CVPR 2021에서 발표된 self-supervised learning 논문. 기존의 다른 self-supervised learning 방법인 SimCLR, SwAV, BYOL과의 비교를 통해 제안된 방법을 설명하고 있다. 개인적으로 문장이나 전체적인 구조가 매우 이해하기 쉽게 잘 쓰여져 있다고 느꼈다.
Introduction
Self-supervsied learning에서는 보통 Siamese network 구조를 많이 이용한다. 이는 weight를 서로 공유하는 neural network를 의미하는데, 이들은 각 entity를 비교하는 데에 유용하게 사용될 수 있다.
그러나 Siamese network는 output이 하나의 constant로 수렴하는 collapsing이 발생할 수 있다는 문제가 있다. 이를 해결하기 위해, 기존의 연구들은 다음의 방법들을 이용했다.
- Contrastive learning (SimCLR): negative pair을 추가로 이용
- Clustering (SwAV)
- Momentum encoder (BYOL)

본 논문에서 제안하는 방법인 SimSiam은 shared encoder을 이용하고, positive pair만을 가지고 이들이 가까워지는 방향으로 학습을 진행한다. 위의 세 가지 방법에서 하나씩 component를 제거한 구조라고 볼 수 있다.
- "SimCLR without negative pairs"
- "SwAV without online clustering"
- "BYOL without momentum encoder"
대신, SimSiam은 collapsing을 방지하기 위한 방법으로 stop-gradient를 이용한다.
Method
SimSiam network의 전반적인 구조는 아래 그림과 같다. 우선 forward pass가 어떻게 진행되는지 설명하고, 사용된 loss를 설명한다.

Forward pass
- 하나의 image xx에 대해 두가지 random augmentation을 진행해 x1x1, x2x2를 만든다.
- 두 augmented image는 encoder network encoder network ff를 통과한다. 이 때 두 encoder은 weight를 공유한다.
- Encoder을 통과한 두 vector 중 하나만 prediction MLP head hh를 통과해 새로운 vector를 얻는다.
- p1=h(f(x1))p1=h(f(x1))
- z2=f(x2)z2=f(x2)
Loss
- Forward pass로 얻은 두 vector의 negative cosine similarity를 최소화한다.
- D(p1,z2)=−p1||p1||2⋅z2||z2||2D(p1,z2)=−p1||p1||2⋅z2||z2||2 ... (1)
- 이는 두 vector를 l2-normalize한 후 MSE를 측정한 것과 동일하다.
- 두 vector의 순서를 바꿔서 한번 더 비교한 symmetrized loss를 사용한다.
- L=12D(p1,z2)+12D(p2,z1)L=12D(p1,z2)+12D(p2,z1)
Stop-grad
- SimSiam의 핵심적인 구조는 여기에 stop-grad를 추가하는 것이다.
- 이는 (1)을 다음과 같이 수정함으로써 구현될 수 있다.
- D(p1,stopgrad(z2))D(p1,stopgrad(z2))
- 이 경우 loss역시 다음과 같이 바뀐다.
- L=12D(p1,stopgrad(z2))+12D(p2,stopgrad(z1))L=12D(p1,stopgrad(z2))+12D(p2,stopgrad(z1))
- 이때 x2x2에 대한 encoder의 관점에서 보면, 첫 번째 term의 z2z2로부터는 gradient를 전달받지 않고,
- 반면 두 번째 term의 p2p2로부터 gradient를 전달받음을 알 수 있다. (x1x1의 경우도 동일)
- 이를 Pseudo-code로 나타내면 다음과 같다.

Empirical study

Stop-gradient를 썼을 때와 쓰지 않았을 때를 비교해 stop-gradient가 collapse를 해결했음을 보였다.
- (left) stop-gradient를 사용하지 않는 경우, optimizer이 빠르게 degenerated solution을 찾고 minimum loss에 도달한다.
- (middle) stop-gradient를 사용하지 않는 경우, l2-normalized output의 standard deviation이 거의 0임을 확인할 수 있다. 반면, stop gradient를 사용하는 경우 l2-normalized output의 standard deviation이 1/√d1/√d로, 이는 output이 collapse하지 않았음을 알 수 있다.
- (right) kNN으로 classification을 진행한 결과 역시 stop-gradient를 사용하지 않을 경우 매우 낮았다.
이외에도 BN, similarity function, symmeterization 등의 실험을 통해 collapse prevention이 오롯이 stop-gradient 덕분이었음을 보였다.
Hypothesis
SimSiam의 작동원리에 대한 가설로써, SimSiam은 Expectation-Maximization (EM)과 같은 방식으로 동작한다는 가정을 세웠다.
Formulation.
다음과 같은 loss 함수를 고려한다고 하자.
L(θ,η)=Ex,T[Fθ(T(x))−ηx||22]
- F : network parameterized by θ
- T : augmentation
η는 또다른 variable의 set이다. Set의 크기는 image x의 갯수와 동일하고, 이 때 ηx는 image x에 대한 representation이라고 볼 수 있다. 이는 꼭 network의 output일 필요는 없다.
그러면 이 경우를 k-means clustering과 동일하다고 생각할 수 있다. θ는 clustering center이 되고, ηx는 sample x에 할당된 vector (x의 representation)이라고 볼 수 있는 것이다.
마찬가지로, 위의 optimization problem 역시 k-means clustering과 동일하게 변수 하나씩을 고정해 가며 alternating algorithm으로 해결할 수 있다.
θt←argminθL(θ,ηt−1) ... (7)
ηt←argminηL(θt,η) ... (8)
Solving for θ.
θ에 대한 optimization (7)은 SGD를 이용해 해결할 수 있다. 단, 이 때 η의 경우 constant로 취급되므로 η에 대한 gradient는 stop해야 한다.
Solving for η.
η에 대한 optimization (8)은 다음과 같이 쉽게 해를 구할 수 있다. 단순히 MSE를 최소화하는 문제이므로..
ηTx←ET[Fθt(T(x))] ... (9)
즉, ηx는 x의 모든 augmentation에 대한 기댓값으로 나타내어질 수 있다는 것이다.
One-step alternation.
Augmentation을 하나만 사용하면, (9)는 다음과 같이 바뀌게 된다
ηTx←Fθt(T′(x)) ... (10)
이를 (7)에 대입하면 다음과 같다.
θt+1←argminθEx,T[||Fθ(T(x))−Fθt(T′(x))||22]
위 식에서 θt는 constant이고, 이 과정을 SGD로 구현한다면 이는 stop gradient를 적용한 Siamese architecture과 동일하게 볼 수 있다.
Predictor.
위 가정은 predictor h은 포함되어 있지 않으나, (10)에서 사용된 approximation을 위해서는 h를 도입하는 것이 도움이 될 수 있다.
SimSiam에서는 h가 다음을 minimize하기 위해 사용된다.
Ez[||h(z1)−z2||22]
이 때 h의 optimal solution은 다음을 만족해야 한다.
h(z1)=Ez[z2]=ET[f(T(x))]
이는 (9)와 비슷한 형태인데, (10)에서 expectation term을 무시했기 때문에, h를 사용하는 것이 이 gap을 채워줄 수 있다는 것이다. 실제로 Expectation을 계산하는 데에는 computation이 많이 들기 때문에, h가 expectation을 approximate할 수 있도록 대신 사용하는 것이다.
Comparison
다른 SOTA Self-supervised learning 방법들과의 비교이다.

본 방법은 SOTA 방법들보다 더 나은 representation을 학습한다는 것 보다는, 적은 batch size로도 SOTA와 견줄만한 성능을 낼 수 있다는 것에 초점을 맞추고 있다. negative pair이나 momentum encoder을 사용하지 않으면서, 다른 방법들보다 적은 iteration에서 더 좋은 성능을 보인다.

Transfer learning에서의 결과도 함께 report했다.