CVPR 2021에서 발표된 self-supervised learning 논문. 기존의 다른 self-supervised learning 방법인 SimCLR, SwAV, BYOL과의 비교를 통해 제안된 방법을 설명하고 있다. 개인적으로 문장이나 전체적인 구조가 매우 이해하기 쉽게 잘 쓰여져 있다고 느꼈다.
Introduction
Self-supervsied learning에서는 보통 Siamese network 구조를 많이 이용한다. 이는 weight를 서로 공유하는 neural network를 의미하는데, 이들은 각 entity를 비교하는 데에 유용하게 사용될 수 있다.
그러나 Siamese network는 output이 하나의 constant로 수렴하는 collapsing이 발생할 수 있다는 문제가 있다. 이를 해결하기 위해, 기존의 연구들은 다음의 방법들을 이용했다.
- Contrastive learning (SimCLR): negative pair을 추가로 이용
- Clustering (SwAV)
- Momentum encoder (BYOL)
본 논문에서 제안하는 방법인 SimSiam은 shared encoder을 이용하고, positive pair만을 가지고 이들이 가까워지는 방향으로 학습을 진행한다. 위의 세 가지 방법에서 하나씩 component를 제거한 구조라고 볼 수 있다.
- "SimCLR without negative pairs"
- "SwAV without online clustering"
- "BYOL without momentum encoder"
대신, SimSiam은 collapsing을 방지하기 위한 방법으로 stop-gradient를 이용한다.
Method
SimSiam network의 전반적인 구조는 아래 그림과 같다. 우선 forward pass가 어떻게 진행되는지 설명하고, 사용된 loss를 설명한다.
Forward pass
- 하나의 image $x$에 대해 두가지 random augmentation을 진행해 $x_1$, $x_2$를 만든다.
- 두 augmented image는 encoder network encoder network $f$를 통과한다. 이 때 두 encoder은 weight를 공유한다.
- Encoder을 통과한 두 vector 중 하나만 prediction MLP head $h$를 통과해 새로운 vector를 얻는다.
- $p_1=h(f(x_1))$
- $z_2=f(x_2)$
Loss
- Forward pass로 얻은 두 vector의 negative cosine similarity를 최소화한다.
- $\mathcal D(p_1,z_2)=-\frac{p_1}{||p_1||_2}\cdot \frac{z_2}{||z_2||_2}$ ... (1)
- 이는 두 vector를 l2-normalize한 후 MSE를 측정한 것과 동일하다.
- 두 vector의 순서를 바꿔서 한번 더 비교한 symmetrized loss를 사용한다.
- $\mathcal L = \frac12 \mathcal D(p_1,z_2)+\frac12 \mathcal{D}(p_2,z_1)$
Stop-grad
- SimSiam의 핵심적인 구조는 여기에 stop-grad를 추가하는 것이다.
- 이는 (1)을 다음과 같이 수정함으로써 구현될 수 있다.
- $\mathcal D(p_1, \text{stopgrad}(z_2))$
- 이 경우 loss역시 다음과 같이 바뀐다.
- $\mathcal L = \frac12 \mathcal D(p_1,\text{stopgrad}(z_2))+\frac12 \mathcal{D}(p_2,\text{stopgrad}(z_1))$
- 이때 $x_2$에 대한 encoder의 관점에서 보면, 첫 번째 term의 $z_2$로부터는 gradient를 전달받지 않고,
- 반면 두 번째 term의 $p_2$로부터 gradient를 전달받음을 알 수 있다. ($x_1$의 경우도 동일)
- 이를 Pseudo-code로 나타내면 다음과 같다.
Empirical study
Stop-gradient를 썼을 때와 쓰지 않았을 때를 비교해 stop-gradient가 collapse를 해결했음을 보였다.
- (left) stop-gradient를 사용하지 않는 경우, optimizer이 빠르게 degenerated solution을 찾고 minimum loss에 도달한다.
- (middle) stop-gradient를 사용하지 않는 경우, l2-normalized output의 standard deviation이 거의 0임을 확인할 수 있다. 반면, stop gradient를 사용하는 경우 l2-normalized output의 standard deviation이 $1/\sqrt d$로, 이는 output이 collapse하지 않았음을 알 수 있다.
- (right) kNN으로 classification을 진행한 결과 역시 stop-gradient를 사용하지 않을 경우 매우 낮았다.
이외에도 BN, similarity function, symmeterization 등의 실험을 통해 collapse prevention이 오롯이 stop-gradient 덕분이었음을 보였다.
Hypothesis
SimSiam의 작동원리에 대한 가설로써, SimSiam은 Expectation-Maximization (EM)과 같은 방식으로 동작한다는 가정을 세웠다.
Formulation.
다음과 같은 loss 함수를 고려한다고 하자.
$\mathcal L(\theta,\eta)=\mathbb E_{x,\mathcal T}[\mathcal F_\theta(\mathcal T(x))-\eta_x||_2^2]$
- $\mathcal F$ : network parameterized by $\theta$
- $\mathcal T$ : augmentation
$\eta$는 또다른 variable의 set이다. Set의 크기는 image $x$의 갯수와 동일하고, 이 때 $\eta_x$는 image $x$에 대한 representation이라고 볼 수 있다. 이는 꼭 network의 output일 필요는 없다.
그러면 이 경우를 k-means clustering과 동일하다고 생각할 수 있다. $\theta$는 clustering center이 되고, $\eta_x$는 sample $x$에 할당된 vector ($x$의 representation)이라고 볼 수 있는 것이다.
마찬가지로, 위의 optimization problem 역시 k-means clustering과 동일하게 변수 하나씩을 고정해 가며 alternating algorithm으로 해결할 수 있다.
$\theta^t \leftarrow \text{argmin}_\theta \mathcal L(\theta,\eta^{t-1})$ ... (7)
$\eta^t \leftarrow \text{argmin}_\eta\mathcal L(\theta^t,\eta)$ ... (8)
Solving for $\theta$.
$\theta$에 대한 optimization (7)은 SGD를 이용해 해결할 수 있다. 단, 이 때 $\eta$의 경우 constant로 취급되므로 $\eta$에 대한 gradient는 stop해야 한다.
Solving for $\eta$.
$\eta$에 대한 optimization (8)은 다음과 같이 쉽게 해를 구할 수 있다. 단순히 MSE를 최소화하는 문제이므로..
$\eta_x^T \leftarrow \mathbb E_\mathcal T [\mathcal F_{\theta^t}(\mathcal T (x))]$ ... (9)
즉, $\eta_x$는 $x$의 모든 augmentation에 대한 기댓값으로 나타내어질 수 있다는 것이다.
One-step alternation.
Augmentation을 하나만 사용하면, (9)는 다음과 같이 바뀌게 된다
$\eta_x^T \leftarrow\mathcal F_{\theta^t}(\mathcal T' (x))$ ... (10)
이를 (7)에 대입하면 다음과 같다.
$\theta^{t+1} \leftarrow \text{argmin}_\theta\mathbb E_{x,\mathcal T} [||F_{\theta}(\mathcal T(x))-F_{\theta^t}(\mathcal T' (x))||_2^2]$
위 식에서 $\theta^t$는 constant이고, 이 과정을 SGD로 구현한다면 이는 stop gradient를 적용한 Siamese architecture과 동일하게 볼 수 있다.
Predictor.
위 가정은 predictor $h$은 포함되어 있지 않으나, (10)에서 사용된 approximation을 위해서는 $h$를 도입하는 것이 도움이 될 수 있다.
SimSiam에서는 $h$가 다음을 minimize하기 위해 사용된다.
$\mathbb E_z[||h(z_1)-z_2||_2^2]$
이 때 $h$의 optimal solution은 다음을 만족해야 한다.
$h(z_1)=\mathbb E_z[z_2]=\mathbb E_\mathcal T[f(\mathcal T(x))]$
이는 (9)와 비슷한 형태인데, (10)에서 expectation term을 무시했기 때문에, $h$를 사용하는 것이 이 gap을 채워줄 수 있다는 것이다. 실제로 Expectation을 계산하는 데에는 computation이 많이 들기 때문에, $h$가 expectation을 approximate할 수 있도록 대신 사용하는 것이다.
Comparison
다른 SOTA Self-supervised learning 방법들과의 비교이다.
본 방법은 SOTA 방법들보다 더 나은 representation을 학습한다는 것 보다는, 적은 batch size로도 SOTA와 견줄만한 성능을 낼 수 있다는 것에 초점을 맞추고 있다. negative pair이나 momentum encoder을 사용하지 않으면서, 다른 방법들보다 적은 iteration에서 더 좋은 성능을 보인다.
Transfer learning에서의 결과도 함께 report했다.