cs231 Generative Adversarial Networks (GANs)
迄今为止,在CS231N中,我们所探索的神经网络的所有应用都是采用输入并被训练以产生标记输出的判别模型。这包括从图像类别的分类到句子生成(这仍然是一个分类问题,我们的标签在词汇空间中,并且我们已经学会了递归来捕获多词标签)。在本笔记本中,我们将使用神经网络建立生成模型。具体地说,我们将学习如何构建生成类似于一组训练图像的新图像的模型。
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from torch.nn import init
import torchvision
import torchvision.transforms as T
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torchvision.datasets as dset
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
#%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
def show_images(images):
images = np.reshape(images, [images.shape[0], -1]) #