🌌 Deep Learning/Implementation

[PyTorch Implementation] PyTorch로 구현한 cycleGAN의 loss 부분 설명

복만 2021. 8. 4. 14:30

cycleGAN 논문에서는 cycleGAN의 Loss function을 다음과 같이 정의하고 있다. (identity loss는 optional)

 

 

이 수식은 코드로 볼 때 훨씬 간단해진다.

 

나는 PyTorch를 이용해 cycleGAN의 loss 부분을 다음과 같이 작성했다.

코드는 저자가 공개한 official codesimplified version(unofficial)을 참고했다.

처음 공부할 때는 simplified version을 통해 전체 구조를 이해한 후, official code를 참고하면 이해하기 편하다.

(*cycleGAN의 detail이나, 이론적인 부분은 이 글에서는 설명하지 않기로 한다.)

 


 

우선, 각 stage의 loss를 다음과 같이 식으로 정리할 수 있다.

참고로, realA와 realB는 각 domain의 sample이다.

각 Generator와 Discriminator을 G_{B->A}, G_{A->B}, D_A, D_B로 표현했다.

 

Generator loss

 

Discriminator A loss

 

Discriminator B loss

 


다음은 위 식을 바탕으로 작성한 코드이다.

 

1. 전체적인 구조

- Generator, Discriminator A, Discrimator B의 train이 각각 이루어진다.

- Generator loss는 identity lossGAN losscycle loss의 합으로 이루어져 있다.

- lambda_identity가 0이 아닐 경우만 identity loss를 계산하며,

- Discriminator loss 계산 시에 Replay buffer을 이용한다.

    *Replay buffer: GAN의 불안정성을 해소하기 위해 이전에 generator가 생성한 사진을 주기적으로 다시 discriminator에게 보여주는 방식

for real_A, real_B in dataloader:
    #Train Generator =====================================================
    if lambda_identity != 0:
    	loss_idt = compute_idt_loss(real_A, real_B, generator_A2B, generator_B2A, loss_f_identity, lambda_identity)
    else:
    	loss_idt = 0
    loss_GAN, fake_A, fake_B = compute_GAN_loss(real_A, real_B, generator_A2B, generator_B2A, discriminator_A, discriminator_B, loss_f_GAN)
    loss_cycle, cycle_A, cycle_B = compute_cycle_loss(real_A, real_B, fake_A, fake_B, generator_A2B, generator_B2A, loss_f_cycle, lambda_cycle)
    
    loss_G = loss_idt + loss_GAN + loss_cycle
    do_step(optimizer_G, loss_G, scheduler_G)
    
    #Train Discriminator A =====================================================
    fake_A_ = fake_A_buffer.push_and_pop(fake_A)
    loss_D_A = compute_discriminator_loss(real_A, fake_A_, discriminator_A, loss_f_GAN)
    do_step(optimizer_D_A, loss_D_A, scheduler_D_A)
    
    #Train Discriminator B =====================================================
    fake_B_ = fake_B_buffer.push_and_pop(fake_B)
    loss_D_B = compute_discriminator_loss(real_B, fake_B_, discriminator_B, loss_f_GAN)
    do_step(optimizer_D_B, loss_D_B, scheduler_D_B)
    
def do_step(optimizer, loss, scheduler=None):
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if scheduler is not None:
    	scheduler.step()

 

2. 각 loss를 계산하는 함수

- 이 부분이 Loss를 계산하는 부분으로, 위에서 수식으로 표현한 부분이다.

- Discriminator loss의 경우, real loss와 fake loss의 평균값을 사용했다.

def compute_idt_loss(real_A, real_B, generator_A2B, generator_B2A, loss_f_identity, lambda_identity):
    same_A = generator_B2A(real_A)
    same_B = generator_A2B(real_B)
    loss_idt_A = loss_f_identity(same_A, real_A)*lambda_identity
    loss_idt_B = loss_f_identity(same_B, real_B)*lambda_identity
    return loss_idt_A+loss_idt_B

def compute_GAN_loss(real_A, real_B, generator_A2B, generator_B2A, discriminator_A, discriminator_B, loss_f_GAN):
    fake_B = generator_A2B(real_A)
    fake_A = generator_B2A(real_B)
    pred_fake_B = discriminator_B(fake_B)
    pred_fake_A = discriminator_A(fake_A)
    loss_GAN_A2B = loss_f_GAN(pred_fake_B, True)
    loss_GAN_B2A = loss_f_GAN(pred_fake_A, True)
    return loss_GAN_A2B+loss_GAN_B2A, fake_A, fake_B

def compute_cycle_loss(real_A, real_B, fake_A, fake_B, generator_A2B, generator_B2A, loss_f_cycle, lambda_cycle):
    cycle_A = generator_B2A(fake_B)
    cycle_B = generator_A2B(fake_A)
    loss_cycle_ABA = loss_f_cycle(cycle_A, real_A)*lambda_cycle
    loss_cycle_BAB = loss_f_cycle(cycle_B, real_B)*lambda_cycle
    return loss_cycle_ABA+loss_cycle_BAB, cycle_A, cycle_B

def compute_discriminator_loss(real_im, fake_im, discriminator, loss_f_GAN):
    pred_real = discriminator(real_im)
    pred_fake = discriminator(fake_im.detach())
    loss_D_real = loss_f_GAN(pred_real, True)
    loss_D_fake = loss_f_GAN(pred_fake, False)
    return (loss_D_real+loss_D_fake)*0.5

 

3. GAN loss function

- official code에서 했던 것과 동일하게, GAN loss를 계산하는 class를 새로 정의했다.

- GAN loss를 계산할 때는 nn.BCEWithLogitsLoss() 혹은 nn.MSELoss()를 이용할 수 있다.

- cycleGAN에서는 vanishing gradient 문제를 해결하기 위해 MSELoss를 이용했다.

class GANLoss(nn.Module):
    def __init__(self, loss):
        """
        :param loss: nn.BCEWithLogitsLoss() / nn.MSELoss()
        """
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(1.0))
        self.register_buffer('fake_label', torch.tensor(0.0))
        self.loss = loss

    def __call__(self, pred, real):
        if real:
            target = self.real_label
        else:
            target = self.fake_label
        target = target.expand_as(pred)
        return self.loss(pred, target)

 


 

 

반응형