2021 ICCV์ Accept๋ ๋ ผ๋ฌธ์ธ "Rethinking the Truly Unsupervised Image-to-Image Translation"์ ์ ๋ฆฌํ ๊ธ์ ๋๋ค. Naver CLOVA AI์์ ์์ฑ๋ ๋ ผ๋ฌธ์ ๋๋ค.
์ด์ ๊น์ง์ Unsupervised model (cycleGAN ๋ฑ)์ ์ฌ์ค Semi-supervised ๋ชจ๋ธ์ด๋ผ๊ณ ํด์ผ ํ๋ค๊ณ ์๊ธฐํ๋ฉฐ, Data collection(labeling)์ด ํ์ํ์ง ์์ Truly unsupervised model์ธ TUNIT์ ์ ์ํ๊ณ ์์ต๋๋ค.
๋ ผ๋ฌธ ๋งํฌ: https://arxiv.org/pdf/2006.06500
Official code: https://github.com/clovaai/tunit
1. Levels of Supervision in Generative Models
์ด์ ๊น์ง๋ Generative model์ ๋ฐ์ดํฐ์ ์ข ๋ฅ์ ๋ฐ๋ผ Supervisied์ Unsupervised ๋ ๊ฐ์ง๋ก ๋๋ด์ต๋๋ค.
- Image์ Label์ด ์์ผ๋ก ์กด์ฌํ๋ ๊ฒฝ์ฐ๋ Supervised๋ผ๊ณ ํฉ๋๋ค. Conditional GAN ๋ฑ์ด ์ฌ๊ธฐ์ ์ํ๊ณ , ํ์ต์ด ๋น๊ต์ ์ฝ์ง๋ง ์ด๋ฌํ ๋ฐ์ดํฐ๋ ์ป๊ธฐ ํ๋ค๋ค๋ ๋จ์ ์ด ์์ต๋๋ค.
- Task์ ๋ฐ๋ผ, Image์ Label ๋ฐ์ดํฐ๋ฅผ ์์ผ๋ก ์ป๊ธฐ ํ๋ค๊ฑฐ๋ ์์ ๋ถ๊ฐ๋ฅํ ๊ฒฝ์ฐ๋ ์์ต๋๋ค. ์ด ๊ฒฝ์ฐ์๋ ๊ฐ Domain ๋ณ๋ก ๋ฐ์ดํฐ๋ฅผ ์์งํฉ๋๋ค. ๊ผญ Image์ Label ๋ฐ์ดํฐ๊ฐ ์์ผ๋ก ์กด์ฌํ์ง ์์๋ ๋ฉ๋๋ค. ์ด๋ฅผ Unsupervised๋ผ๊ณ ํ๊ณ , cycleGAN ๋ฑ์ด ์ด์ ์ํฉ๋๋ค.
๋ณธ ๋ ผ๋ฌธ์์๋ Generative model์ Supervision์ ์ ๋๋ฅผ ์ธ ๊ฐ์ง๋ก ๋ถ๋ฅํฉ๋๋ค. ์ด์ ๊น์ง Unsupervised๋ผ๊ณ ํ๋ ๊ฒ ์ญ์ Domain ๋ณ๋ก ๋ฐ์ดํฐ๋ฅผ ๋ชจ์ผ๊ณ , ๋ผ๋ฒจ๋ง์ ํด์ผ ํ๋ค๋ ์ ์์ Semi-supervised๋ผ๊ณ ๋ถ๋ฌ์ผ ํ๊ณ , ์ง์ ํ Unsupervised ๋ชจ๋ธ์ ์ด๋ฌํ ๋ฐ์ดํฐ ๋ผ๋ฒจ๋ง ์์ด๋ Image-to-image translation์ ์ํํ ์ ์์ด์ผ ํ๋ค๊ณ ์ฃผ์ฅํฉ๋๋ค.
Truly Unsupervised Learning
๋ณธ ๋ ผ๋ฌธ์์๋, True unsupervised learning์ด๋ ๊ฐ ๋ฐ์ดํฐ์ ๋ํ ๋ผ๋ฒจ(Class)์ด ์๋, ์ฌ๋ฌ Domain์ ์ด๋ฏธ์ง๋ค๋ก ๊ตฌ์ฑ๋ ํ๋์ ๋ฐ์ดํฐ์ ์์ ์ด๋ฏธ์ง๋ฅผ ์์ฑํด๋ผ ์ ์๋ ๊ฒ์ด๋ผ๊ณ ์ ์ํฉ๋๋ค.
True unsupervised learning์ ์ฅ์ ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
- ๋ณ๋์ Data annotation์ ํ์ง ์์๋ ๋ฉ๋๋ค.
- ๋ฐ๋ผ์, ์ด๋ฌํ ๋ผ๋ฒจ๋ง ์์ ์์ ์ค๋ Noise๋ฅผ ๋ฐฉ์งํ ์ ์์ต๋๋ค.
- Semi-supervised model์ ๋ํ ๊ฐ๋ ฅํ Baseline์ ์ ๊ณตํ ์ ์์ต๋๋ค.
TUNIT Architecture
TUNIT์ ๋ ๊ฐ์ง์ ๊ตฌ์กฐ, ์ธ ๊ฐ์ง์ Network๋ก ๊ตฌ์ฑ๋์ด ์์ต๋๋ค.
Guiding network๋ Input image์ Domain(label)์ ๋ถ๋ฅ(Clustering)ํ๊ณ Style code๋ฅผ ์ถ์ถํ๋ ์ญํ ์ ํฉ๋๋ค.
GAN์ Input image๋ฅผ Target domain์ผ๋ก ๋ณํํ๋ ์ญํ , ์ฆ Mapping function์ ํ์ตํฉ๋๋ค.
Guiding network๋ ๋ ๊ฐ์ Branch๋ฅผ ๊ฐ์ง ํ๋์ Encoder๋ก ๊ตฌ์ฑ๋์ด ์์ต๋๋ค.
ํ๋์ Branch($E_C$)์์๋ Clustering ๊ฒฐ๊ณผ(Pseudo label)๋ฅผ ์ถ๋ ฅํ๊ณ ,
๋ค๋ฅธ ํ๋์ Branch($E_S$)๋ Style code๋ฅผ ๋ด์ vector๋ฅผ ์ถ๋ ฅํฉ๋๋ค.
Style code๋ GAN์ Generator๊ฐ ์ด๋ฏธ์ง๋ฅผ ๋ณํํ๋ ๋ฐ์ ์ฌ์ฉ๋๊ณ ,
Pseudo label์ Discriminator๊ฐ ๋ณํ๋ ์ด๋ฏธ์ง์ Real/Fake๋ฅผ ํ๋จํ๋ ๋ฐ์ ์ฌ์ฉ๋ฉ๋๋ค.
Training Guiding Network
Guiding network์ ํ์ต ๋ฐฉ๋ฒ์ ๋ํด ๋จผ์ ์์๋ณด๊ฒ ์ต๋๋ค.
์์ ๋งํ๋๋ก, Guiding network๋ ๋์ผํ Encoder๋ฅผ ๊ณต์ ํ๋ ๋ ๊ฐ์ Branch๋ก ๊ตฌ์ฑ๋์ด ์์ต๋๋ค.
Pseudo label์ ์์ฑํ๋ $E_C$๋ Mutual information (MI)๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ๋ Clustering ๊ธฐ๋ฒ์ ์ด์ฉํ๊ณ ,
Style code๋ฅผ ์์ฑํ๋ $E_S$๋ Contrastive loss๋ฅผ ์ด์ฉํฉ๋๋ค.
Training $E_C$ : use differentiable clustering method based on mutual information (MI) maximization
ํ์ต ๋ฐฉ๋ฒ
- Input image $x$์, ์ด๋ฅผ Randomํ๊ฒ Augmentationํ Image $x^+$๋ฅผ ์ด์ฉํฉ๋๋ค.
- ์ด๋ค์ $E_C$์ input์ผ๋ก ์ค ๊ฒฐ๊ณผ๋ฅผ ๊ฐ๊ฐ $p$, $p+$๋ผ๊ณ ํ๋ฉฐ, ์ด๋ ๊ฐ K๊ฐ์ domain ๊ฐ๊ฐ์ ์ํ ํ๋ฅ ์ ๋ด์ vector, ์ฆ Pseudo label์ ๋๋ค. ($p=E_C(x)$)
- $p$, $p+$์ Mutual information์ด ์ต๋๊ฐ ๋๋ ๋ฐฉํฅ์ผ๋ก Encoder๋ฅผ ํ์ต์ํต๋๋ค.
- Loss function์ ์์์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
(Maximize) $L_{MI} = I(p,p^+) = I(P) = \sum^K_{i=1}\sum^K_{j=1}P_{ij}ln\frac{P_{ij}}{P_iP_j}$
์ฆ, $x$์ $x+$์ Pseudo label์ Mutual information์ด ์ต๋๊ฐ ๋๊ฒ ํ๋ ๊ฒ์ ๋๋ค.
Mutual information์ ๋ variable์ ์ํธ์์กด์ฑ์ ์ธก์ ํ ๊ฒ์ผ๋ก, ์ํค๋ฐฑ๊ณผ์์๋ ๋ค์๊ณผ ๊ฐ์ด ์ค๋ช ํ๊ณ ์์ต๋๋ค.
Mutual information is therefore the reduction in uncertainty about variable X , or the expected reduction in the number of yes/no questions needed to guess X after observing Y .
Mutual information์ ๋ํ ๋ณด๋ค ์์ธํ ์ค๋ช ์ ์๋ ๋งํฌ๋ฅผ ์ฐธ๊ณ ๋ฐ๋๋๋ค.
http://www.scholarpedia.org/article/Mutual_information
๊ฐ๋จํ ๋งํ์๋ฉด, Mutual information์ ์ต๋๊ฐ ๋๊ฒ ํจ์ผ๋ก์จ, Encoder์ ๋ Image $x$์ $x+$๊ฐ ๊ฐ์ label์ ์ํ๋๋ก ํฉ๋๋ค.
๋ณธ ๋ ผ๋ฌธ์์๋ Clustering์ ํ ๋ฐฉ๋ฒ์ผ๋ก์จ Mutual information maximization์ ์ฌ์ฉํ์์ผ๋, ๋ค๋ฅธ Clustering ๋ฐฉ๋ฒ์ ์ฌ์ฉํด๋ ๋๋ค๊ณ ์๊ธฐํฉ๋๋ค.
Training $E_S$ : use contrastive loss
ํ์ต ๋ฐฉ๋ฒ
- MoCo๋ผ๋ ๋ชจ๋ธ์์ ๊ฐ์ ธ์จ ๋ฐฉ๋ฒ์ ๋๋ค.
- $E_C$๋ฅผ ํ์ต์ํฌ ๋์ ๋ง์ฐฌ๊ฐ์ง๋ก, randomly augmented sample์ ์ด์ฉํฉ๋๋ค.
- Input image $x$๋ฅผ Randomly augmentํ Image $x+$์ Positive sample๋ก ์ด์ฉํ๊ณ , ๋ค๋ฅธ Image๋ค์ Negative sample๋ก ์ด์ฉํฉ๋๋ค ($x_n^-$).
- 1๊ฐ์ Positive sample $x+$๊ณผ N๊ฐ์ Negative sample $x_n^-$, ์ด N+1๊ฐ์ sample์ $E_S$ ์ ๋ฃ์ Output์ ์ด์ฉํด N+1 way classification์ ์ํํฉ๋๋ค.
- ์ฆ, $s=E_S(x)$๋ผ๊ณ ํ ๋, Positive pair์ style vector ($s$, $s^+$) ๊ฐ์ ์ ์ฌ๋๋ฅผ ์ต๋ํํ๊ณ , Negative pair์ style vector ($s$, $s^-$) ๊ฐ์ ์ ์ฌ๋๋ฅผ ์ต์ํํ๋ ๋ฐฉ์์ ๋๋ค.
- Loss function์ ์์์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
(Minimize) $L^E_{style} = -log\frac{exp(s\cdot s^+ / \tau)}{\sum^N_{i=0}exp(s \cdot s^-_i / \tau)}$
Training Guiding Network
์ ๋ฆฌํ์๋ฉด,
- Guiding Network๋ ๋์ผํ Encoder๋ฅผ ๊ณต์ ํ๋ฉฐ, Pseudo label์ ์์ฑํ๋ $E_C$์, Style code๋ฅผ ์์ฑํ๋ $E_S$ ๋ ๊ฐ์ง Branch๋ก ์ด๋ฃจ์ด์ ธ ์์ผ๋ฉฐ,
- ๊ฐ๊ฐ์ Loss function $L_{MI}$์ $L_{style}$์ ์ด์ฉํด ๊ฐ Branch๋ฅผ ํ์ต์ํต๋๋ค.
๊ทธ๋ฌ๋ ๊ฐ Branch๋ฅผ ๋ฐ๋ก๋ฐ๋ก ํ์ต์ํค๋ ๊ฒ์ด ์๋๊ณ , ๋ Loss function์ ํฉ์ณ ํจ๊ป (Jointly) ํ์ต์ ์งํํฉ๋๋ค.
์ด๋ ๊ฒ ๋ Task๋ฅผ ํ ๋ฒ์ ํ์ต์์ผฐ์ ๋์ ์ฅ์ ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
- Clustering์ Style code ํ์ต ๊ณผ์ ์์ ํ์ตํ๋ Rich representation์ ํ์ฉํ ์ ์์ต๋๋ค.
- Style code๋ Clustering ํ์ต ๊ณผ์ ์์ ํ์ตํ๋ Domain-specific nature๊ณผ, ๊ฐ์ Domain์ ์ํ๋ ๋ฐ์ดํฐ๋ค์ ์ ์ฌ์ฑ์ ํ์ฉํ ์ ์์ต๋๋ค.
์ด๋ฌํ ์ด์ ๋ก Joint training์ ์ด์ฉํ์๊ณ , Guiding network์ Loss function์ ๋ค์๊ณผ ๊ฐ๊ฒ ๋ฉ๋๋ค.
(Minimize) $-L_{MI} + L_{style}$
์ค์ ๋ก $L_{MI}$๋ฅผ ์ด์ฉํด $E_C$๋ฅผ ํ์ต์์ผฐ์ ๋๋ณด๋ค, $L_{style}$๊น์ง ํ์ฉํด Joint training์ ํ์ ๋ ์์ฑ๋ Pseudo label์ ์ ํ๋๊ฐ ๋์๋ค๊ณ ํฉ๋๋ค.
Training Generative Network (GAN)
Generative Network (GAN)์ 3๊ฐ์ง์ Loss function์ ์ด์ฉํด ํ์ต์์ผฐ์ต๋๋ค.
1) Realisticํ Image๋ฅผ ์์ฑํ๊ธฐ ์ํ Adversarial Loss์,
2) Style code๋ฅผ ์ ๋ณด์กดํ๊ธฐ ์ํ Style Contrastive Loss,
3) cycleGAN์ identity loss์ ์ ์ฌํ Image Reconstruction Loss
์ธ ๊ฐ์ง๋ฅผ ์ด์ฉํ์ต๋๋ค.
Adversarial Loss
์ผ๋ฐ์ ์ธ GAN Loss์ ๋์ผํฉ๋๋ค.
์ฐจ์ด์ ์, Generator์ Reference image์ style code $\widetilde{s}$๊ฐ input์ผ๋ก ๋ค์ด๊ฐ๋ค๋ ๊ฒ์ ๋๋ค.
Generator์ Style code $\widetilde{s}$๋ฅผ ๋ฐ์ํ๋ฉด์ Input image $x$๋ฅผ Target domain $\widetilde{y}$๋ก ๋ณํํ๋ ๊ฒ์ด ๋ชฉํ์ ๋๋ค.
์์์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
์ฐธ๊ณ ๋ก, Style code๋ฅผ ์ ์ฉํ๋ ๋ฐฉ๋ฒ์ AdaIN์ ์ด์ฉํ์ต๋๋ค.
Generator ๊ตฌ์กฐ์ 5๊ฐ์ AdaIN layer๊ฐ ํฌํจ๋์ด ์๋๋ฐ,
1) Style vector์ MLP์ ํต๊ณผ์์ผ ๊ฐ AdaIN์ ์ฌ์ฉํ Parameter๋ค์ ๋ฝ์๋ธ ๋ค์,
2) ์ด๋ฅผ ๊ฐ AdaIN layer์ ์ ์ฉ์ํค๋ ๋ฐฉ๋ฒ์ผ๋ก Input image์ Style์ ์ ํ์ต๋๋ค.
Style Contrastive Loss
Style Code๋ฅผ ์ ๋ณด์กดํ๊ธฐ ์ํ ์ถ๊ฐ์ ์ธ Loss์ ๋๋ค.
Generator์ ํตํด ์์ฑ๋ Image์ Style code $s'=E_S(G(x, \widetilde{s}))$์, Reference image์ Style code $\widetilde{s}$๊ฐ์ Contrastive Loss๋ฅผ ๊ณ์ฐํฉ๋๋ค.
์ด๋ ์์ฑ๋ Image๊ฐ Input์ผ๋ก ์ฃผ์ด์ง Style code๋ฅผ ์ผ๋ง๋ ์ ๋ณด์กดํ๊ณ ์๋์ง๋ฅผ ๊ณ์ฐํฉ๋๋ค.
์์์ ์๋์ ๊ฐ์ต๋๋ค.
Image Reconstruction Loss
Reference image๋ฅผ Source image์ ๋์ผํ๊ฒ ์ฃผ์์ ๋, ์์ฑ๋ Image๊ฐ ์๋ณธ๊ณผ ์ผ๋ง๋ ๋์ผํ์ง๋ฅผ ๊ณ์ฐํ๋ Loss์ ๋๋ค.
์ด๋ Source Image์ Style code๋ฅผ ์ด์ฉํ ๊ฒ์ด๊ธฐ ๋๋ฌธ์, ๋คํธ์ํฌ๊ฐ ์ ํ์ต๋์ด ์๋ค๋ฉด ์๋ณธ๊ณผ ๋์ผํ Image๊ฐ ๋์ค๋ ๊ฒ์ด ์ด์์ ์ ๋๋ค.
CycleGAN์ Identity Loss์ ์ ์ฌํ๋ค๊ณ ๋๊ผ์ต๋๋ค.
์์์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
Train All Network Jointly
์์, Guiding network๋ฅผ ํ์ตํ ๋ ๋ Loss function๋ค์ ํ๋ฒ์ ํ์ตํ๋ค๊ณ ํ์ต๋๋ค.
GAN์ ํ์ต๋ ๋ชจ๋ Loss function์ ํ ๋ฒ์ ํ์ต์ํค๊ณ ,
์ด ๋ฟ ์๋๋ผ Guiding network์ GAN์ ํ์ต ์ญ์ ๋์์ ์งํ๋ฉ๋๋ค.
๋ค์ ๋งํด, ๋ชจ๋ ํ์ต์ด end-to-end๋ก ํ ๋ฒ์ ์งํ๋๋ ๊ฒ์ ๋๋ค.
์ด๋ Guiding network์์ Clustering๊ณผ Style code๊ฐ ์๋ก์ ํ์ต์ ๋์์ ์ฃผ์๋ ๊ฒ์ฒ๋ผ,
GAN๊ณผ Guiding network ์ญ์ ์๋ก์ ํ์ต ๊ณผ์ ์ ๋์์ ์ค๋ค๊ณ ํฉ๋๋ค.
Experiments & Results
์คํ์ Supervised Translation์์ SOTA ๋ชจ๋ธ์ธ FUNIT์ Unsupervised setting์ผ๋ก ๋ฐ๊พผ ๊ฒ๊ณผ ๋น๊ตํ์ต๋๋ค. ์ด์ธ์๋ ๋ค์ํ ์คํ์ ์งํํ์ง๋ง, ์ผ๋ถ ๊ฒฐ๊ณผ๋ง ๊ฐ๋ตํ ์๊ฐํ๋๋ก ํ๊ฒ ์ต๋๋ค.
์์ธํ ์คํ๋ด์ฉ๊ณผ ๊ฒฐ๊ณผ๋ ๋ ผ๋ฌธ์ ์ฐธ๊ณ ํ์๋ฉด ๋ ๊ฒ ๊ฐ์ต๋๋ค.
Labeled Dataset์ ๋ํ ์คํ
Unlabeled Dataset์ ๋ํ ์คํ
Conclusion
๋ ผ๋ฌธ์์ ์ฃผ์ฅํ๋ Contribution์ ์ฌ๋ฌ ๊ฐ์ง๊ฐ ์์ง๋ง, ๊ทธ ์ค์์๋ ๊ฐ์ฅ ์ค์ํ ๊ฒ์ Unsupervised image-to-image translation์ ์ฌ์ ์ํ๊ณ , ์ด๋ฌํ Task๋ฅผ ์ํํ๋ End-to-end model์ ์ ์ํ๋ค๋ ์ ์ธ ๊ฒ ๊ฐ์ต๋๋ค.
๋ณธ ๊ฒ์๋ฌผ์์๋ ์ต๋ํ ๊ฐ๋ตํ๊ฒ ๋ ผ๋ฌธ์ ์ ๋ฆฌํ์ง๋ง, ๋ณด๋ค ๋ํ ์ผํ ๋ถ๋ถ์ด๋ ๊ตฌํ ๊ด๋ จ ๋ด์ฉ๋ค์ ๋ ผ๋ฌธ๊ณผ ๊ณต์ ๊นํ ์ฝ๋๋ฅผ ์ฐธ๊ณ ํ์๋ฉด ๋ ๊ฒ ๊ฐ์ต๋๋ค.