GAN-3 (CGAN 论文笔记)

GAN-3 (CGAN 论文笔记)

刚才有GAN的一些理解,这个主要是对CGAN的理解,CGAN相对于GAN的基础之上加上了条件,C指的就是条件, 网络结构如图

avator

loss函数改成了

avator 其中的y就是条件,比如可以是labels或者其他的函数,paper里面说关于在手写数字识别上面的是用的class labels,相比GAN有labels的信息加入,从这里也可以看出这里的监督学习和无监督学习之间的关系。

部分Tensorflow的代码如下

# discriminator loss
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = d_logits_real, 
                                                                         labels = tf.ones_like(d_logits_real)) * (1 - smooth))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = d_logits_fake, 
                                                                         labels = tf.zeros_like(d_logits_fake)))
# loss
d_loss = tf.add(d_loss_real, d_loss_fake)
# generator loss
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = d_logits_fake,
                                                                    labels = tf.ones_like(d_logits_fake)) * (1 - smooth) )

G和D分别定义如下

def get_generator(digit, noise_img,  reuse = False):
    with tf.variable_scope("generator", reuse = reuse):
        concatenated_img_digit = tf.concat([digit, noise_img], 1)
        
#         output = tf.layers.dense(concatenated_img_digit, 256)
 
        output = fully_connected('gf1',concatenated_img_digit,128)
        output = leakyRelu(output)
        output = tf.layers.dropout(output, rate = 0.5)
        
#         output = tf.layers.dense(output, 128)
 
        output = fully_connected('gf2',output, 128)
        output = leakyRelu(output)
        output = tf.layers.dropout(output, rate = 0.5)
 
#         logits = tf.layers.dense(output, 784)
        logits = fully_connected('gf3',output,784)
        outputs = tf.tanh(logits)
        
        return logits, outputs
 
 
def get_discriminator(digit, img,  reuse = False):
    with tf.variable_scope("discriminator", reuse=reuse):
        concatenated_img_digit = tf.concat([digit, img], 1)
        
#         output = tf.layers.dense(concatenated_img_digit, 256)
        output = fully_connected('df1',concatenated_img_digit,128)
        output = leakyRelu(output)
        output = tf.layers.dropout(output, rate = 0.5)
        
#         output = tf.layers.dense(concatenated_img_digit, 128)
        output = fully_connected('df2',output, 128)
        output = leakyRelu(output)
        output = tf.layers.dropout(output, rate = 0.5)
 
#         logits = tf.layers.dense(output, 1)
        logits = fully_connected('df3',output, 1)
        outputs = tf.sigmoid(logits)
        
        return logits, outputs


注意上面代码中,digitimgconcat起来共同作为输入了,这就是上面第一个图的意思,相当于这时候的输出有了label的信息,这样的好处是,是最后可以指定让其伪造某个具体的数字。 从这里也可以对比之前GAN的代码,之前GAN的时候也是784的维度拉平了之后去算loss,不过那里没有考虑到label的信息,

# generator
g_logits, g_outputs = get_generator(real_img_digit, noise_img)
 
sample_images = tf.reshape(g_outputs, [-1, 28, 28, 1])
tf.summary.image("sample_images", sample_images, 10)
 
# discriminator
d_logits_real, d_outputs_real = get_discriminator(real_img_digit, real_img)
d_logits_fake, d_outputs_fake = get_discriminator(real_img_digit, g_outputs, reuse = True)

具体可见 link

另外 根据之前的一个torch gan 的实现,我把其中的改了一下成为了cgan的version. 然后还可以对比一下二者的结果。

修改的过程中学到了几个地方,

  • 输入不一样

现在gan 的输入有label的信息,正如之前tensorflow的version一样, 所以这里需要加一个lable的信息,但是需要转成onehot的形式, 所以查了torch如何转onehot的,

代码如下

import torch
batch_size = 5 
nb_digits = 10
 # Dummy input that HAS to be 2D for the scatter (you can use view(-1,1) if needed)
y = torch.LongTensor(batch_size,1).random_() % nb_digits
# One hot encoding buffer that you create out of the loop and just keep reusing
y_onehot = torch.FloatTensor(batch_size, nb_digits)
 
# In your for loop
y_onehot.zero_()
y_onehot.scatter_(1, y, 1)

print(y)
print(y.shape)
print(y_onehot)
print(y_onehot.shape)

打赏,谢谢~~

取消

感谢您的支持,我会继续努力的!

扫码支持
扫码打赏,多谢支持~

打开微信扫一扫,即可进行扫码打赏哦