GAN-1

GAN-1

一直听说“GAN”已经火得不行不行的了,我也打算了解一下GAN, 看的书是 《Learning Generative Adversarial Networks》,准备记录一下看书的过程。 还有,有时候慢就是快!

chapter 2. Unsupervised Learning with GAN

  • 初步感觉是GAN将更加具有创造性和想像力。

  • 为何generative models值得研究?

    • sampling(or generation) 是直接的
    • 训练过程不涉及MLE
    • 够robust,从而能够防止拟合,原因是生成器看不到训练数据。
    • GAN 擅长学习模型的分布。
  • 大致认识GAN

gan 是一个在对抗中进行学习的网络,举个例子,比如你想成为一个顶尖的棋手,做法就是去找比你厉害的人去下棋,刚开始输了不要紧,你能从对方的棋路中进行思考自己哪里下的不好,然后你就能够进步,而你经常找那个人下,那个人也感觉到你的棋力在一直进步,所以他也可能会有压力,为了能够一直完败你,他也要学习,他要根据你自己经过进步悟出的新招来想破解的办法。 这样你们两个就会不断的相互竟争式的学习,最终你们两个都比较厉害了。说得可能不太恰当,但是这就是GAN的思想。

  • gan 的网络结构都示意图 avator

  • gan的学习过程

avator

接下来就以手写数字来大致认识一下其过程。

  • discriminator
    def discriminator():
      model = nn.Sequential(
          Flatten(),
          nn.Linear(784, 256),
          nn.LeakyReLU(0.01, inplace=True),
          nn.Linear(256,256),
          nn.LeakyReLU(0.01, inplace=True),
          nn.Linear(256, 1)
      )
      return model
    
  • generator
def generator(noise_dim=NOISE_DIM):
    model = nn.Sequential(
        nn.Linear(noise_dim, 1024),
        nn.ReLU(inplace=True),
        nn.Linear(1024, 1024),
        nn.ReLU(inplace=True),
        nn.Linear(1024, 1024),
        nn.Tanh(),
    )
    return model

全部代码如下

import torch
import torch.nn as nn
from torch.nn import init
from torch.autograd import Variable
import torchvision
import torchvision.transforms as T
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn

import numpy as np

import matplotlib.pyplot as plt 
import matplotlib.gridspec as gridspec

def show_images(images):

    images = np.reshape(images, [images.shape[0], -1])  # images reshape to (batch_size, D)
    sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
    sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))

    fig = plt.figure(figsize=(sqrtn, sqrtn))
    gs = gridspec.GridSpec(sqrtn, sqrtn)
    gs.update(wspace=0.05, hspace=0.05)

    for i, img in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(img.reshape([sqrtimg,sqrtimg]))
    return 

def preprocess_img(x):
    return 2 * x - 1.0 

def deprocess_img(x):
    return (x + 1.0) / 2.0 

def rel_error(x,y):
    return np.max(np.abs(x - y) / (np.maximum(1e-8, np.abs(x) + np.abs(y))))

def count_params(model):
    """Count the number of parameters in the current TensorFlow graph """
    param_count = np.sum([np.prod(p.size()) for p in model.parameters()])
    return param_count

#answers = np.load('gan-checks-tf.npz')

# 采样函数为自己定义的序列采样(即按顺序采样)
class ChunkSampler(sampler.Sampler): 
    """Samples elements sequentially from some offset. 
    Arguments:
        num_samples: # of desired datapoints
        start: offset where we should start selecting from
    """
    def __init__(self, num_samples, start=0):
        self.num_samples = num_samples
        self.start = start

    def __iter__(self):
        return iter(range(self.start, self.start + self.num_samples))

    def __len__(self):

        return self.num_samples

NUM_TRAIN = 50000   # 训练集数量
NUM_VAL = 5000      # 测试集数量

NOISE_DIM = 96
batch_size = 128

mnist_train = dset.MNIST('./datasets/MNIST_data', train=True, download=True,
                           transform=T.ToTensor())
loader_train = DataLoader(mnist_train, batch_size=batch_size,
                          sampler=ChunkSampler(NUM_TRAIN, 0)) # 从0位置开始采样NUM_TRAIN个数

mnist_val = dset.MNIST('./datasets/MNIST_data', train=True, download=True,
                           transform=T.ToTensor())
loader_val = DataLoader(mnist_val, batch_size=batch_size,
                        sampler=ChunkSampler(NUM_VAL, NUM_TRAIN)) # 从NUM_TRAIN位置开始采样NUM_VAL个数


#imgs = loader_train.__iter__().next()[0].view(batch_size, 784).numpy().squeeze()
#show_images(imgs)

def sample_noise(batch_size, dim):
    """
    Generate a PyTorch Tensor of uniform random noise.

    Input:
    - batch_size: Integer giving the batch size of noise to generate.
    - dim: Integer giving the dimension of noise to generate.

    Output:
    - A PyTorch Tensor of shape (batch_size, dim) containing uniform
      random noise in the range (-1, 1).
    """
    temp = torch.rand(batch_size, dim) + torch.rand(batch_size, dim)*(-1)

    return temp

class Flatten(nn.Module):
    def forward(self, x):
        N, C, H, W = x.size() # read in N, C, H, W
        return x.view(N, -1)  # "flatten" the C * H * W values into a single vector per image

class Unflatten(nn.Module):
    """
    An Unflatten module receives an input of shape (N, C*H*W) and reshapes it
    to produce an output of shape (N, C, H, W).
    """
    def __init__(self, N=-1, C=128, H=7, W=7):
        super(Unflatten, self).__init__()
        self.N = N
        self.C = C
        self.H = H
        self.W = W
    def forward(self, x):
        return x.view(self.N, self.C, self.H, self.W)

def initialize_weights(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose2d):
        init.xavier_uniform(m.weight.data)


def discriminator():
    """
    Build and return a PyTorch model implementing the architecture above.
    """
 model = nn.Sequential(
        nn.Linear(noise_dim, 1024),
        nn.ReLU(inplace=True),
        nn.Linear(1024, 1024),
        nn.ReLU(inplace=True),
        nn.Linear(1024, 784),
        nn.Tanh(),
    )
    return model


Bce_loss = nn.BCEWithLogitsLoss()

def discriminator_loss(logits_real, logits_fake):
    """
    Computes the discriminator loss described above.

    Inputs:
    - logits_real: PyTorch Variable of shape (N,) giving scores for the real data.
    - logits_fake: PyTorch Variable of shape (N,) giving scores for the fake data.

    Returns:
    - loss: PyTorch Variable containing (scalar) the loss for the discriminator.
    """
    loss = None
    # Batch size.
    N = logits_real.size()

    # 目标label,全部设置为1意味着判别器需要做到的是将正确的全识别为正确,错误的全识别为错误
    true_labels = Variable(torch.ones(N))


    real_image_loss = Bce_loss(logits_real, true_labels) # 识别正确的为正确
    fake_image_loss = Bce_loss(logits_fake, 1 - true_labels) # 识别错误的为错误

    loss = real_image_loss + fake_image_loss

    return loss

def generator_loss(logits_fake):
    """
    Computes the generator loss described above.

    Inputs:
    - logits_fake: PyTorch Variable of shape (N,) giving scores for the fake data.

    Returns:
    - loss: PyTorch Variable containing the (scalar) loss for the generator.
    """
    # Batch size.
    N = logits_fake.size()

    # 生成器的作用是将所有“假”的向真的(1)靠拢
    true_labels = Variable(torch.ones(N))

    # 计算生成器损失
    loss = Bce_loss(logits_fake, true_labels)

    return loss


def get_optimizer(model):
 """
    Construct and return an Adam optimizer for the model with learning rate 1e-3,
    beta1=0.5, and beta2=0.999.

    Input:
    - model: A PyTorch model that we want to optimize.

    Returns:
    - An Adam optimizer for the model with the desired hyperparameters.
    """
    optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.5, 0.999))
    return optimizer


def run_a_gan(D, G, D_solver, G_solver, discriminator_loss, generator_loss, show_every=250,
              batch_size=128, noise_size=96, num_epochs=10):
    """
    Train a GAN!

    Inputs:
    - D, G: pytorch模块,分别为判别器和生成器
    - D_solver, G_solver: torch.optim Optimizers to use for training the
      discriminator and generator.
    - discriminator_loss, generator_loss: Functions to use for computing the generator and
      discriminator loss, respectively.
    - show_every: Show samples after every show_every iterations.
    - batch_size: Batch size to use for training.
    - noise_size: Dimension of the noise to use as input to the generator.
    - num_epochs: Number of epochs over the training dataset to use for training.
    """

    iter_count = 0
    for epoch in range(num_epochs):
        for x, _ in loader_train:
            if len(x) != batch_size:
                continue
            D_solver.zero_grad()
            #real_data = Variable(x)
            logits_real = D(2* (x - 0.5))

            #g_fake_seed = Variable(sample_noise(batch_size, noise_size))
            g_fake_seed = sample_noise(batch_size, noise_size)
            fake_images = G(g_fake_seed).detach()
            #fake_images = G(g_fake_seed)
            logits_fake = D(fake_images.view(batch_size, 1, 28, 28))

            d_total_error = discriminator_loss(logits_real, logits_fake)
            d_total_error.backward()
            D_solver.step()

            G_solver.zero_grad()
            #g_fake_seed = Variable(sample_noise(batch_size, noise_size))
            g_fake_seed = sample_noise(batch_size, noise_size)
            fake_images = G(g_fake_seed)

            gen_logits_fake = D(fake_images.view(batch_size, 1, 28, 28))
            g_error = generator_loss(gen_logits_fake)
            g_error.backward()
            G_solver.step()

            if (iter_count % show_every == 0):
                print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count,d_total_error.data[0],g_error.data[0]))
                imgs_numpy = fake_images.data.cpu().numpy()
                show_images(imgs_numpy[0:16])
                plt.savefig("%d.png" %(iter_count))
                #plt.show()
                #print()
                print("having written %d" % (iter_count))
            iter_count += 1

D = discriminator()

# Make the generator
#G = generator().type(dtype)
G = generator()

# Use the function you wrote earlier to get optimizers for the Discriminator and the Generator
D_solver = get_optimizer(D)
G_solver = get_optimizer(G)
# Run it!
run_a_gan(D, G, D_solver, G_solver, discriminator_loss, generator_loss)

结果

学习一段时间之后,就会发现生成器产生的图像已经和原图十分的接近了。 见下图 avator

打赏,谢谢~~

取消

感谢您的支持,我会继续努力的!

扫码支持
扫码打赏,多谢支持~

打开微信扫一扫,即可进行扫码打赏哦