loading

DINO(Emerging Properties in Self-Supervised Vision Transformers) 논문 리뷰


SSL, Vision Transformer

Published on November 12, 2023 by JunYoung

SSL ViT DINO

12 min READ

들어가며 …

제목에서 알 수 있듯이 이 논문은 Vision Transformer가 자기 학습을 통해 습득할 수 있는 능력이나 특성에 대해 논의한다. ViT의 프레임워크가 제안된 배경에는 자연어 분야의 Transformer 구조가 존재하는데, 이미 GPT나 BERT와 같은 후속 연구를 기반으로 NLP에서는 Large Dataset의 self-supervised learning이 downstream task에서 보다 풍부한 semantic information을 제공한다는 사실이 증명된 바 있다. 이와는 다르게 ViT의 학습 구조를 보게 되면 언어 모델과 같이 대용량의 이미지 데이터셋을 사용하여 사전 학습하는 과정이 이후의 downstream task에도 도움이 된다는 사실은 증명이 되었으나, 여전히 supervised learning 구조에서 벗어나지 못한 것을 알 수 있다.


발견한 특성들

논문 구성은 간단하게도 아이디어를 develop하는 과정(self-supervised learning 방법론에 대한 approach)에 대한 motivation으로 시작하게 되고, 해당 방법론으로부터 온 효과를 언급하면서 이를 증명할 여러 실험 결과들을 보여주게 된다. 논문에서 발견한 ViT의 self-supervised learning 특성을 요약하면 다음과 같다.

  • 자기 학습을 통해 획득한 ViT의 feature는 이미지의 semantic segmentation 정보를 가지게 되고, 이는 지도 학습으로 학습된 ViT나 convnet에서도 발견되지 않은 특성이다. 실제로 아래 그림과 같이 attention 정보를 통해 네트워크가 각 이미지 단위로 포커싱하고있는 영역이 곧 이미지 상에서 object의 semantic한 정보 그 자체라는 것을 확인할 수 있다.

  • 또한 이렇게 획득된 feature는 유사도에 기반한 $k$-NN classifier로 활용될 수 있고 small ViT로도 ImageNet(recognition task)에서 좋은 정확도를 보임을 확인하였다.
  • 마지막으로 여러 셋팅에서의 실험 및 충분한 ablation을 통해 ViT의 자기 학습 과정에서 효과적으로 쓰일 수 있는 방법론들을 직접 실험들을 통해 규명했으며, momentum encoder, multi-crop training 그리고 smaller patch(more number of patches)가 중요하게 쓰인다.

Self-supervised learning Frameworks

대표적인 label이 없는 환경에서의 unsupervised(self-supervised) learning 접근법으로는 SimCLR, MoCo 그리고 BYOL가 있었다. 갑자기 이 얘기를 꺼낸 이유는 이 페이퍼에서 논하고자 했던 property가 곧 ViT의 self-supervised learning으로부터 나오기 때문에, 제안된 structure의 근거를 알기 위해서는 이전 논문들의 참고가 필수적이기 때문이다. DINO 논문에서는 SimCLR, MoCo 그리고 BYOL 중 BYOL에서 inspiration을 얻었다고 한다.

SimCLR의 구조는 첫번째 그림과 같다. Input $x$가 주어지면 이를 정해진 random augmentation을 적용한 각각의 샘플 $\tilde{x}_i$ 그리고 $\tilde{x}_j$로 만들게 되고, 이를 뉴럴 네트워크 $f(\cdot)$에 통과시킨 output $h_i$, $h_j$ 를 하나의 representation/embedding이라고 했을 때 이를 Linear operation($g(\cdot)$)으로 mapping한 최종 latent인 $z_i$ 그리고 $z_j$를 contrastive하게 학습하는 방법을 사용하였다. Moco도 큰 틀에서는 contrastive learning과 두 개의 branch를 사용한다는 점에서 SimCLR와 거의 동일하지만, 차이점이라고 한다면 SimCLR은 배치 내에서 동일한 인코더를 기준으로 representation 학습을 진행하지만 MoCo는 학습이 되지 않고 EMA 방식으로 업데이트되는 momentum encoder가 사용된다는 점이다. 쿼리에 사용되는 배치는 매 학습마다 enqueuedequeue 를 통해 최신 mini-batch가 지속적으로 업데이트되며, positive logit은 동일 배치의 샘플에 대해, negative logit은 이전 queue의 샘플에 대해 연산을 진행하게 된다. Querying에 사용되는 encoder를 점진적으로 학습하는 방법을 적용했다는 점에서 차이가 생긴다.


BYOL paper에 대한 짧은 논문 리뷰

BYOL은 momentum encoder의 장점을 가져오면서 학습의 전반적 형태는 SimCLR와 같은 방식을 가져왔다. 그러나 알고리즘 측면에서 큰 차이가 있는데, 이는 바로 BYOL은 contrastive learning을 하지 않는다는 점, 즉 negative pair가 필요하지 않다는 것이다.

BYOL에서 가장 크게 주목할 점은 어떻게 negative pair와 같이 collapse를 방지할 만한 장치가 없이도 안정적인 학습이 가능한가에 대한 부분이다. 바로 이 부분에서 대체 왜 online/target 두 브랜치가 서로 assymetric(비대칭)하게 구성되었는가를 확인할 수 있다. 예컨데 predictor $q_\theta(\cdot)$는 projection 목적이 되는 $g_\theta$와 동일한 신경망 구조를 가진다. 단순하게 생각했을때 prediction은 하나의 classifier라고 생각할 수 있지만 그렇지는 않고 projection network와 같은 dimension의 output을 내보낸다. 하지만 바로 이 projector 부분이 학습되면서 optimal point에 놓여있는 것이 가장 주요한 학습 키포인트로 작용한다.

Projector가 수렴했다는 것은 곧 projector가 어느 정도 optimal한 영역에 있다고 볼 수 있고 이를 $q^\ast_\theta$라고 한다면 online branch의 input이 되는 $z_\theta$에 대해 $q^\ast_\theta(z_\theta) = \mathbb{E}[z^\prime_\xi \vert z_\theta]$로 표현이 가능하다. 수식 상에서 조건부 expectation은 $z_\theta$에 대한 함수가 되며, 조건부 확률 분포와 동일한 의미를 가진다. 즉 우리가 흔히 optimal하게 학습된 neural network를 특정 도메인의 데이터셋 ${X, Y}$에 대해 parameterized posterior $p_\theta(Y \vert X)$로 표현하는 것처럼 projector가 수렴했다는 가정 하에 $z^\prime_\xi$와의 수식으로 해석할 수 있다. 결국 이 가정을 통해 수식을 다시 전개하게 되면 simplified BYOL loss(원래 BYOL에서는 view를 교차하는 형태로 symmetric한 cost function을 구성하는 것과 더불어 latent의 정규화 작업이 추가됨)은 다음과 같이 표현 가능하며

[ \mathcal{L}_\text{BYOL} = \mathbb{E}\left(\parallel \mathbb{E}(z^\prime_\xi\vert z_\theta)-z_\xi^\prime\parallel_2^2\right) ]

결국 학습 파라미터(online branch) $\theta$에 대한 gradient는 다음과 같이 expected variance의 gradient로 수렴하게 된다.

[ \nabla_\theta \mathcal{L}_\text{BYOL}= \nabla_\theta\mathbb{E}\left(\sum_i \text{Var}(z^\prime_{\xi, i} \vert z_\theta) \right) ]

이러한 가정은 optimal projector가 수렴했다는 전제에서 성립하게 되는데, 이를 통해 BYOL loss는 수렴된 projector를 변화시키지 않고 online network를 업데이트할 수 있다. 위의 수식은 파라미터 및 projection을 다변수로 가지는 최적화 함수를 Lagrangian으로 표현했을 때의 envelop theorem 그리고 optimality condition에 기반한다.

BYOL에서는 이렇게 업데이트되는 $\theta$에 대해 online branch와 target branch $\xi$가 동시에 감소하는 방향은 loss surface $\mathcal{L}$에는 정의될 수 없다는 것이다. Target branch에서의 projection $z^\prime_\xi$와 online branch에서의 $z_\theta$에 대한 Variance로 loss 최적화 식을 만들었었고 이게 의미하는 바는 projector가 어느 정도 수렴한 상황에서 가장 말단의 posterior는 고정된 상태라고 보는 것이다. 임의의 random variable에 대해 조건부 분산은 조건 변수가 추가될수록 이전 분산의 lower bound가 된다. 만약 BYOL을 통한 최적화 과정이 collapse를 일으킨다면 online network의 projection인 $z_\theta$는 더이상 무작위로 분포한 확률 랜덤 변수가 아닌 constant $c$로 고정될 수 있고, 이는 parameter space에서기존 업데이트 과정이 lower bound가 됨을 명시할 수 있는 근거가 된다.

[ \text{Var}(z^\prime_\xi \vert z_\theta) \le \text{Var}(z^\prime_\xi \vert c) ]

즉 collapse가 일어날 수 있는 환경이 parameter surface에서 보다 큰 값을 가지기 때문에 더 불안정(unstable), collapse가 발생하지 않는다.

만약 반대라면 어떻게 될까? 이 논문에서는 $\xi$를 loss function을 기준으로 업데이트하지 않고 EMA를 사용하여 점진적 과부하를 걸었는데, 이는 같은 위의 수식으로 그 이유를 찾을 수 있다. Target network에 collapse가 발생한다면 이번에는 $z^\prime_\xi = c$ 인 deterministic constant가 되고, 이번에는 조건부 변수가 아닌 메인 변수에 해당되므로 분산이 0이 된다.

[ \text{Var}(c \vert z_\theta) = 0 \le \text{Var}(z^\prime_\xi \vert z_\theta) ]

즉 $\xi$에 대해서 학습하게 되면 무조건 collapse가 발생하게 된다는 것을 볼 수 있다. BYOL는 이러한 이론적 배경에 근거하여 negative pair를 굳이 구하지 않더라도 similarity loss를 기반으로 점진적으로 latent를 bootstrapping (과거의 online parameter가 미래의 online parameter의 학습에 도움이 되는 과정)하는 방법을 제시하였고, 이는 batch size로부터의 자유 및 자기 학습 방법의 지평을 보다 넓힐 수 있는 계기가 되었다.


DINO approach

DINO는 ‘knowledge distillation with no labels’의 줄임말로, 말 그대로 ViT 학습에 SSL 프레임 워크를 제안한 형태가 된다. 이 방법 역시 student/teacher(혹은 online/target)의 두 브랜치 간의 학습이 진행되는데, 안정적인 pseudo label을 만들어내는 teacher은 loss term에 대한 파라미터 최적화가 발생하지 않고consistency를 통해 지속적으로 update되는 student parameter를 지수 평균 이동(exponential moving average)으로 가져온다.

Knowledge Distillation

논문에서 접근한 SSL은 다음과 같다. Student model과 teacher model은 학습되는 중간에는 데이터가 매핑되는 함수로 작용하고, 이는 곧 probability mapper로 해석 가능하다. 만약 student network($g$)의 파라미터를 $\theta_s$라 한다면 입력 신호 $x$에 대한 output logit $g_{\theta_s}(x)$를 구할 수 있다. 그리고 이 logit에 softmax function을 적용하면 확률로의 직접 매핑이 가능하다. 이때 softmax가 적용되는 dimension은 특징자(feature embedding) 축이라고 생각하면 된다.

[ P_s(x)^{(i)} = \frac{\exp(g_{\theta_s}(x)^{(i)}/\tau_s)}{\sum_{k=1}^K\exp(g_{\theta_s}(x)^{(k)}/\tau_s)} ]

여기서 temperature $\tau_s$가 사용되는데, 이는 다양한 논문들에서 probability distribution의 분포를 결정하는 하이퍼파라미터 혹은 학습 가능한 파라미터로 많이 사용된다. 이 논문에서의 temperature parameter의 목적은 student network에 의한 probability의 sharpness 조절 역할을 하게 된다. 마찬가지로 teacher network에 대해서도 다음과 같은 formulation이 가능하다.

[ P_t(x)^{(i)} = \frac{\exp(g_{\theta_t}(x)^{(i)}/\tau_t)}{\sum_{k=1}^K\exp(g_{\theta_t}(x)^{(k)}/\tau_t)} ]

Knowledge distillation에서 학습은 teacher의 output을 일종의 ground truth로 가정한 student output과의 consistency loss이다. 즉 cross entropy에서 one-hot label을 teacher network의 output으로 바꿨다고 생각하면 된다.

[ \underset{\theta_s}{\min} H(P_t(x),~P_s(x)) = \min_{\theta_s} {-P_t(x) \log P_s(x)} ]

크로스 엔트로피가 의미하는 것이 정보이론에서 “하나의 확률분포”가 “또다른 확률분포”가 가지는 정보와 얼마나 가까운지에 따른 거리 metric이기 때문에 결국 학습 목적은 학생으로 하여금 선생의 지식을 잘 모방하도록 하는 것이 된다. 하지만 단순히 이 방법론으로 마무리되는 알고리즘은 아니고, DINO는 효과적인 학습을 위해 다양한 방법들을 추가하게 된다. 예컨데 위의 수식은 앞서 보여준 framework와는 다르게 augmentation에 대한 내용이 없지만, 저자는 바로 이 수식 전개 직후 단순 distillation을 사용함에 따라 생기는 문제점들을 언급한다.

Data augmentation

Transformer 구조가 가지는 가장 큰 문제점 중 하나가 local-to-global correspondence가 적다는 것이다. Transformer는 attention을 기반으로 단번에 global information을 인지하기 때문에 convolution network에 비해 가지는 장점도 있겠지만, local information을 포착하기 전에 모든 attention map들이 global feature에서 수렴해버린다면 CNN이 가지는 계층적 구조에 의한 correspondency(feature간 상관관계에서 얻을 수 있는 high to low level 효과)를 SSL에서 기대할 수 없다는 문제가 있다.

따라서 이를 해결하는 방법으로 augmentation의 비대칭을 사용하였다. 이 프로세스를 요약하면 다음과 같다:

  1. Teacher가 local/global에 대한 consistency를 가지고 어느 정도 수렴했다는 가정 하에, teacher는 이미지의 global한 형태를 보고 ‘그럴 듯한’ 예측을 한다.
  2. 위의 가정이 있다면 teacher network에는 계속 global한 image 정보만 주면 된다.
  3. Teacher는 student의 파라미터로부터 EMA된다. 즉, teacher의 바람직한 수렴을 위해서는 student가 앞서 말했던 local/global에 대한 consistency 정보를 학습할 수 있는 환경이 되어야한다.
  4. 따라서 student에는 local image 정보를 같이 준다.

길게 설명했지만 풀어쓰자면 teacher network에는 이미지에 큰 범위에서의 augmentation이 들어간 global view $x_1^g$, $x_2^g$만 예측에 사용되고, student network에는 multi-crop strategy와 같은 이미지의 작은 범위까지의 augmentation이 적용된 local/global 정보가 예측에 사용된다.

[ \underset{\theta_s}{\min}\sum_{x \in {x_1^g,~x_2^g}} ~~\sum_{x^\prime\in V,~x^\prime \neq x} H(P_t(x), P_s(x^\prime)) ]

논문에서는 보통의 방식과 같이 multi-crop image들을 생성했는데, 2개의 global views는 원본 이미지 대비 $50\%$보다 큰 크기만큼 잘라서 쓰고 여러 local view는 반대로 $50\%$보다 작은 크기만큼 잘라서 쓴다.

Avoiding collapse

Self-supervised learning의 문제점은 representation 학습에 대한 ground truth가 없기 때문에 collapse가 발생할 수 있다는 것이다. 사실상 우리가 많이 알고 있는 contrastive learning이든, clustering 방식이든, predictor를 다는 BYOL과 같은 구조라던지 Batch Normalization을 도입하는 등등의 approach는 공통적으로 collapse를 막는 역할을 하게 된다. 물론 DINO 또한 normalization 구조라던지 앞서 언급한 여러 방법론으로 stabilization을 수행할 수 있었지만, 이 논문에서는 momentum teacher network의 output을 centering 및 sharpening하는 구조를 통해 이러한 효과를 얻을 수 있다고 한다. Sharpening/Centering에 대한 내용은 조금 알아보기 쉽게 나타내면 Sharpening은 temperature 조절을 통해 softmax 예측값 분포를 보다 명확하게 드러내는 것이고 centering은 teacher output에 center value $c$를 bias term으로 더해주어 예측값 사이의 차이를 조절해주게 된다. 즉 sharpening과 centering은 효과만 보게 되면 서로 반대의 역할을 수행한다. 여기서 드는 의문점은, 굳이 sharpening을 통해 prediction의 entropy minimization을 수행할 목적이었다면 왜 다시 centering이라는 방법으로 다시금 prediction을 재조정하는 과정을 거치는지에 대한 부분이다. 이 부분에 대해서 나름대로 이해한 것은 다음과 같다.

우선 centering에 사용될 bias term $c$는 다음 식을 통해 exponentially update가 된다. EMA 방식으로 teacher parameter가 업데이트되는 것과 동일하다.

[ c \leftarrow mc + (1-m)\frac{1}{B}\sum_{i=1}^B g_{\theta_t}(x_i) ]

식을 자세히 보면 center $c$에는 결국 학습 시 사용되는 batch size랑은 무관하게, 기존 input에 대한 model의 output(prediction) 정보를 평균으로 저장하게 된다. 이제 centering에 대한 맥락은 얼추 이해했고, 다시 sharpening으로 돌아가보도록 하자.

예컨데 고양이라는 이미지에 대해 모델이 낸 prediction을 sharpening하는 작업을 하게 되면 뭉뚱그려진 예측값을 어느 정도 명확하게 하면서 feature map을 선명하게 만들어 준다는 장점이 있지만, 만약 배치 단위로 들어오는 특정 input이 모델로 하여금 지속적으로 collapse가 발생하게 한다면, 이는 불난 집에 부채질하는 격이 된다.

원래 목적이라 함은 다른 input에 대해 특징을 잘 잡아내는 feature를 뽑고자 sharpening을 도입했는데, contrastive learning과 같은 제어장치가 없다면 모델은 그냥 단순히 네트워크 예측 자체의 entropy를 낮추는 방향으로 끊임없이 학습이 될 것이기 때문에 이미지의 종류에 상관없이 단일의 feature를 뽑게 되고, 이러한 문제를 trivial solution이 발생한다고 한다. 결국 기존 SSL approach를 사용하지 않고는 이를 근본적으로 해결하기 어렵다는 문제가 발생한다.

그렇기 때문에 만약 centering term이 있다면 이를 단순히 model prediction에 더해주는 것만으로도 이전 배치들의 정보를 가져올 수 있으며, batch size의 크기에 robust한 학습 효과를 보여주는 것이다. 예컨데 contrastive learning에서는 positive sample과 negative sample 쌍을 얻기 위해 최대한 많은 배치 수가 필요했고, 그 이유는 모델이 학습할 때 metric learning을 적은 단위의 배치 내에서 진행하는 것보다는 큰 배치 내에서 진행하는 것이 전체 데이터셋의 확률 분포를 잘 나타낼 수 있기 때문이었다. 하지만 위와 같이 output을 뽑아서 배치 단위의 prediction을 저장하고, 이를 이후의 output을 sharpning할 때 smoothing에 사용하는 것만으로도 배치 사이즈를 키우지 않고 이러한 효과를 볼 수 있다는 것이 바로 sharpning/centering이 가지는 장점이다. 사실 까놓고 말하자면 단순하게 이전 prediction을 일종의 prototype으로 저장해놓고 쓴다는 느낌인데, 저자들은 이 방법론이 실제로 학습에 미치는 영향을 보여주기 위해 실험을 진행하였다.

Sharpening의 효과는 entropy를 0으로 만든다. 그리고 centering의 효과는 smoothing을 통해 어떤 input이 들어오든 entropy를 유지시킨다. 둘 중 하나만 사용하면 epoch에 따라 representation overfitting/underfitting이 발생하는모습을 잘 확인할 수 있다. 무엇보다 이를 잘 보여주는 실험이 KL divergence에 있지 않을까 싶다.

Teacher/student 구조를 쓰면서 얻고 싶은 장점은 EMA 방식으로 기존 representation 정보를 차곡차곡 모아가는 teacher network의 prediction을 student가 따라가면서 서로 간의 학습에 bootstrapping이 일어날 수 있다는 것인데 만약 representation이 collapse가 된다면 이러한 효과를 볼 수 없을 것이고, 결국 bootstrapping이 없다는 것은 학습이 진행되면서 student/teacher prediction 차이가 없어진다는 것이다.


실험 결과

Classification

SSL/Unsupervised Learning의 경우 학습된 feature를 증명하는 과정이 여러 가지로 분류된다. 우선 downstream task에 맞게 head를 다는 과정이 필요하고, 이 head를 어떻게 써먹냐에 따라 linear classifier/fine-tuning/k-NN classifier로 분류된다.

  • Linear classifier : 학습된 backbone을 frozen한 채로 linear classifier만 학습해서 representation의 효과를 보고자 하는 것
  • Fine tuning : 학습된 backbone을 head에 붙인 채로 fine tuning하여 representation의 효과를 보고자 하는 것
  • k-NN classifier : classifier 같은 부수적인 요소 없이 단순히 embedding으로 retrieval해서 representation/metric learning 자체 효과를 보고자 하는 것

뭐 성능 자체와 관련해서는 상당히 좋게 나온 것을 확인할 수 있고, ViT baseline의 다른 SSL 방식과 비교했을 때도 유의미하게 높은 classification 성능을 보여준다.

ViT Attention map

하지만 classification 보다는 DINO의 가장 큰 특징은 ViT의 attention map을 보면 잘 드러나는데, 바로 local feature에 attention을 집중할 수 있다(localization)는 것이다. 이는 기존 ViT 방식으로는 얻을 수 없는 feature map에 해당된다.

Segmentation처럼 high-level image task의 경우 모델의 예측이 정교해야하기 때문에 상대적으로 classification task에 비해 SSL이 달성할 수 있는 성능 수치가 그리 높지 않았다. 그럼에도 불구하고 DINO의 attention map을 보면 알 수 있듯이 이 페이퍼에서 제안한 학습 방법은 ViT의 input에 대한 attention을 효과적으로 localization하는 것에 성공하였고, 정량적으로도 그 수치를 증명하였다.


결론

DINO는 BYOL을 비롯한 기존 SSL 방식에서 motivation된 self-distillation 구조와 더불어 collapse를 방지하고 학습 안정화를 위해 sharpening/centering을 도입하여 ViT를 효과적으로 학습하였다. 그런데 이렇게 안정적으로 ViT를 라벨 없이 SSL로만 학습하고 보니 이게 무슨 일이람. ViT의 attention map이 localization되는 중요한 변화를 확인할 수 있었다. 이러한 점이 시사하는 바는 상당히 크다.

지금까지는 supervised learning이 절대적인 학습법이었으며, 사실 SSL이 학습 안정화를 토대로 가끔 supervised learning의 성능을 넘는 경우도 있긴 했지만은 모든 task에 정통으로 사용될 수 있는 방법은 아니었으며 linear probing이나 fine tuning 시에 미리 학습된 representation의 효과를 강하게 보여주었지 실질적으로 SSL로 학습된 representation이 가능성을 보여주는 경우는 많지 않았다. 하지만 애초에 구조상 inductive bias가 없어 localization이 힘든 transformer baseline인 ViT를 SSL하였더니 attention map이 segmentation 효과를 보여주었고, 이는 NLP가 아닌 Computer Vision 분야에서도 classification 뿐만 아니라 여러 task에 SSL이 우월한 성능을 보여줄 수 있음을 증명하였다.

A n o t h e r p o s t i n c a t e g o r y