이번 글에서는 WGAN 코드를 리뷰해보겠습니다.
Wasserstein GAN 논문에 대한 내용은 아래의 글에서 보실 수 있습니다.
그럼 시작해보겠습니다.
transformation = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
])
train_dataset = torchvision.datasets.MNIST(root = '/content/drive/MyDrive/MNIST', train = True, download = True,
transform = transformation)
print("dataset size: ", len(train_dataset))
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True, num_workers = 2)
latent_dim = 100
clip_value = 0.005
n_critic = 5
sample_interval = 400
img_shape = (1, 32, 32)
cuda = True if torch.cuda.is_available() else False
먼저 데이터를 불러와줍니다. 데이터는 MNIST를 사용하였습니다.
latent dimension은 100으로 지정하였고, 논문에서 사용된 weight clipping은 0.005로 지정하였습니다. 논문에는 0.01로 지정되어 있었으나, 코드를 돌려보니 좋지 못한 결과가 나와서 이를 줄여서 실험을 진행했습니다.
n_critic은 generator보다 critic을 얼마나 더 많이 학습 시킬 것인지를 나타내는 값입니다. 저는 5배 더 많이 학습시키도록 설정하였습니다.
그리고 sample_interval은 얼마나 자주 샘플 이미지를 생성해볼 것인지를 결정합니다. 저는 400 배치마다 저장하도록 설정하였습니다.
# Generator
class Generator(torch.nn.Module):
def __init__(self, channels):
super().__init__()
# Filters [1024, 512, 256]
# Input_dim = 100
# Output_dim = C (number of channels)
self.main_module = nn.Sequential(
# Z latent vector 100
nn.ConvTranspose2d(in_channels=100, out_channels=1024, kernel_size=4, stride=1, padding=0),
nn.BatchNorm2d(num_features=1024),
nn.ReLU(True),
# State (1024x4x4)
nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(num_features=512),
nn.ReLU(True),
# State (512x8x8)
nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(num_features=256),
nn.ReLU(True),
# State (256x16x16)
nn.ConvTranspose2d(in_channels=256, out_channels=channels, kernel_size=4, stride=2, padding=1))
# output of main module --> Image (Cx32x32)
self.output = nn.Tanh()
def forward(self, x):
x = self.main_module(x)
return self.output(x)
해당 구조는 DCGAN의 구조를 그대로 사용하였습니다.
class Discriminator(torch.nn.Module):
def __init__(self, channels):
super().__init__()
# Filters [256, 512, 1024]
# Input_dim = channels (Cx64x64)
# Output_dim = 1
self.main_module = nn.Sequential(
# Image (Cx32x32)
nn.Conv2d(in_channels=channels, out_channels=256, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
# State (256x16x16)
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
# State (512x8x8)
nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(1024),
nn.LeakyReLU(0.2, inplace=True))
# outptut of main module --> State (1024x4x4)
self.output = nn.Sequential(
nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=4, stride=1, padding=0),
# Output 1
)
def forward(self, x):
x = self.main_module(x)
return self.output(x)
critic 부분도 DCGAN과 동일하게 사용하였는데, 차이가 있다면 맨 마지막에 sigmoid를 제거하였습니다.
# Initialize generator and discriminator
generator = Generator(1)
discriminator = Discriminator(1)
if cuda:
generator.cuda()
discriminator.cuda()
fixed_noise = torch.randn(16, 100, 1, 1).cuda()
# Optimizers
optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=0.00005)
optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=0.00005)
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
fixed_noise는 고정된 noise vector를 만들어주고자 생성하였고, optimizer로는 RMSprop을 사용하였습니다.
batches_done = 0
for epoch in tqdm(range(EPOCHS)):
D_losses = []
G_losses = []
for i, (imgs, _) in enumerate(trainloader):
# Configure input
real_imgs = Variable(imgs.type(Tensor))
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Sample noise as generator input
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim, 1, 1))))
# Generate a batch of images
fake_imgs = generator(z).detach()
# Adversarial loss
loss_D = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))
writer.add_scalar("Train/Wasserstein_estimate", -loss_D.item(), batches_done)
loss_D.backward()
optimizer_D.step()
D_losses.append(-loss_D.item())
# Clip weights of discriminator
for p in discriminator.parameters():
p.data.clamp_(-clip_value, clip_value)
이미지를 real_imgs로 가져오고, z라는 변수로 sample noise를 만들어준 다음, 이를 generator에 투입해서 가짜 이미지를 만들어줍니다.
그리고 논문에 나왔던대로, Discriminator loss를 계산해줍니다.
단, 논문에서는 gradient ascent를 해주는 식으로 Algorithm이 설계되었기 때문에 우리는 gradient descent를 적용하기 위해 부호를 반대로 바꿔줍니다.
그래서 writer.add_scalar를 적용할때는 -loss_D.item()으로 입력되도록 만든 것입니다. 부호가 반대로 들어갔기 때문이죠.
그리고 마지막으로는 weight clipping을 해줍니다.
# Train the generator every n_critic iterations
if i % n_critic == 0:
# -----------------
# Train Generator
# -----------------
optimizer_G.zero_grad()
# Generate a batch of images
gen_imgs = generator(z)
# Adversarial loss
loss_G = -torch.mean(discriminator(gen_imgs))
writer.add_scalar("Train/Generator_loss", loss_G.item(), batches_done // n_critic)
G_losses.append(loss_G.item())
loss_G.backward()
optimizer_G.step()
n_critic은 어느정도의 빈도로 generator를 update 해줄 것인지 결정하는 것이라고 앞에서 설명하였는데, 이를 이용해서 batch의 index가 n_critic의 배수일 때만 generator를 학습하도록 만들어줍니다.
sample noise인 z를 다시 generator에 투입하고, 이를 통해서 가짜 이미지를 만들어줍니다.
그리고 이를 discriminator에 투입한 값의 평균에 마이너스를 취한 값을 generator loss로 계산해줍니다.
이를 통해서 얻게 되는 loss graph는 다음과 같았습니다.
분명 논문에서는 엄청나게 깔끔하게 하향하는 곡선이 올라와있었는데, 실제로 돌려보니 그렇지는 않았습니다..
아무래도 hyperparameter들을 조정하거나 해야 더 좋은 성능이 나올 것 같은데...
나와있는 수많은 코드들을 직접 돌려보고 해봤는데도 멀쩡히 돌아가는 코드를 못 찾아서 몇일동안 고군분투하다가 이정도에서 마무리하기로 하였습니다.
GAN을 향상시킨 모델이라고 들었어서 기대를 많이 했지만, 여전히 생성 모델로 좋은 결과를 만드는 것이 쉽지만은 않네요..
해당 코드는 제 Github에 올라와있으니, 여기서 전체 코드를 확인하실 수 있습니다.
github.com/PeterKim1/paper_code_review/tree/master/5.%20Wasserstein%20GAN(WGAN)