cycleGAN 논문에서는 cycleGAN의 Loss function을 다음과 같이 정의하고 있다. (identity loss는 optional)
이 수식은 코드로 볼 때 훨씬 간단해진다.
나는 PyTorch를 이용해 cycleGAN의 loss 부분을 다음과 같이 작성했다.
코드는 저자가 공개한 official code와 simplified version(unofficial)을 참고했다.
처음 공부할 때는 simplified version을 통해 전체 구조를 이해한 후, official code를 참고하면 이해하기 편하다.
(*cycleGAN의 detail이나, 이론적인 부분은 이 글에서는 설명하지 않기로 한다.)
우선, 각 stage의 loss를 다음과 같이 식으로 정리할 수 있다.
참고로, realA와 realB는 각 domain의 sample이다.
각 Generator와 Discriminator을 G_{B->A}, G_{A->B}, D_A, D_B로 표현했다.
Generator loss
Discriminator A loss
Discriminator B loss
다음은 위 식을 바탕으로 작성한 코드이다.
1. 전체적인 구조
- Generator, Discriminator A, Discrimator B의 train이 각각 이루어진다.
- Generator loss는 identity loss, GAN loss, cycle loss의 합으로 이루어져 있다.
- lambda_identity가 0이 아닐 경우만 identity loss를 계산하며,
- Discriminator loss 계산 시에 Replay buffer을 이용한다.
*Replay buffer: GAN의 불안정성을 해소하기 위해 이전에 generator가 생성한 사진을 주기적으로 다시 discriminator에게 보여주는 방식
for real_A, real_B in dataloader:
#Train Generator =====================================================
if lambda_identity != 0:
loss_idt = compute_idt_loss(real_A, real_B, generator_A2B, generator_B2A, loss_f_identity, lambda_identity)
else:
loss_idt = 0
loss_GAN, fake_A, fake_B = compute_GAN_loss(real_A, real_B, generator_A2B, generator_B2A, discriminator_A, discriminator_B, loss_f_GAN)
loss_cycle, cycle_A, cycle_B = compute_cycle_loss(real_A, real_B, fake_A, fake_B, generator_A2B, generator_B2A, loss_f_cycle, lambda_cycle)
loss_G = loss_idt + loss_GAN + loss_cycle
do_step(optimizer_G, loss_G, scheduler_G)
#Train Discriminator A =====================================================
fake_A_ = fake_A_buffer.push_and_pop(fake_A)
loss_D_A = compute_discriminator_loss(real_A, fake_A_, discriminator_A, loss_f_GAN)
do_step(optimizer_D_A, loss_D_A, scheduler_D_A)
#Train Discriminator B =====================================================
fake_B_ = fake_B_buffer.push_and_pop(fake_B)
loss_D_B = compute_discriminator_loss(real_B, fake_B_, discriminator_B, loss_f_GAN)
do_step(optimizer_D_B, loss_D_B, scheduler_D_B)
def do_step(optimizer, loss, scheduler=None):
optimizer.zero_grad()
loss.backward()
optimizer.step()
if scheduler is not None:
scheduler.step()
2. 각 loss를 계산하는 함수
- 이 부분이 Loss를 계산하는 부분으로, 위에서 수식으로 표현한 부분이다.
- Discriminator loss의 경우, real loss와 fake loss의 평균값을 사용했다.
def compute_idt_loss(real_A, real_B, generator_A2B, generator_B2A, loss_f_identity, lambda_identity):
same_A = generator_B2A(real_A)
same_B = generator_A2B(real_B)
loss_idt_A = loss_f_identity(same_A, real_A)*lambda_identity
loss_idt_B = loss_f_identity(same_B, real_B)*lambda_identity
return loss_idt_A+loss_idt_B
def compute_GAN_loss(real_A, real_B, generator_A2B, generator_B2A, discriminator_A, discriminator_B, loss_f_GAN):
fake_B = generator_A2B(real_A)
fake_A = generator_B2A(real_B)
pred_fake_B = discriminator_B(fake_B)
pred_fake_A = discriminator_A(fake_A)
loss_GAN_A2B = loss_f_GAN(pred_fake_B, True)
loss_GAN_B2A = loss_f_GAN(pred_fake_A, True)
return loss_GAN_A2B+loss_GAN_B2A, fake_A, fake_B
def compute_cycle_loss(real_A, real_B, fake_A, fake_B, generator_A2B, generator_B2A, loss_f_cycle, lambda_cycle):
cycle_A = generator_B2A(fake_B)
cycle_B = generator_A2B(fake_A)
loss_cycle_ABA = loss_f_cycle(cycle_A, real_A)*lambda_cycle
loss_cycle_BAB = loss_f_cycle(cycle_B, real_B)*lambda_cycle
return loss_cycle_ABA+loss_cycle_BAB, cycle_A, cycle_B
def compute_discriminator_loss(real_im, fake_im, discriminator, loss_f_GAN):
pred_real = discriminator(real_im)
pred_fake = discriminator(fake_im.detach())
loss_D_real = loss_f_GAN(pred_real, True)
loss_D_fake = loss_f_GAN(pred_fake, False)
return (loss_D_real+loss_D_fake)*0.5
3. GAN loss function
- official code에서 했던 것과 동일하게, GAN loss를 계산하는 class를 새로 정의했다.
- GAN loss를 계산할 때는 nn.BCEWithLogitsLoss() 혹은 nn.MSELoss()를 이용할 수 있다.
- cycleGAN에서는 vanishing gradient 문제를 해결하기 위해 MSELoss를 이용했다.
class GANLoss(nn.Module):
def __init__(self, loss):
"""
:param loss: nn.BCEWithLogitsLoss() / nn.MSELoss()
"""
super(GANLoss, self).__init__()
self.register_buffer('real_label', torch.tensor(1.0))
self.register_buffer('fake_label', torch.tensor(0.0))
self.loss = loss
def __call__(self, pred, real):
if real:
target = self.real_label
else:
target = self.fake_label
target = target.expand_as(pred)
return self.loss(pred, target)
'🌌 Deep Learning > Implementation' 카테고리의 다른 글
[PyTorch Implementation] StyleGAN2 (2) | 2022.09.26 |
---|---|
[PyTorch Implementation] PointNet 설명과 코드 (0) | 2022.08.12 |
[PyTorch Implementation] ResNet-B, ResNet-C, ResNet-D, ResNet Tweaks (1) | 2022.06.05 |
[PyTorch Implementation] CBAM: Convolutional Block Attention Module 설명 + 코드 (0) | 2022.04.22 |
[PyTorch Implementation] 3D Segmentation model - VoxResNet, Attention U-Net, V-Net (0) | 2020.12.30 |