loading

Prompt learning in Vision-Language(CoOp, CoCoOp) 논문 리뷰


prompt learning

Published on February 19, 2023 by JunYoung

VLP Prompt learning Downstream tasks

19 min READ

들어가며…

대용량의 pre-trained VL model(CLIP, ALIGN 등등)은 여러 downstream task에서 효과적으로 representation을 transfer할 수 있음을 보여주었다. 특히 zero-shot이나 linear probing에서 보여준 성능은 vision과 language를 함께 활용하는 것이 보다 open world set에 대해 적합한 학습 구조라는 것을 증명하였다. 대부분 one-hot encoding과 같은 discretized label(class를 label로 mapping하는 형태)를 사용하는 것이 traditional representation learning의 트렌드였다면, VLP(Vision language pre-training) task에서는 image와 text를 같은 feature space에 어떻게 하면 유의미한 관계를 가지게끔 위치시킬 수 있을지 연구하게 되었다. 이러한 연구를 토대로 각 downstream task를 해결하기 위한 prompting을 적용하였으며, 여기서 말하는 prompting이란 일종의 ‘텍스트를 구성하는 방법론’이라는 의미가 된다. 이번 포스팅에서 다룰 주제가 prompt learning인 만큼 prompt에 대한 개념이 계속 사용되기 때문에 대략적인 개념만 짚고 넘어가도록 하자.

자세한 내용은 뒤에서 더 디테일하게 설명할 것이지만 이 논문에서 저자들은 prompt engineering이 challenging하다는 점을 문제로 삼는다. 보통의 딥러닝에서 학습 성능을 높이기 위한 방법으로 hyperparameter를 조정하는 과정을 거치곤 하는데, 이처럼 VLP task에서 prompt를 어떠한 방식으로 만들어내는지에 따라 task별 성능의 등락폭이 너무나도 커져서 무시할 수 없는 전처리 과정으로 간주되었다. 하지만 단순히 단일 task의 성능을 높이기 위해 사실상 무한에 가까운 경우의 수를 가지는 prompt를 모두 테스트하는 것도 굉장히 시간이 오래 걸리는데, 이들 중 가장 좋은 성능을 보이는 prompt를 기준으로 학습을 진행했더라도 다른 task에서는 또다시 prompt engineering 과정을 처음부터 시작해야하는 것이다. 그리고 특정 언어에 대한 representation을 학습하기 위해 prompt를 만드는 과정이 domain expertise를 요구한다는 점이 문제가 된다.

따라서 이 논문에서는 NLP task에서 제시한 prompt learning 방법론을 사용하여, prompt 자체를 최적화하는 Context Optimization(CoOp)을 제시하였고, CoOp은 prompt의 context word를 학습 가능한 벡터로 간주하여 VL task의 model parameter는 frozen 시킨 채로 각 downstream task에 최적의 prompt를 찾기 위한 학습 과정을 거친다. CoOp을 통해 더 적은 shot(학습 샘플)으로도 hand-crafted prompt의 성능을 따라잡을 수 있었으며 학습이 충분히 진행된 후에는 downstream task의 성능이 대략 $15\%$ 상승한 것으로 나타났다.


Limitations of existing methods and apperance of VLP

CLIP 논문 리뷰에서도 언급했지만 CLIP이 문제시했던 내용은 충분히 많은 representation을 학습할 수 있는 Web 기반 이미지 데이터셋에 대한 부분이었고, 구체적으로는 왜 기존 image modality에 대한 학습법이 좋은 representation learning 방법이 아니었는지 언급하지 않고 넘어왔었다. ResNet과 Vision Transformer와 같은 연구들에서 확인할 수 있듯이 classification task의 경우 정해진 갯수의 object category가 있고, 각 카테고리에 대한 description은 indexing을 거쳐 discrete label로 간주하게 된다. 예를 들어 CIFAR10과 같은 경우에는,

위와 같이 총 $10$개의 클래스를 가진 데이터셋을 사용하는데, 각각의 카테고리 정보(airplane, cat, frog 등등)은 학습 과정에서 단어의 뜻으로 supervision을 주는 것이 아닌 one-hot encoding 형태로 cross-entropy loss를 최적화하는데 사용되었다. ImageNet dataset은 약 1000개의 클래스로 구분되는데, 다음과 같은 클래스들이 포함된다.

281: 'tabby, tabby cat',
282: 'tiger cat',
283: 'Persian cat',
284: 'Siamese cat, Siamese',
285: 'Egyptian cat',

이는 완전히 무관한 카테고리인 ‘orange’라던지 ‘mushroom’과의 언어적/맥락적 유사도를 전혀 고려하지 않고 단순히 서로 다른 카테고리일 경우에는 다른 인덱스를 부여하는 형태의 supervision이 될 수 밖에 없다. 이러한 discretized label이 가지는 문제는 다음과 같이 묘사할 수 있다.

만약 siamese cattiger cat이 있다면 분명 두 사진은 다른 카테고리이지만, 어쨌든 카테고리의 이름에서 볼 수 있듯이 ‘고양이’라는 공통점을 가지고 있고, 무엇보다 두 이미지에서 확인할 수 있는 object의 attribute(동물, 털이 있음, 수염이 있음 등등)가 유사하다고 판단할 수 있다. 그와는 반대로 아래에서 볼 수 있는 사과 이미지는 위의 두 사진과 사실상 겹치는 attribute가 없다고 볼 수 있다.

하지만 기존 classification task에서 사용하는 카테고리별 label은 text가 가지는 연관성을 이미지 특성과 전혀 연관짓지 못한다는 것이다. 그렇기 때문에 실제로 open world set에 대해 적용할 수 있는 다양한 object를 판별하기 위해서는 continous signal인 image를 discrete signal인 one-hot encoding으로 단순하게 mapping을 하는 것이 옳지 않다. 또한 카테고리를 묘사하는 text가 더 자세하면 자세할수록 network가 분류해야할 category의 개수는 증가하게 되고 그렇게 되면 정답인 label 이외의 다른 object들에 대한 노드가 방해 요소로 작용할 수 밖에 없다. 기존 computer vision 관련 딥러닝에서 다루던 방법은 visual recognition system을 closed set으로 간주한다는 점이 한계가 되었고, 데이터가 추가되면 기존에 학습했던 데이터에 추가하여 새로운 classifier를 학습해야하는 문제점이 생기게 되었다(보다 관심이 생긴다면 continual learning, domain generalization 관련 논문을 찾아보면 좋다).

따라서 비교적 최근 pre-training을 discretized label이 아닌 text embedding에 대해 진행하는 VLP model인 CLIP과 ALIGN이 등장하기 시작했고, vision representation learning에서 기본 방식에 비해 zero-shot transfer 성능을 눈에 띄게 높일 수 있었다. 학습 아이디어는 간단한데, 바로 image encoder과 text encoder의 output을 align하는 방식을 사용한 것이다. CLIP과 ALIGN 논문에서는 모두 contrastive loss를 objective function으로 사용했는데, 만약 image를 묘사하는 text랑 서로 positive pair(image-text pair를 학습 데이터로 사용했다)라면 embedding space 상에서 가까워지도록, negative pair라면 멀어지도록 학습하게 된다. 이와 같은 방법으로 CLIP과 ALIGN은 대용량의 데이터셋에 대해 pre-trained된 네트워크를 사용하여 다양한 task에 맞는 prompting을 거쳐 knowledge transfer을 진행하였다.

Task specific prompting과정은 위의 그림들 중 좌측(CLIP framework)에서 확인할 수 있는데, classification을 진행할 class의 object name을 사용하여 “A photo of (Class)” 등의 형태(task마다 prompt 형식이 달라짐)로 formatting을 진행한 뒤, text encoder를 거쳐 나온 텍스트 임베딩classifier weight으로 사용하게 된다.


Limitations of prompt engineering in VLP

사전 학습된 network를 사용하는 VL task에서 성능에 주된 영향을 보여준 것은 적절한 prompt engineering이었다. 그러나 적절한 prompt를 정의하는 것은 쉽지 않았기 때문에 이를 tuning하는 과정이 굉장히 오래 걸리는 작업이었고, 약간의 prompt 변화로도 성능이 크게 좌우되는 task도 존재했다. 예를 들어 Caltech101 dataset과 같은 경우, “a photo of (class)” 대신 “a photo of a (class)”를 썼을 때 무려 성능이 $5\%$ 증가할 정도로 변동이 큰 것을 확인할 수 있다.

게다가 prompt engineering은 task에 대한 prior knowledge도 있어야하며 language model이 작동하는 방식에 대한 domain knowledge가 필요하다.

‘texture’라던지, “centered satellite photo”와 같은 specific한 description은 해당 dataset이나 분야에 대해 제대로 알고 있지 않다면 prompt engineering에 적용하기 힘들기 때문이다. 그리고 가장 큰 문제는 앞서 말했던 것처럼 이와 같은 prompt engineering이 실질적으로 optimal한지 확신할 수 없기 때문에 일정 수준의 성능 향상으로만 만족할 수 밖에 없는 상황이다.


Prompt engineering in VLP task

위에서 볼 수 있는 여러 문제점들을 해결하기 위해 NLP task에서 비교적 최근 등장한 prompt learning research의 아이디어를 기반으로 CoOp에서는 downstream task에 최적화된 prompt를 찾고자 하였다. 기존의 prompt를 찾는 방식은 최적의 hyperparameter를 찾기 위해 tuning 및 검증하는 단계와 같았지만 CoOp 논문에서 제시하는 방법은 적절한 prompt를 찾는 과정을 자동화하는 것이다. 즉 prompt의 context words를 학습 가능한 vector로 취급한다는 것이다. 다양한 task를 다루기 위해 논문에서는 두 가지 방식을 제시하였다. 첫번째는 unified context로 모든 class에 대해 동일한 prompt를 학습시키는 방법이고 두번째는 각 class마다의 prompt를 학습시키는 방법이다. 대부분 첫번째 방법이 좋은 성능을 보였으나, 몇몇 fine-grained category를 가진 task에서는 두번째 방법이 더 효과적이었다고 한다.


Contributions

논문에서 선택한 방법인 prompt learning은 기존의 prompt engineering에서 사용한 방식처럼 text prompt를 discretized explanation(category)로 생각하지 않고 continous signal로 취급한 것과 같다. CoOp의 prompt optimization의 효과를 확인해보기 위해 $11$개의 벤치마크 데이터셋을 사용, 여러 object나 scene 그리고 action 등등 다양한 형태의 category로 분류된 task에서 evaluation을 진행하였다. 결과적으로 논문에서 제시한 본인들의 contribution은 다음과 같다.

  • 직접 VLP 네트워크에 대해 다루지는 않고, 이를 활용한 downstream application을 연구한 시기적절한 연구라고 생각한다. 그리고 기존의 VLP 방식에서의 비효율적인 문제점을 제시하였다(prompt engineering).
  • 사전 학습된 VL model을 통해 prompt engineering이 자동화될 수 있게 하기 위해, continous prompt learning 방식을 기반으로 접근하였고 unified/class specific 방식 두 가지를 제시함으로써 보다 다양한 recognition task에 적용될 수 있는 방법을 제시하였다.
  • Hand crafted prompt 방식과 linear probe model 방식을 downstream task에 적용했을 때 VLP 네트워크의 representation을 transfer learning이 진행되는 것보다 더 효율적으로 최적화가 가능하며 성능 측면에서도 기존 방법들을 뛰어넘었다. 그리고 VL model에 있어서 domain shift에 보다 robust하다는 특징이 있다.

Related works

Vision Language models

결국 prompt learning을 적용한 것은 VLP task의 pre-trained representation효과적으로 transfer해서 사용하고 싶기 때문인데, 비교적 최근 연구인 CLIPALIGN과 같이 text와 image encoder의 결과를 contrastive하게 학습한 형태가 image/text multimodal의 기본 아키텍쳐로 사용된다. 두 연구 모두 web 기반의 대용량 데이터셋을 사용했다는 점과 large minibatch를 기반으로 contrastive learning을 진행했다는 점을 공통점으로 삼을 수 있다. 물론 CLIP과 ALIGN 이전에도 text와 image를 같은 embedding space에 올리고자 했던 연구는 있었지만, text embedding을 추출하는 방식(Word2Vec, TF-IDF 등등)이나 matching하는 방식(metric learning, multi-label classification, n-gram language learning 등등)이 현재 SOTA인 contrastive representation learning based와는 차이가 있다.

그러나 이 논문에서 밝히는 바는 본인들의 연구는 이러한 기존의 vision-language model의 연구와 방향성이 다르다고 한다. 이는 기존의 VLP task는 image와 text를 동일 embedding space 상에 alignment하는 방법에 대해 초점이 맞춰져 있었다면 이 논문에서는 이미 학습된 pre-trained knowledge를 transfer하는 방식에 대해 초점을 맞춘 것이다. 논문에서는 hand-crafted 방식으로 prompt engineering을 하는 과정을 prompt learning으로 바꾸는 것이 효과적일 것이라고 밝혔다.

What is prompt learning?

Large pre-trained language model(LLM)의 knowledge probing을 하는 방식으로는 ‘빈칸 채우기’ 방식의 cloze texts 방법이 제시되었고, 이는 NLP task에서 prompt learning 연구가 진행될 수 있는 발판이 되었다.

Knowledge probing이 빈칸 채우기 방식이라고 했는데, 간단하게도 probing의 기본 컨셉은 주어진 cloze-style의 prompt에 대해 정답을 생성하게끔 하는 것이다. ‘How can we know what language models know?’ 논문에서는 text mining을 통해 여러 prompt를 후보군으로 생성하고, training accuracy를 보이는 prompt를 optimal prompt로 사용하는 방식을 제안하였다. AutoPrompt에서 제시한 방법은 아래 그림에서 볼 수 있듯이 label likelihood에 대해 가장 큰 gradient 변화를 주는 token을 searching한 뒤(gradient based search라고 부른다) 이를 prompt generation에 사용하였다.

이 논문에서는 continous prompt 방식을 사용하게 되는데, 이 방법의 문제점이라면 기존 discrete token에 대해 text embedding space에서 찾는 것보다 학습된 word가 정확히 어떤 prompt를 나타내는지 시각화할 수 없다는 점이다. 그럼에도 불구하고 저자들이 continous prompt 방식을 사용한 것은 VLP task의 주된 목적은 명확한 prompt embedding을 추출하는 것이 아니라, VL model을 downstream task에 사용했을 때 좋은 성능을 보이는 prompt를 tuning하는 과정을 자동화하기 위함이다.


Method

방법론은 굉장히 심플하다. 단순히 기존의 prompt engineering 부분을 학습 하능한 context vector로 설정하고 이를 최적화하는 과정을 사용한다. 물론 VLP 모델 자체를 학습하는 것과는 orthogonal하다는 저자들의 말과 같이 학습 과정은 CLIP baseline framework에서 step (2)에서 (3) 과정과 같다. 각 downstream task에 대해서는 supervision이 있기 때문에 context를 최적화할 수 있는 것이다.

CLIP baseline

네트워크 구조는 CLIP을 사용하였는데, 구조를 간단히 확인해보자면 vision encoderlanguage encoder 각각 있는데 vision encoder는 ResNet50과 같은 CNN baseline과 ViT와 같은 transformer baseline을 사용하였으며 language encoder로는 transformer를 사용하엿다. CLIP의 텍스트 인코딩 방식은 BPE를 사용한다. 학습 과정은 related work에서 간단하게 설명했었지만 다시 한번 설명하면 다음과 같다. Batch 단위의 image-text pair를 가지고 matched pair(contrastive learning에서의 positive pair)에 대해서는 cosine similarity를 최대화하고 unmatched pair(contrastive learning에서의 negative pair)에 대해서는 cosine similarity를 최소화하는 방향으로 학습한다. 다양한 image/text representation에 대해서 학습시키기 위해 CLIP은 web 기반 대용량($400M$)의 paired dataset을 학습에 사용한다.

CLIP의 주된 contribution 중 하나가 바로 zero-shot inference 성능이 높다는 것인데, CLIP은 웹 기반으로 다양한 text prompt에 대해 학습이 되어있기 때문에 다양한 category를 가지는 classification dataset에 대해 downstream task를 수행할 수 있다. 예컨데 $f$가 image encoder에 image $x$를 통과시켜 얻은 feature이며, $(w_i)_{i=1}^K$를 각 class에 대한 description $t_i$를 text encoder에 통과시켜 얻은 weight라고 생각하면 된다. $K$는 downstream task의 클래스 개수를 의미하고 prompt는 ‘a photo of a (class).’와 같이 설정하여, (class)라 명시된 부분에 ‘cat’, ‘dog’와 같이 class 이름이 들어가게 된다. 그런 뒤 prediction은 cosine similarity를 기반으로 softmax probability를 사용하게 된다. cosine similarity를 일종의 score라고 보면 된다(유사도가 높을수록, 해당 class일 확률이 증가하는 구조).

[ p(y = i \vert x) = \frac{\exp (\cos (w_i, f) / \tau)}{\sum_{j=1}^K \exp(\cos (w_j, f)/\tau)} ]

$\tau$는 temperature parameter로 CLIP pre-train 과정에서 학습되는 parameter이다. 기존 classifier가 closed-set visual concepts(정해진 class를 구분하는 task로서 학습됨)에 대해서만 discrete label을 학습하는 구조였다면, CLIP은 open-set visual concept를 유기적으로 학습할 수 있다는 것이 high-capacity network를 구성할 수 있는 방법이 되었다.

Context Optimization

그러나 위의 방법에서 볼 수 있듯이 ‘a photo of a (class)’와 같은 prompt는 사람이 직접 각 task에 대해 좋은 성능을 보이는 prompt를 찾는 tuning 과정이 필요하다. CoOp 논문에서는 이를 자동화할 수 있는 방법으로 다음과 같이 두 방법들을 제시한다.

Unified Context

모든 class에 대해서 같은 context를 공유하는 방식이다. Text prompt $g(\cdot)$에 주어지는 prompt는 다음과 같은 형태로 정의해볼 수 있다.

[ t = (V)_1(V)_2 \cdots (V)_M (\text{CLASS}) ]

$(V)_m$으로 표시된 부분이 각각 특정 word의 embedding이라고 가정하면 된다(CLIP에서는 $512$의 dimension을 가진다). $M$은 사용할 word embedding의 갯수가 되며, 이는 hyperparameter로 정해주게 된다. Prompt $t$를 text encoder에 통과하면 각 class에 대한 classification weight vector를 구할 수 있고, prediction probability는 위에서 봤던 식과 동일하게 구할 수 있다.

[ p(y = i \vert x) = \frac{\exp (\cos (g(t_i), f) / \tau)}{\sum_{j=1}^K \exp(\cos (g(t_j), f)/\tau)} ]

그런데 사실 최적의 context 구조가 “a photo of (class)” 형태일 수도 있지만, “a photo of (class), a type of object”일 수도 있기 때문에 학습 가능한 prompt를 다음과 같이 지정해줄 수도 있다.

[ t = (V)_1 \cdots (V)_{\frac{M}{2}}(\text{CLASS})(V)_{\frac{M}{2}+1} \cdots (V)_M ]

이처럼 prompt는 latter cell을 채울 수도 있고, termination signal을 통해 더이상 채우지 않을 수도 있게끔 학습할 수 있다.

Class-specific context

앞서 언급한 방법은 모든 class에 대해 동일한 context를 학습하게 되고, 다른 방법으로는 각 class 마다의 context를 학습하는 방법이 있다. 서로 다른 class index $i$와 $j$에 대해,

[ (V)_1^i(V)_2^i \cdots (V)_M^i \neq (V)_1^j(V)_2^j \cdots (V)_M^j ]

위와 같이 서로 다른 context를 학습하는 것이다. 이 방법은 일반적인 task와는 다르게 fine-grained classification이 필요할 때 효과적이었다고 한다.


Experiments

저자들이 실험한 데이터셋은 총 $11$개로, ImageNet, Caltech101, Oxford-Pets, StanfordCars, Flowers102, Food101, FGVCAircraft, SUN397, DTD, EuroSAT 그리고 UCF101을 사용했다고 한다.

사용된 dataset의 statistics는 위와 같았다. Hand-crafted prompt로 사용한 prompt는 연구에서 ablation을 통해 가장 좋은 성능을 보이는 prompt를 적용했을때다.

$11$개의 dataset에 대해 CoOp 방식을 사용한 평균 결과는 위와 같이 나와있다. ‘end’라고 표시된 것이 CLASS description을 마지막에 넣은 context를 최적화했을 때가 되고 ‘mid’라고 표시된 것이 CLASS description을 중간에 넣은 context를 최적화했을 때가 된다. CSC는 class specific하게 학습했을 경우인데, 대체로 경향성을 보게 되면 unified prompt를 적용했을 때가 성능이 좋은 것을 볼 수 있다.

물론 모든 경우에 unified prompt가 좋은 성능을 보이지는 않았고 일부 dataset에 대해서는 sample 수가 증가할수록 오히려 CSC가 더 좋은 성능을 보인 경우도 있었다. 이 논문의 가장 main contribution이라고 생각하면, zero-shot CLIP에 대해 여러 dataset의 fine-tuning(Linear probing) 과정에서4-shot 이전까지는 few-shot 성능이 zero-shot 성능 이상으로 보장되지 않았던 것이 limitation으로 제시되었는데, prompt learning을 통해 few-shot 성능이 얼추 zero-shot 성능 이상으로 올라가는 경향성을 확인할 수 있다는 것이다.

위의 그래프를 보게 되면 장점이 더 명확해지는데, Zero-shot CLIP과 비교했을 때 CoOp 방식에 $16$-shot을 적용한 few-shot network가 적게는 $1.24\%$, 많게는 $45.97\%$의 성능 향상을 보이는 것을 알 수 있다. 물론 오히려 성능이 하락한 경우(Food101)도 있지만, $11$개의 dataset 중 $10$개의 dataset에 대해 성능 향상을 보이는 것을 확인할 수 있다.

그리고 CoOp을 통한 학습이 위와 같은 domain shifting 상황에 대해서 보다 robustness를 가지는 것을 확인할 수 있다. 특히 ImageNet-R, ImageNet-Sketch 등의 dataset에 대해서는 단순히 source에 대한 성능 향상에 대한 경향성보다는 더 큰 폭으로 성능이 좋아질 수 있는 것<>을 볼 수 있다.

그리고 실험 결과를 보면 context length를 $16$만큼 사용했는데, 이에 대한 ablationvision backbone에 따른 CoOp의 경향성 또한 확인하였다. 기존 zero-shot CLIP이 보였던 성능 추이에 비슷하게 나오는 것을 확인할 수 있으며, backbone 구조에 무관하게 CoOp 방식이 성능 향상에 잘 적용될 수 있다는 점을 보여주었다.

위는 appendix에 각 dataset에 따라 학습된 context와 가장 유사한 vector를 나타낸 것이다. Continous prompt learning을 사용했기 때문에 학습된 vector가 특정 단어를 나타낼 수는 없지만, 위를 통해 간접적으로나마 학습된 prompt와 유사한 word를 찾을 수 있다. 사실 이걸 보면서 느낀 점은 prompt를 continous하게 학습하게 되면 기존의 text 구조 체계가 무너질 수 밖에 없기 때문에, 성능을 높이는 방법으로는 CoOp이 적절할 수는 있으나 image/text와의 관계성을 보여주기엔 한계쩜이 많다는 것이다. 저자들이 related works를 작성하는 과정에서 본인들의 연구가 VLP와는 동떨어진 연구라고 주장했던 이유가 바로 이 때문이 아닐까 조심스럽게 추측해본다.


Limitation in CoOp and appearance of CoCoOp

CoCoOp에는 치명적인 문제가 존재한다. 이는 바로 학습 과정에서 context가 downstream task에 overfitting이 되다보니, in-domain class에 대해서는 좋은 성능을 보이지만 비슷한 distribution을 가지는 out-of domain class에 대해서는 낮은 성능을 보인다는 점이다.

예를 들어 SUN397 dataset에 존재하는 class category인 ‘Arrival gate’나 ‘Cathedral’과 같이 존재하는 class에 대한 accuracy는 zero-shot에 비해 학습된 prompt가 더 좋은 성능을 보이지만,

‘Wind farm’이나 ‘Trail railway’와 같이 비슷한 distribution(scene understanding이라는 관점에서) hand crafted prompt를 사용하는 zero-shot baseline에 비해 오히려 성능이 나빠지는 것을 확인할 수 있다. 즉 최적화된 text prompt가 특정 dataset을 기준으로 seen class에 대해서만 overfitting되는 문제가 발생한다.

사실 이러한 문제는 어느 정도 당연히 제시될 수 밖에 없는 것이 앞서 CoOp 논문 experiment 과정에서 마지막에 언급한 Appendix 표에도 잘 드러나 있는데, 학습된 prompt와의 nearest neighbor을 시각화했을 때 합리적인 word나 문장이 전혀 생성되지 않고 사실상 이미지의 문맥과는 거의 동떨어진 description이 나오는 것을 확인할 수 있었다. 애초에 논문에서 유연한 context를 실험해보기 위해 CLASS prompt를 문맥 중간에 넣거나 뒤쪽에 넣는 식으로 variation을 주었지만, 이러한 과정이 text domain에 대한 intuition에 아무런 영향도 끼치지 못했다는 것처럼 보인다.


Conditional Context Optimization

논문에서는 이러한 weak-generalization 문제를 해결하기 위해 image captioning과 비슷한 방법을 사용한다. Input image가 ‘학습되는 prompt’에 guidance를 줄 수 있게끔 meta-network를 구성하고, 이를 통해 prompt가 class에 overfitting되는 문제를 정규화하는 것이다. Image captioning에서도 instance에 따른 optimization이 class shift 문제에 대해 보다 robust한 학습이 가능하다고 밝혔다. Prompt가 최적화하는 과정에서 text encoder의 영향만 받다 보니 image representation의 transfer이 약해진다고 판단한 듯하다.

따라서 image encoder(ViT, ResNet)의 class token을 기반으로 meta network를 학습하여, 각 image sample instance에 대해 conditioning된 meta-token을 생성하는 방법을 사용하였다. Context tokens($v_1, v_2, \cdots v_M$)이 오롯이 특정 dataset에 대한 prompt 최적화가 되면 overfitting이 될 수 있기 때문에, context token으로 하게끔 일반화 가능한 정도의 prompt만 학습하고, 나머지 각 image에 대한 정보는 $\pi$가 줄 수 있게 tokenlightweight meta network로 conditioning한다는 것이다. 컨셉을 보고 생각했던 것은 context token을 학습하는 과정은 잘 짜여진 도화지를 만드는 작업이고, $\pi$를 학습하는 과정은 좋은 representation을 그릴 수 있는 팔레트를 구성하는 작업이라는 것이다. 기존의 CoOp이 맨땅에서 그럴듯한 prompt를 만들어내고자 하다보니 overfitting될 수 밖에 없었고, CoCoOp은 conditioning을 통해 text encoder와 image encoder의 역할을 분리함으로써 정규화가 가능해진다고 해석되었다.

두 예시에 대해 각각의 CoOp vs. CoCoOp 결과는 위와 같다. 위의 결과는 class specific prompt에 대한 성능이 CoOp와 CoCoOp이 얼추 비슷하게 나온다는 것을 보여주고, 아래의 결과는 unseen class prompt에 대한 성능이 CoCoOp이 훨씬 좋다는 것을 보여준다. CoOp에서는 zero-shot 성능보다 $10\%$나 떨어지는 경우에 대해서도 성능이 오히려 올라가는 결과를 보여주었다.


Conditional Context Optimization(CoCoOp)

사실 논문을 보면 알 수 있듯이 related worksmethod가 원래 본인들 연구였던 CoOp paper에서 살짝만 바꾼 수준이다. 뭐 본인들 paper을 그대로 reference로 작성하였기 때문에 굳이 tackle을 걸 필요는 없지만 나도 나중에 논문을 쓰면 저렇게 쓰고 싶다는 생각을 해보았다. 사실 원래 연구에 그냥 meta-network만 추가한 것과 같아서 experiment setting도 쉬워보였기 때문이다.

아무튼 그렇기 때문에 나머지 수식은 모두 동일하고 meta network를 통한 meta token $\pi$에 대한 식만 살펴보면 다음과 같다. 저자들이 가장 단순하게 먼저 생각했던 것은 $M$개의 context token에 대해 따로 학습되는 $M$개의 neural networks를 설계하는 것이었다. 그러나 $M\times$의 neural network를 보두 학습하는 것은 CoOp 기준으로 $16$개의 network를 학습하는 과정이 되기 때문에 너무 heavy하고, parameter를 적게 사용할 수 없는 방법이 되기 때문에 $M$개의 context vector에 모두 공통적으로 더해질 수 있는 token을 생성하는 meta network를 구성하였다.

파라미터 $\theta$를 학습 가능한 파라미터로 가지는 Meta-Net($h_\theta(\cdot)$)에 대해 input image embedding $f = E_I(x)$에 대한 context vector는 다음과 같이 구성할 수 있다. 우선 앞서 CoOp 논문에서 사용했던 수식에 추가로 설명하면,

[ t = (V)_1(V)_2 \cdots (V)_M (\text{CLASS}) ]

정해진 갯수의 학습 가능한 prompt vector를 다음과 같이 설계하고, meta token $\pi = h_\theta(f)$를 각 vector에 더한 conditioned vector $V_m(f) = (V)_m + \pi$를 prompt로 대체할 수 있다.

[ t(f) = (V_1(f),~V_2(f), \cdots ,~V_M(f), (\text{CLASS}) ]

그렇게 되면 기존의 prediction probability는 다음과 같이 수정된다.

[ p(y = i \vert x) = \frac{\exp (\cos (g(t_i(f)), f) / \tau)}{\sum_{j=1}^K \exp(\cos (g(t_j(f)), f)/\tau)} ]

학습 과정에서는 meta network의 parameter $\theta$context vector $V_m$이 함께 gradient based로 update된다.

Meta network는 정말 심플하게 2개의 layer를 가지는 structure로, Linear-ReLU-Linear의 MLP로 구성했다고 한다. 보다 복잡한 구조는 future work로 남기겠다고 했는데, 여기에 더 복잡한 구조를 써서 유의미한 결과를 얻어내는 것 자체는 큰 contribution이 안될 것 같다(논문 주제 하나 또 잃었네).


Experiments

실험에 사용한 dataset은 기존 CoOp 연구에서의 $11$개의 dataset을 그대로 사용하였다. 실험 결과 확인한 accuracy 평균은 다음과 같다.

CoCoOp을 사용한 것이 New(unseen class), H(Base + New) 모두 CoOp보다 좋은 결과를 보여주었다. 물론 CLIP이 unseen class에 대한 zero-shot 성능은 제일 좋았으나, base dataset에 대한 성능이 $11\%$ 차이가 난다는 점에서 CoCoOp 방법이 seen class와 unseen class 모두에 적용될 수 있는 방법이라는 것을 보여준다.

각 dataset에 대한 unseen class와 base class 성능 비교는 위와 같다. Base class에 대해서는 CoOp 성능이 더 높은 경우가 많았는데, 이는 overfitting 덕분이라는 분석이 있기 때문에 유의미하지 않았고 unseen class에 대해서는 CoCoOp이 기존 방법을 모두 뛰어넘었다는 점이 주목할만한 점이다.

Generalization 성능에 대해서 언급한 만큼 domain generalization에 대한 실험도 빠질 수 없는데, 실제로 CoCoOp이 source domain과 target domain의 차이가 커짐에도 CoOp보다 robust한 성능을 보였다고 한다.


결론

CoOp 논문, CoCoOp 논문 둘 다 prompt learning을 기반으로 CLIP downstream task의 성능을 높이고자 한 방법론을 다룬다. CoOp에서는 NLP task의 여러 prompt learning 기법 중에서 최적화 과정에 적용할 수 있는 continous prompt learning 방법을 사용, downstream task의 성능을 높이는데 집중했으며 CoCoOp에서는 CoOp에서 성능을 높이면서 놓친 seen class에 대한 overfitting 문제를 다룸으로써 unseen class에 대한 generalization 방법도 meta-network를 통해 제시하였다.

NLP task에서 적용되던 domain generalization 방법을 VL network에 적용하면서 VLP 연구와 orthogonal하게 독자적인 논문들을 작성했다는 점이 contribution이 될 것 같으며, 사실 CoOp이나 CoCoOp은 성능에서도 충분히 좋은 논문이지만 ‘이렇게 써야 논문을 쓸 수 있겠구나’라는 생각을 하게 된 paper라고 생각된다.

A n o t h e r p o s t i n c a t e g o r y