GAN-3 (CGAN 论文笔记)

# 小郑之家~

loss函数改成了

# discriminator loss
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = d_logits_real,
labels = tf.ones_like(d_logits_real)) * (1 - smooth))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = d_logits_fake,
labels = tf.zeros_like(d_logits_fake)))
# loss
# generator loss
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = d_logits_fake,
labels = tf.ones_like(d_logits_fake)) * (1 - smooth) )



G和D分别定义如下

def get_generator(digit, noise_img,  reuse = False):
with tf.variable_scope("generator", reuse = reuse):
concatenated_img_digit = tf.concat([digit, noise_img], 1)

#         output = tf.layers.dense(concatenated_img_digit, 256)

output = fully_connected('gf1',concatenated_img_digit,128)
output = leakyRelu(output)
output = tf.layers.dropout(output, rate = 0.5)

#         output = tf.layers.dense(output, 128)

output = fully_connected('gf2',output, 128)
output = leakyRelu(output)
output = tf.layers.dropout(output, rate = 0.5)

#         logits = tf.layers.dense(output, 784)
logits = fully_connected('gf3',output,784)
outputs = tf.tanh(logits)

return logits, outputs

def get_discriminator(digit, img,  reuse = False):
with tf.variable_scope("discriminator", reuse=reuse):
concatenated_img_digit = tf.concat([digit, img], 1)

#         output = tf.layers.dense(concatenated_img_digit, 256)
output = fully_connected('df1',concatenated_img_digit,128)
output = leakyRelu(output)
output = tf.layers.dropout(output, rate = 0.5)

#         output = tf.layers.dense(concatenated_img_digit, 128)
output = fully_connected('df2',output, 128)
output = leakyRelu(output)
output = tf.layers.dropout(output, rate = 0.5)

#         logits = tf.layers.dense(output, 1)
logits = fully_connected('df3',output, 1)
outputs = tf.sigmoid(logits)

return logits, outputs



# generator
g_logits, g_outputs = get_generator(real_img_digit, noise_img)

sample_images = tf.reshape(g_outputs, [-1, 28, 28, 1])
tf.summary.image("sample_images", sample_images, 10)

# discriminator
d_logits_real, d_outputs_real = get_discriminator(real_img_digit, real_img)
d_logits_fake, d_outputs_fake = get_discriminator(real_img_digit, g_outputs, reuse = True)



• 输入不一样

import torch
batch_size = 5
nb_digits = 10
# Dummy input that HAS to be 2D for the scatter (you can use view(-1,1) if needed)
y = torch.LongTensor(batch_size,1).random_() % nb_digits
# One hot encoding buffer that you create out of the loop and just keep reusing
y_onehot = torch.FloatTensor(batch_size, nb_digits)

y_onehot.zero_()
y_onehot.scatter_(1, y, 1)

print(y)
print(y.shape)
print(y_onehot)
print(y_onehot.shape)