loading

MaskCLIP 논문 리뷰


Zero-shot, Multimodal learning

Published on January 24, 2023 by JunYoung

Multimodal Zero shot SSL AI Deep learning

10 min READ

들어가며…

사실 리뷰할 MaskCLIP은 ECCV paper와 arxiv paper 두 종류가 있다. 논문 제목은 다르지만 main idea의 제목이 같다보니 조금 혼란스러운 감이 없지 않아 있었다. 처음에 읽고 싶었던 논문은 ‘Extract Free Dense Labels from CLIP‘이며, 그 다음에 추가로 arxiv에 올라온 논문 제목은 ‘MaskCLIP: Masked Self-Distillation Advances Contrastive Language-Image Pretraining‘으로, 아직 학회에 제출을 하진 않은 것 같지만 느낌상 NeurIPS 2023에 올라올 것 같은 내용이다. ECCV 2022에 올라갔던 논문인 Dense label extraction은 segmentation task에 댇해 CLIP의 zero-shot representation을 어떻게 하면 가장 잘 활용할 수 있을까에 대한 실험을 진행했던 연구였고, 아카이브에 올라온 논문인 self-distillation은 CLIP의 image representation 학습에 self-supervised learning 방식 중 하나인 generative approach(masked autoencoder 방식)을 적용한 연구다.


CLIP의 장점?

VL(Visual-language) contrastive learning를 통해 vision의 modality인 image와 text modality와의 관계성을 찾은 것이 CLIP 논문이었고, 해당 논문을 기반으로 다양한 dataset의 downstream task(classification, object detection 그리고 segmentation 등등)을 zero-shot 혹은 few-shot 으로 해결할 수 있는 길이 열리게 되었다.

사실상 GPT와 같은 대용량 언어 모델 이외에 computer vision에서 웹 상에서 획득할 수 있는 text prompt 기반 대용량 데이터셋을 학습에 사용한 것은 CLIP이 처음이었으며, 단순히 downstream task를 잘 해결할 뿐만 아니라 language와 vision의 관계를 찾은 것에서 image captioning, view synthesis와 같이 prompt 기반 DL engineering이 가능해졌다는 점에서 멀티모달의 새 시대를 연 장본인이라고 볼 수 있다.


Segmentation에서는 활용할 수 없을까?

ECCV 2022 논문인 ‘Extract Free Dense Labels from CLIP’는 바로 이러한 CLIP의 장점에 집중하여, 어떻게 하면 CLIP의 zero-shot performance를 segmentation에서도 활용할 수 있을까에 대해 연구를 진행하였다.

이는 굉장히 큰 의미가 있었던 것이 만약 우리가 위와 같은 그림을 segmentation해야 한다고 가정했을 때, 조커와 배트맨을 독립적으로 instance labeling 해야하는 문제가 발생한다. 물론 단순하게 생각해서 각 사람 이미지를 독립적으로 구분하는 작업을 하나씩 라벨링하고 이를 supervision으로 활용해서 학습하게 된다면 얼추 segmentation 성능이 기대하는 것만큼 나오겟지만, 그렇다고 해서 모든 영화의 모든 인물들에 대한 프레임별 labeling을 수작업으로 진행하고 각각을 새로운 label에 매칭하는 것은 학습 과정에서나 데이터셋 구축에서도 시간과 노력이 많이 드는 과정이다.
기존의 SOTA segmentation 방식은 ImageNet에 대한 pre-trained weight으로 초기화한 representation을 활용하여 segmentation task에 적용하는 한편, 이 논문에서는 CLIP이 가지는 global image representation(Web image scale에 대한 representation)을 활용하고자 한다. CLIP이 ImageNet base로 학습된 representation에 비해 가지는 장점은 다음과 같다.

  1. 각 feature에 local image semantic(각각의 feature dimension은 이미지의 일부에 대한 정보를 담는다)을 학습 가능하다. 사실 이 부분은 ImageNet에 pre-trained된 baseline weight의 효과와 거의 유사하다고 볼 수 있다.
  2. Open-vocabulary의 concept을 학습할 수 있다. 앞서 설명했던 것처럼 구체적인 labeling 없이도 원하는 object에 대한 segmentation이 가능하다.
  3. 물체들 간의 상호작용, 관계 그리고 spatial location에 대한 풍부한 문맥상의 정보를 학습 가능하다.

이러한 CLIP의 장점이 모여 segmentation task에서 CLIP representation의 supervision을 활용할 수 있다면, 효율적인 task performance를 기대할 수 있으리라는 것이다.


Failure and success

사실 처음부터 저자들이 실험에 성공한 것은 아니었다. 가장 쉽게 생각할 수 있는 방법은, CLIP 모델 자체도 image encoder로 ResNet 구조를 활용하기 때문에 pre-trained weight을 가져와서 초기화시킨뒤, CLIP image encoder의 weight를 segmentation task에 맞게 fine-tuning 시키면 성능이 좋아질 수 있겠다는 생각을 하게 된다. 예를 들어 segmentation task에는 DeepLab 모델들이 SOTA로 사용되었는데, DeepLab의 weight를 CLIP image encoder의 weight로 초기화시킨 뒤에 backbone을 segmentation에 맞게 학습을 시키는 것이다. 이러한 과정이 CLIP의 text embedding을 굳이 사용하지 않아도 된다는 장점이 있었지만, 이러한 접근법은 CLIP의 장점 중 하나인 unseen class에 대한 segmentation을 진행할 수 있는 능력을 완전히 배제한 것이다.
따라서 단순히 weight를 초기화하고 fine-tuning을 진행하는 기존 방식에서 벗어나, MaskCLIP이라고 불리는 접근법은 CLIP의 image encoder로부터 추출된 patch-level feature를 활용하여, 기존 CLIP의 마지막 layer였던 attention pooling을 진행하지 않고 CLIP의 text encoder로부터 얻을 수 있는 $1 \times 1$ convolution weight를 기반으로 dense prediction을 진행한다. 사실 여기에서는 convolution 연산이 가지는 의미 자체를 하나의 trick으로 사용한 것인데, 원래라면 기존의 text embedding과 image embedding 사이의 similarity를 통해 classification을 진행하는 task라고 해보면, 사실 이 과정은 특정 prompt에 대한 text embedding과 그에 맞는 image embedding 간의 correlation을 계산하는 것이고, 결국 $1 \times 1$ convolution 연산이 가지는 의미가 각 dimension에 대한 inner production이기 때문이다.
또한 CLIP baseline 구조를 그대로 사용했기 때문에 MaskCLIP은 좋은 zero-shot 성능을 보였으며, ResNet과 ViT 구조 모두 segmentation에 대해 적용할 수 있었다(이는 attention pooling 연산 구조 자체가 ViT의 image class token이 self-attention하는 과정과 완전히 동일하기 때문). 추가로 segmentation의 성능 향상을 위해 학습할 필요가 없는 key smoothingprompt denoising 과정을 추가하였다. 두 방법에 대해서는 뒤에서 보다 자세히 보고 넘어가도록 하겠다.


MaskCLIP approach

MaskCLIP+로 표시된 부분은 이후에 따로 설명하기로 하고, 우선 MaskCLIP으로 표시된 부분(짙은 회색)을 먼저 확인하면 위와 같다. 참고로 MaskCLIP에서는 학습 과정이 전혀 필요 없고, 단순히 특정 이미지에 대해 얻고자 하는 class label에 대한 text prompt를 만든 뒤, 이렇게 만든 각 class별 text prompt를 encoding한 결과를 $1 \times 1$ convolution의 weight으로 사용하여 image embeddingdense prediction으로 확장시킨다. 기존 attention pooling이 image feature의 average를 class에 대한 token으로 간주하여 모든 pixel 정보에 대한 attention을 구했던 것과 달리, MaskCLIP에서는 attention pooling 이전의 모든 픽셀 정보들을 활용하는 형태가 된다.

하지만 단순히 이렇게 했을 경우에는 prediction 성능이 크게 좋지는 않았는데, 이유는 아무래도 text prompt와의 similarity를 기준으로 prediction이 진행되는 구조이다 보니, 기존 segmentation task에서와 같이 locality(비슷한 위치의 pixel은 같은 class일 확률이 높다)에 대한 정보가 담기기 힘들고, 무엇보다 feature 단위의 pixel prediction은 noisy한 결과를 보여줄 수 있기 때문이다. 이러한 문제들을 해결하고자 제시한 방법이 바로 앞서 설명을 넘겼던 key smoothingprompt denoising 과정이다.

Key smoothing

굳이 attention pooling layer가 필요없다면, 결국 key, value값이 더 이상 학습에 관여하지 않게 된다는 의미이다. 그러나 결국 attention layer도 CLIP 학습에 있어 prediction에 영향을 주었을 것이고, segmentation task를 위해 이를 아예 제거하고 본다는 것은 attention layer가 CLIP image encoder의 representation 학습에 끼친 영향력을 아예 무시한다는 것과 같다. 이러면 안되는 것이, attention layer로 하여금 한 픽셀과 다른 픽셀의 연관성을 key value로 학습했을 것이고, 이러한 학습 결과가 각 local semantic에 대해 ‘어떤 부분들을 참고해야하는지‘에 대한 정보를 담았을 것이라고 볼 수 있다. 예컨데 위의 그림과 같이 $K_1,~K_2$ 그리고 $K_3$에 해당되는 feature 영역이 있다고 생각해보자. 서로 다른 물체 영역에 속하는 $K_1$(잔디밭)과 $K_2$(고양이)는 유사도가 낮지만, $K_2$(고양이)와 $K_3$(고양이)는 유사도가 높을 것이다.

[ \text{pred}_i = \sum_j \cos \left( \frac{k_i}{\parallel k_i \parallel_2},~\frac{k_j}{\parallel k_j \parallel_2} \right) \text{pred}_j
]

바로 이러한 key value의 특성을 고려한 prediction이 key smoothing이며, $i$번째 patch에 대한 prediction을 진행할 때, 다른 위치의 patch에 대한 key 값과의 유사도를 고려하겠다는 의미가 된다. 참고로 논문에서는 similarity와 곱해지는 $\text{pred}$ 부분의 index가 $i$로 표기되어 있는데, 의미하는 바를 읽어보면 $j$가 맞는 듯하여 위와 같이 식을 수정해보았다.

Prmopt denoising

또다른 문제는 image에 포함된 class가 너무 많을 경우 발생하는 문제다. 만약 위와 같은 그림에서 ‘cat’, ‘grass’ 그리고 ‘tree stump’를 구별하는 segmentation task라고 한다면, ‘tree stump’의 경우 이미지의 극히 일부에만 포함되어있어 오히려 key smoothing 과정에서 다른 class에 대한 confidence를 방해하는 역할을 한다. 따라서 본 논문에서는 만약 특정 class에 대한 key similarity가 모든 feature pixel에 대해 $0.5$보다 작을 경우, 해당 class를 prediction 과정에서 아예 배제시키는 방법을 사용하였고, 이를 통해 더 좋은 성능을 얻을 수 있었다고 한다.


Mask CLIP to Mask CLIP+

여기서 멈추지 않고 논문에서는 Mask CLIP 대신 Mask CLIP+라는 방법을 추가로 제시한다. Mask CLIP의 장점은 CLIP base로 zero-shot segmentation이 가능하다는 것이지만, 달리 표현한다면 CLIP architecture 대신 segmentation에 좋은 성능을 보이는 기존 baseline인 DeepLab이나 PSPNet 등등을 사용할 수 없다는 단점이 있다. 특히 ASPP와 같은 방식을 사용할 수 없다는 점에서 segmentation task에 특화된 네트워크 설계가 CLIP based method에서는 한정적일 수밖에 없다.

따라서 해당 논문에서는 MaskCLIP+에 대해 위와 같은 framework를 제안한다. 기존 CLIP network는 학습하지 않고, 단순히 MaskCLIP+에 사용될 network가 학습할 수 있는 pseudo-label을 만들어준다. MaskCLIP+에서 사용되는 backbone은 기존 fine-tuning 방식과 동일하게 ImageNet에 대해 pre-trained된 네트워크를 사용한다. 그러나 기존 방식과 다른 점은 fine-tuning 과정에서 사용하는 classifier는 학습하지 않고 CLIP의 text-prompt를 기반으로 생성한 $1 \times 1$ convolution을 사용한다는 것이다. 이를 통해 CLIP이 아닌 다른 network 베이스에도 CLIP space에 대한 knowledge distillation이 가능하다는 접근을 보여주었다.
다만, MaskCLIP+에서 pseudo label을 계속해서 supervision으로 사용하게 되면 segmentation backbone(dilated backbone)의 성능 upper bound가 CLIP 구조에 한정될 수 밖에 없기 때문에 학습 schedule 상 약 $1/10$ 까지는 CLIP의 pseudo label을 활용하되, 이후에는 성능 수렴이 발생하여 이를 제거하고 self-training 과정을 거쳤다고 한다.

그 결과 놀라울 정도로 좋은 segmentation 성능을 보여주었고, 단순히 fine-tuning하는 기존 방식에서 벗어나 CLIP guidance learning 기반의 text prompt의 representation이 supervised based segmentation에 특화된 네트워크 구조에서도 zero-shot segmentation이 가능하게끔 할 수 있는 방법론으로 제시되었다.

Fully supervised mIoU가 사실상 zero-shot이 가질 수 있는 최대 한계치라고 볼 수 있는데, MaskCLIP+가 달성한 수치는 기존 SOTA를 뛰어넘어 더 이상 발전이 힘든 수준까지 올라간 것을 볼 수 있다.


CLIP의 문제점

이렇듯 완벽할줄만 알았던 CLIP도 사실은 문제점이 있었다. 사실 원래 모든 연구는 크나 작으나 어느 정도 약점을 가지고 있는데, CLIP은 그에 비해 contribution이 너무 크기 때문에 사실상 흠을 잡기가 애매했다. 이러한 상황에서 ECCV에서 나온 MaskCLIP과 지금부터 리뷰할 arxiv의 MaskCLIP은 방향성이 다른데, ECCV는 segmentation에서 CLIP을 사용할 수 있을 정도로 CLIP이 학습한 text prompt based representation이 풍부한 contextual meaning을 담을 수 있음에 집중했고, arxiv에 올라온 MaskCLIP은 CLIP이 제시한 학습법으로는 image representation의 충분한 학습이 힘들다는 점에 집중했다.

예컨데 만약 좌측과 같이 고양이가 있고, 해당 이미지에 대해 “Two cats are on the grass”라는 문장이 있다면, 사실 text prompt가 줄 수 있는 contextual information은 단순히 object에 대한 정보와 둘 사이의 관계에 대한 내용이지, 실질적으로 이미지의 background나 object를 제외한 디테일한 사물/texture 형태에 대한 어떠한 정보도 줄 수 없다는 것을 알 수 있다.
따라서 단순히 language information guidance를 토대로 CLIP을 학습하게 되면 충분한 데이터셋을 기반으로 image to text relationship 성립이 가능하지만, 정말로 학습된 후에 특정 text prompt를 기준으로 네트워크가 어디에 집중하는지 확인해보면, image description을 기준으로 학습된 부분은 representation에 잘 포함되지만 그렇지 않은 영역은 representation에 잘 포함되지 않는다는 것이다.


Learn image representation with self-supervised learning

따라서 논문에서 제시하고자 한 해결책은 CLIP도 결국엔 이미지 자체의 supervision을 사용할 수는 없기 때문에 SimCLR, Moco 등등 contrastive한 방법이나 MAE, BEiT 혹은 DINO와 같이 generative한 approach로 SSL을 CLIP 학습에 함께 사용하고, 이를 통해 image representation의 효율적인 학습을 이끌어내고자 한 것이다. 물론 기존에 self-supervised learning 방법을 CLIP에 적용하고자 했던 SLIP과 같은 방법들이 있었지만, 각 방법들은 contrastive learning 방법을 적용한다던가, salient object에 집중하게 하여 이 논문이 제시한 CLIP의 문제점 중 하나인 surrounding object에 대한 contextual information이 부족한 점을 채우지 못했다.

또한 BEiT, MAE 방식의 경우 token이나 image patch에 대해 self-supervision을 가지고 image representation을 학습하는 형태가 되기 때문에, high level feature인 image embedding과 text embedding 간의 유사도를 기준으로 학습하는 CLIP 방식에서 generative learning 방법을 적용하기엔 서로 학습 목적으로 삼는 representation의 level이 다르다는 문제가 발생한다.
따라서 해당 논문에서는 pixel 단위의 supervision이 아닌 feature 단위의 supervision과, patch 단위가 아닌 global 단위로 representation을 학습하는 CLIP 구조 자체를 self-teacher로 삼아 patch 단위로 knowledge distillation을 진행하는 self-distillation 방식을 사용한다.

모든 ablation에 대한 저자들의 framework 제안은 위와 같다. 가장 우측에 있는 MaskCLIP 학습 형태가 최종적으로 제안하는 학습 형태가 된다. 이 논문의 main contribution은 크게 세 개로 나눌 수 있다.

  1. VL contrastive에 새로운 학습 framework를 제안함으로써, masked self-distillation을 통해 VL model의 transfer 성능을 올릴 수 있다.
  2. 단순히 SSL 방법론을 적용하여 visual model의 성능을 높이고자 한 것이 아니라, 여러 MaskCLIP 형태에 대해 ablation을 진행하였다.
  3. 제안한 MaskCLIP(self-distillating) 구조가 zero-shot, linear-probing 그리고 finetuning과 같은 setting에서 모두 좋은 성능을 보였다.

Results

결과적으로는 MaskCLIP이 기존 CLIP의 성능에 비해 월등히 좋아졌으며, 논문에서 주장하는 것은 image representation에 대한 추가 loss term을 통해 수렴 속도가 빨라졌다는 점에 집중을 한다. 아마 이 부분은 image representation에 supervision을 추가로 주면서 단순한 embedding인 text space보다 복잡한 signal인 image space가 text에 대한 representation보다 학습 속도가 느리기 때문에 발생하는 수렴 차이를 어느 정도 좁혀줄 수 있는 방법이기 때문이지 않을까 생각해본다.

실제로 논문이 주장한 바와 같이 특정 text prompt에 대해서 네트워크가 집중하는 부분이 보다 reasonable해진 것을 알 수 있다.

A n o t h e r p o s t i n c a t e g o r y