产生anchor的机制2

产生anchor的机制2

把上面的代码全部改成torch而不用numpy的话,是这样的, torch的一些函数是和numpy中的很类似的,但是也有不同。


import torch
import torch.nn as nn

class Anchors(torch.nn.Module):
    def __init__(self, pyramid_levels=None, strides=None, sizes=None, ratios=None, scales=None):
        super(Anchors, self).__init__()

        if pyramid_levels is None:
            self.pyramid_levels = [3, 4, 5, 6, 7]
        if strides is None:
            self.strides = [2 ** x for x in self.pyramid_levels]
        if sizes is None:
            self.sizes = [2 ** (x + 2) for x in self.pyramid_levels]   # [32,64,128,256,512]
        if ratios is None:
            self.ratios = torch.tensor([0.5, 1, 2])
        if scales is None:
            self.scales = torch.tensor([2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])

    def forward(self, image):
        image_shape = torch.tensor(image.shape[2:])  # nchw  h,w
        image_shapes = [(image_shape + 2 ** x - 1) // (2 ** x) for x in self.pyramid_levels]
        all_anchors = torch.zeros((0, 4))  # 初始化为0.
        for idx, p in enumerate(self.pyramid_levels):
            anchors = generate_anchors(base_size=self.sizes[idx], ratios=self.ratios, scales=self.scales)
            shifted_anchors = shift(image_shapes[idx], self.strides[idx], anchors)
            all_anchors = torch.cat((all_anchors, shifted_anchors))
        all_anchors = torch.unsqueeze(all_anchors, 0)
        print(all_anchors.shape)
        return all_anchors

def generate_anchors(base_size=16, ratios=None, scales=None):

    if ratios is None:
        ratios = torch.tensor([0.5, 1, 2])

    if scales is None:
        scales =torch.tensor([2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])

    num_anchors = len(ratios) * len(scales)   # 这是每个地方产生的anchors

    anchors = torch.zeros((num_anchors, 4))

    anchors[:, 2:] = base_size *scales.repeat((2, len(ratios))).permute((1,0))

    areas = anchors[:, 2] * anchors[:, 3]


    anchors[:, 2] = torch.sqrt(areas / ratios.repeat(len(scales)))
    anchors[:, 3] = anchors[:, 2] *ratios.repeat(len(scales))

    # transform from (x_ctr, y_ctr, w, h) -> (x1, y1, x2, y2)
    anchors[:, 0::2] -= torch.clone(anchors[:, 2] * 0.5).repeat((2, 1)).permute((1,0))
    anchors[:, 1::2] -= torch.clone(anchors[:, 3] * 0.5).repeat((2, 1)).permute((1,0))

    return anchors

def compute_shape(image_shape, pyramid_levels):
    image_shape = image_shape[:2]
    image_shapes = [(image_shape + 2 ** x - 1) // (2 ** x) for x in pyramid_levels]
    return image_shapes


def anchors_for_shape(
    image_shape,
    pyramid_levels=None,
    ratios=None,
    scales=None,
    strides=None,
    sizes=None,
    shapes_callback=None,
):

    image_shapes = compute_shape(image_shape, pyramid_levels)

    # compute anchors over all pyramid levels
    all_anchors = torch.zeros((0, 4))
    for idx, p in enumerate(pyramid_levels):
        anchors         = generate_anchors(base_size=sizes[idx], ratios=ratios, scales=scales)
        shifted_anchors = shift(image_shapes[idx], strides[idx], anchors)
        all_anchors     = torch.cat((all_anchors, shifted_anchors))
    return all_anchors


# 因为之前只是产生了一个anchor的位置,现在是利用一个anchor来产生所有的anchor.

def shift(shape, stride, anchors):   # shape是(10,8)  # 并不是图的shape, 而是要产生多少行和多少列的anchors
    shift_x = (torch.arange(0, shape[1]) + 0.5) * stride
    shift_y = (torch.arange(0, shape[0]) + 0.5) * stride

    shift_x, shift_y = torch.meshgrid(shift_x, shift_y)
    shift_x = shift_x.contiguous()



打赏,谢谢~~

取消

感谢您的支持,我会继续努力的!

扫码支持
扫码打赏,多谢支持~

打开微信扫一扫,即可进行扫码打赏哦