STN论文笔记

STN论文笔记

paper 地址在 https://arxiv.org/pdf/1506.02025.pdf

初步印象

初步看了之后的感觉是这个模块是对img或者feature作了一个前期的处理,或者可以理解成在处理之前作了一个较准,因为较准后的图肯定会好一些,所以加上这个module之后,效果会好一些。可以把它放在cnn之前, 可以放多个这个module进行串联起来,这样看起来还增加了网络的深度,也可以并行的做,比如alphapose里面就有这个结构,大致的思想就是用这个模块去学得一个变换,这个变换可以理解成align。那么是不是需要一些gt的变换呢?paper里面说并不需要,如果需要的话,gt的变换如何找就是个问题。至少要花费很多的力气来做这个gt。并且网络结构是可微的,这样可以使得梯度回传来更新参数。

motivation

因为实际自然中各种场景都有,而训练集总归是有限的一个很小的部分,那么如何在有限的数据集中训练出好的模型呢,比如一个图中有一个狗,把图作各种变换之后也至少能够识别出还是那个狗,不然的话网络就太差了,这就需要网络有一定的”不变性”,比如平移不变,roate不变等,有了这些之后无论图怎么变化,都能够对其作出准确的识别,这样才能够更加适合更加贴合自然的场景。我估计这种需求在自动驾驶里面非常之高,只是自己的猜测,因为那种要求的精度非常高,但是实际的场景又多种多样。

其实加了max-Pooling层的网络是有一定的平移不变性,但是这种平移不变性能够接纳的平移度非常的小,因为往往其步长为2,这可能在一些小的任务上可行,比如机械臂这种,或者找关键点这种。即希望图抖一下之后还是能够精确的检测出。

为了处理可能的自然场景,实际中往往会做一些数据增强的操作,比如随机crop, 旋转,光照,对比度,等等。但是感觉这些并非从其根本上来解决,这些只是让网络”尽可能多地做题”,而这里面的STN想法感觉是从网络本身来提高其性能,就像学ba一定没有天才厉害一样,修炼其资质才是根本。

具体

网络结构如图所示, avator

下面分别介绍三个网络的作用。

  • Localisation net

它是以input feature U作输入,输出的是参数 \theta, 即 \theta = f(U)其中\theta的类型取决于变换的类型,其中f可以取fcn或者cnn都可以,但是需要最后一层是回归层来用得到变换的参数。

  • 一个问题是,前面说没有变换的gt,那么变换的参数是如何学到的呢?

  • Grid generator

从图上也可以看出,Grid generator 用上一步得到的变换参数,从而就可以得到一个sampling grid, 所以这一块叫 grid generator,

  • Sampler

将input U 和 sampling grid 作为输入送到Sampler进行采样就可以得到输出的V。

Parameterised Sampling Grid

输出V的每个pixel上的值实际上都是经过一个sampling kernel 在input U上的一个特定的位置经过操作得到的。具体的这个核决定操作的类型,比如如果是affine map的话就是下面的形式的,

avator

其中带左边带s上标的是input feature 上的位置,带t的是输出上的位置,其实上面的公式中代表着许多的内容,比如有crop, translation, rotation, scale, crop是因为这个变换可以是压缩的,只要当右边2*2的矩阵的行列式小于1,那么mapped regular grid就会在一个range没有xis, yis大的平行四边行里面。

其中的变换可以变得有限制,也可以很general,比如

avator

这样就只能对应crop,istropic scale, translation, 再比如 变换甚至可以是平面上的投影变换,即每个像素都做个affine。

Differentiable Image Sampling

刚才提到采样实际上是用采样核的,就像cnn中的filter一样,general的形式是

avator

这个变换关于每个channel的形式是一样的,这样在channels之间保证空间上的一致性。 理论上来说,只要梯度或者是次梯度可以关于输入xis, yis定义的话,就可以用。

比如integer sampling kernel

avator

这个的意思是取离xis, yis最近的那个像素点位置上的值作为输出xit, yit上的位置。 还有双线性interpolation,

avator

STN可以应用到哪里,怎么用

STN可以放到CNN中去,并且因为其计算很快,并不会影响训练速度,有时候甚至会加快训练速度,因为STN里面是有采样的,而且作了变换之后,feature更好了一些,算loss的时候可能有用的信息就比较集中。

也可以用STN去专门做一些上采样或者下采样的动作,只要采样核是固定的就可以了。

可以使用多个STN,还可以并行的方式使用多个STN,这在alphapose中就用到了,因为那里对人做姿态估计,因为一张图上可能会有多人,通过这种方式可能会减少对人的检测的误判。

更细节的东西需要根据代码来看。

示例代码

在torch的教程上看到了一个代码,关于stn的部分是这样的。

# Spatial transformer localization-network
        self.localization = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )

        # Regressor for the 3 * 2 affine matrix
        self.fc_loc = nn.Sequential(
            nn.Linear(10 * 3 * 3, 32),
            nn.ReLU(True),
            nn.Linear(32, 3 * 2)
        )

        # Initialize the weights/bias with identity transformation
        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

能够看到localization network就是和通常的block没有什么区别,就是作了两次的cnn-pooling-relu,可以理解成是提特征, 然后最后一层接了一个回归,会输出6个参数,(回归直观的理解就是输出可连续变化的量),这6个参数因为都是实数,所以是可连续变化的。然后初始化参数,先全部弄成0,然后再让它等于个identity matrix, 其shape是(2,3),

然后是前向计算部分

 # Spatial transformer network forward function
    def stn(self, x):
        xs = self.localization(x)
        xs = xs.view(-1, 10 * 3 * 3)
        theta = self.fc_loc(xs) 
        theta = theta.view(-1, 2, 3)

        grid = F.affine_grid(theta, x.size())
        x = F.grid_sample(x, grid)

        return x

可以看出,输入x,先进行predict参数,即得到theta, 然后 用theta来得到grid,这即paper里面说的grid-generator, 然后进行sampler,把input x和 grid作为输入得到最终的输出结果x。

最后来看看它在网络中的具体位置,这个是mnist的情况,其他的有可能不一样

   def forward(self, x):
        # transform the input
        x = self.stn(x)

        # Perform the usual forward pass
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


可见是在数据一进来的时候就用了stn,所以可以理解成是一个前期的数据处理。好像类似的网络也有其它的,比如SENET, DeformableCNN,感觉也有类似的思想在里面。感觉本质上是一样的,相当于是这里有两步操作,先变换后cnn,而他们那里直接就是在cnn的时候就寻求一些变化。 前面提的那个问题现在也有了答案,从上面的代码中可以看出来要学的变换根本没有gt可以用,只是初始化的时候是一个identity,那么网络就会根据从后面传来的梯度进行更新和学习。

打赏,谢谢~~

取消

感谢您的支持,我会继续努力的!

扫码支持
扫码打赏,多谢支持~

打开微信扫一扫,即可进行扫码打赏哦