Test time domain adaptation
Published on March 14, 2023 by JunYoung
Deep learning DA Continual learning Test time adaptation
19 min READ
최근 test time domain adaptation에 대한 논문이 많이 나오는 중이다. 본인이 판단했을 때 TTA(Test time adaptation) 혹은 TTDA(Test time domain adaptation)이 활발하게 연구된 이유는 다음과 같다.
첫번째 이유는 UDA같이 수많은 연구가 진행된 쪽으로는 더이상 성능 mining이 쉽지 않다. 요 며칠간 UDA와 관련된(classification, detection, segmentation, cross-domain 등등) 연구들을 찾아본 결과, 최근 연구 방향으로는 더이상 domain adaptation의 성능을 증가시킬 수 있는 방법론이 제시되는 것이 아니라 representation learning을 어떻게 하면 잘 진행함으로써 feature alignment를 효율적으로 할 것인지에 대한 내용이나 기존 방법론들을 이것저것 같이 사용해서 성능을 올리자는 쪽의 페이퍼가 많았던 것 같다.
두번째 이유는 UDA는 실제 domain adaptation이 사용될 환경과는 부합하지 않는다는 것이다. 보통 domain adaptation을 진행하는 상황이라면 특정 modality 및 domain에서 학습된 pre-trained network가 존재하고, 이를 inference 환경에서 다양한 domain shift가 발생했을 때 robust한 예측을 추구하는 것이다. 자율 주행을 예로 든다면 차가 주행하는 환경이(비가 오거나, 눈이 오거나 등등 날씨가 좋지 않은 경우) 학습 데이터에 존재하지 않았거나 상이한 경우(날씨가 좋은 경우, 외국 배경에 대해 학습되어있는데 한국의 도로에서 사용될 경우)를 생각해볼 수 있다.
이러한 상황에서 상식적으로 생각했을 때, 도로 주행 환경에서 사용되는 딥러닝 네트워크는 차량 임베디드 시스템에 내장되어야한다. 그런데 갑자기 주행 환경의 날씨가 변하는 등 domain shift가 발생했다고 해서 UDA 방식처럼 source dataset과 target dataset를 가지고 학습하려면 차량 임베디드 시스템에 source dataset가 저장되어야 하고, 이는 현실적으로 불가능함을 알 수 있다. 데이터셋을 마음대로 쓸 수 없는 것은 하드웨어의 용량 문제도 있겠지만, 라이센스 문제나 사생활 문제와 같은 법적인 bottleneck도 포함한다.
이번 글에서 다룰 내용은 바로 위와 같은 실제 상황에서 활용되기 힘든 domain adaptation assumption보다, 실제로 딥러닝 기술이 활용될 상황에서의 효과적인 domain adaptation 방법을 연구한 TTA/TTDA에 대해서 살펴볼 것이다. 그 중에서 baseline으로 가장 많이 언급되는 논문인 TENT와, continual test time domain adaptation이라는 조금 더 까다로운 domain shift 문제를 정의한 논문을 살펴볼 것이다.
단일 데이터셋(Image classification에서의 CIFAR10, ImageNet 등등)과 같이 정해진 distribution을 가지는 training dataset과 test dataset에 대해 학습하는 것은 deep learning에서 비교적 해결하기 어렵지 않은 task에 속한다. SOTA를 달성한 네트워크 중 하나를 무작위로 fine-tuning하거나 데이터셋이 충분하다면 scratch부터 학습시키는 방법을 사용할 수 있다. 하지만 학습된 기존 데이터셋과 다른 새로운 데이터셋이 지나치게 한정적이거나(수가 적거나) ground truth가 없는 경우 등등 학습된 representation을 활용하기에 부적절한 경우 generalization 성능이 떨어지게 된다. 이러한 문제를 dataset shift라고 부른다. 바로 이러한 문제를 해결하려는 task가 domain adaptation이며, 우리는 그 중에서 test time adaptation에 대해 살펴볼 것이다.
보통 우리가 알고 있기로는 testing하는 과정에서 네트워크의 parameter는 고정되어있고 변하지 않는 것으로 가정하지만 test time adaptation의 경우에는 다르다. 예를 들어 도로 주행 환경에서 사용되는 딥러닝 시스템이 가지고 있는 것은 각 시점에서의 네트워크의 파라미터와 target data가 전부가 된다. 이러한 상황에서 source data를 사용하기 힘들고, 딥러닝 시스템은 오직 자신이 마주한 현재의 dataset에 의존하여 최선의 prediction을 내려야한다.
위의 그림에 대해서는 뒤에서 조금 더 디테일하게 보겠지만 우선 간단하게만 언급하자면, TENT 논문에서 다루고자 하는 task는 조금 더 디테일하게는 ‘fully test-time adaptation’이라고 부르며, 논문의 저자들이 가정하는 것은 해당 상황에서는 source data에 대한 어떤 정보도 얻을 수 없는 상황을 가정하고 오직 ground truth도 존재하지 않는 target data(test data)로만 네트워크의 parameter를 조정하는 작업이 된다. Fully test-time adaptation는 결국 실생활에서 사용되는 시스템에서 데이터, 연산량 그리고 성능에 대해 다음과 같은 목적 의식을 가진다고 할 수 있다.
이와는 조금 차이가 있는 것이 ‘Continual test time adaptation’이 되며, continual test time adaptation(CoTTA) 논문에서는 위의 표에서 보는 것과 같이 inference 상황에서 지속적인 domain shift를 겪는 상황을 가정할 것이다.
앞서 말했던 것과 같이 test dataset에 대해서 네트워크를 adaptation하는 것은 까다로운 작업이다. TENT라는 논문에서는 target dataset에 대한 unsupervised learning 메커니즘에서 어떠한 방식으로 supervision을 주는가에 초점을 맞추었고, 해결하는 방법으로는 논문 제목(fully test-times adaptation by entropy minimization)에서 힌트를 얻을 수 있듯이 entropy minimization을 선택하였다. 방법론에서도 언급되겠지만 entropy minimization의 대상이 되는 것이 바로 test data에 대한 prediction이기 때문에 test entropy라는 objective name을 사용하였으며 이에 따라 해당 방법이 ‘TENT’라는 이름을 가지게 되었다.
그렇다면 과연 ‘엔트로피’란 어떤 의미를 가지는지 생각해볼 수 있다. 예를 들어 동물을 분류하는 네트워크가 있다고 생각해보자. 극단적으로는 고양이/개를 이진 분류하는 task를 수행하는 하나의 네트워크($F_\theta$)가 있다고 가정해보자. 해당 네트워크가 아주 잘 학습(overfitting)이 되어있다는 가정 하에, 갑자기 해당 네트워크에 다음과 같이 개와 고양이가 같이 있는 사진 $I$를 넣어주면 어떻게 될까?
네트워크는 사진에서 강아지에 대한 특징도 찾을 수 있고, 그와 동시에 고양이에 대한 특징도 찾을 수 있다. 단순히 이미지 하나를 하나의 class로 분류하는 이진 분류에 적응된 네트워크이므로, 아마도 위의 사진을 네트워크 $F_\theta$에 통과시킨다면 굉장히 혼란스러워할 것이다.
[ softmax(F_\theta(I)) = (0.51, 0.49) ]
네트워크의 score value에 softmax를 취하게 되면 각 class일 확률이 나오게 되는데, 이때 class $0$일 확률과 class $1$일 확률이 거의 엇비슷해지는 문제가 발생한다. 이처럼 네트워크가 혼란스러울수록(확신이 없을수록) ‘엔트로피는 증가한다’라고 표현된다. 이는 정보 이론에서 언급하는 entropy로 생각하면 보다 이해가 빠르다.
정보 이론에서의 Entropy란 random variable이 특정 확률을 기반으로 추출된다고 했을때, 정보의 불확정성을 지표화한 공식이다.
[ H(X) := -\sum_{x \in X} p(x) \log p(x) = \mathbb{E}(-\log p(X)) ]
어떠한 이미지가 분류될 수 있는 class index인 랜덤 변수가 $0$과 $1$인 distribution을 가정했을 때, 네트워크의 output에 softmax를 취한 값이 서로 비슷할수록 entropy는 증가하는 것을 볼 수 있다. 이처럼 네트워크는 학습 시에 사용된 data distribution를 기준으로 접하지 못했던 이미지나 class를 inference 상황에서 마주하게 될 경우 제대로 분류하지 못하거나 특정 class에 대한 confidence를 제대로 주지 못하는 경우가 발생한다.
Confidence는 특정 class index에 해당되는 probability가 다른 class index와 유의미한 차이를 보일수록 높아지며, entropy와 같은 의미로 쓰이지만 방향만 다르다고 생각하면 된다. 실제로 저자들이 실험해본 결과 Entropy가 낮을수록 corrupted data와 같이 domain shift에 대한 error 비율이 높아졌으며, 이는 실제로 dataset이 어려운 정도(corruption level)가 심해질수록 entropy와 loss가 증가하는 경향성과 맞닿아있다고 할 수 있다. 따라서 이를 통해 내릴 수 있는 결론은 다음과 같다.
Dataset이 달라지면 취할 수 있는 학습 방법은 사실 기존에도 존재했었다. Transfer learning의 한 방법인 fine-tuning은 source dataset에 대해 학습된 네트워크를 supervision이 존재하는 target dataset로 학습하는 방법을 사용한다. 만약 target label이 없는 상황이라면 unsupervised domain adaptation(UDA)을 사용할 수 있고, 이때는 필수적으로 두 도메인 사이의 alignment를 위해 $\mathcal{L}(x^s, x^t)$를 최적화해야하기 때문에 ground truth를 포함한 source dataset이 학습에 관여해야한다. Domain adaptation과 필요로 하는 데이터셋은 같지만 meta learning과 비슷한 방식으로 접근한 TTT(Test-time training)은 source dataset에 대한 supervised loss $\mathcal{L}(x^s, y^s)$ 그리고 self-supervised loss $\mathcal{L}(x^s)$를 동시에 최적화하고, test 상황에서는 test dataset에 대한 self-supervised loss를 최적화하는 방법을 사용한다.
위의 방식들은 대부분 target dataset에 대한 supervision이 있는 상황을 가정하거나, 만약 없다면 source dataset을 학습 과정에 필수로 사용하는 방법론에 해당된다. 이와는 다르게 fully test-time adaptation은 위의 그림의 (a)와 같은 source dataset에 대한 supervised learning 과정과 독립적으로, 오직 test dataset에 대한 entropy만으로 adaptation을 진행하게 된다.
위에서 test-time adaptation과 관련된 대략적인 concept은 모두 설명했다. 실제로 논문을 작성할 때 task를 결정하게 되면 해결하기 위해 필요한 부분들을 정의해야하는데, TENT는 test time에서의 domain shift를 정의했기 때문에 source dataset, target dataset에 대해서 학습될 compatible model이 필요하고, target dataset에 대해 어떠한 loss function을 적용할 지 결정해야한다. 그리고 추가로 이렇게 정의된 loss function을 네트워크 parameter에 어떤 부분에 적용할 지(부분적으로 적용 혹은 전체에 적용) 결정하는 단계를 거친다. 물론 test-time adaptation 상황에서는 source dataset이 주어지지 않기 때문에 네트워크는 이미 source dataset에 대해 학습되었다는 가정을 하게 된다.
Loss는 앞에서 이미 정보 이론에서의 엔트로피를 정의하고 왔기 때문에 짧게 소개하자면, 네트워크의 output인 vector($f_\theta(x^t) = \hat{y}$)에 대한 entropy $H(\hat{y})$를 최소화하게끔 학습된다. 이때 entropy는 Shannon entropy를 최소화한다.
[ H(\hat{y}) = -\sum_c p(\hat{y}_c) \log p(\hat{y}_c) ]
위에서의 $c$는 각 class에 대한 index를 나타내며, $\hat{y}_c$는 $c$번째 class score이며 이에 softmax를 취한 $p(\hat{y}_c)$가 해당 class 변수에 대한 확률값으로서 엔트로피가 계산된다고 볼 수 있다. 네트워크의 예측이 잘못되는 경우 등등 각 샘플 단위로 위의 entropy를 loss로 사용하게 되면 noisy한 학습이 진행될 수 있기 때문에 배치 단위로 prediction을 진행한 뒤 전체에 대한 parameter를 최적화하는 방법을 취하였다. Entropy loss를 사용하는 것에 대해 저자는 unsupervised learning이지만 prediction value를 그대로 사용한다는 점에서 supervised task와 직접적인 연관이 있다고 한다. 이러한 해석이 나온 배경은 다음과 같다.
보통 self-supervised learning setting에서 사용하는 proxy task(auxiliary task)는 ground truth를 직접 예측하는 primary task가 불가능할 때 loss로 사용하게 된다. Class annotation와 같은 정보는 explicit한 정보이기 때문에 image $x$를 토대로 직접 매핑하지 않는 한 딥러닝이 자체적으로 학습이 불가능하지만, $x_t$를 사용하여 도출한 self-supervised label $y^\prime$은 task label 없이 supervision으로 사용될 수 있다. 예를 들어 이미지에 의도적으로 각도 회전을 준 뒤 rotation prediction을 진행한다거나, 이미지를 패치 단위로 나눈 뒤 context를 예측하는 task를 생각해볼 수 있다. 모두 흔히 볼 수 있는 self-supervised learning 방법들 중 하나이다.
하지만 이러한 SSL 방법과 관련된 연구들은 supervised task와 함께 활용되기 어려울 정도로 관련이 없기 때문에 독립적으로 발전하였고, proxy task가 compatible할 수 있는 네트워크의 구조에 제한될 수 밖에 없는 문제가 생긴다. SSL을 토대로 학습된 representation을 supervised downstream task에 사용함에 있어서 primary task와 밸런스를 잘 맞추는 것도 관건이다. 이러한 디테일한 부분들을 굳이 entropy minimization에서는 고려하지 않아도 상관없다는 장점이 있다.
네트워크의 parameter $\theta$는 source dataset의 representation을 모두 학습된 상태로 세팅이 된 상태다. 학습된 parameter $\theta$도 결국 어찌 보면 ground truth인 one-hot encoding에 따라 entropy minimization을 진행한 것과 같은데, 문제는 이 parameter가 source dataset에만 최적화가 되어있다는 것이다. Test time adaptation 상황에서는 실제 최적화 과정에서 source dataset의 접근이 제한되기 때문에 parameter $\theta$를 모두 update하는 것은 기존에 학습된 source dataset의 representation을 망쳐버릴 수도 있다. 특히 deep learning network $f$는 고차원의 input/output을 가지는 nonlinear function이며 parameter 역시 고차원의 텐서로 구성되기 때문에 test-time에서 output이 parameter 변화에 민감하게 반응할 수 있다.
안정성과 효율성을 위해 저자들은 affine(multiplication, translation)하며 차원 수가 낮은(channel-wise) parameter인 batch normalization statistics($\gamma, \beta$)에 집중했다. Batch normalization은 익히 들어서 알고 있겠지만 해당 레이어로 들어온 feature map에 대해 배치 단위로 정규화를 진행한 뒤($\bar{x} = (x-\mu)/\sigma$) 학습된 scale factor $\gamma$와 bias $\beta$로 재정규화를 진행한다($x^\prime = \gamma \bar{x} + \beta$). 따라서 source network에 대해서는 정상적으로 batch normalization layer로 학습된 parameter를 test time optimization 과정에서는 entropy minimization의 목적 함수를 통해 업데이트하는 구조로 수정한 것이다.
파라미터의 초기화 상태는 transformation과 관련된 변수($\gamma, \beta$)를 제외하고는 source network parameter를 그대로 사용하게 되고 source dataset의 normalization statistics인 $\mu, \sigma$는 모두 무시한다. Target data(test dataset)이 학습에 활용될 때 각 step에서 normalization 변수들인 $\mu, \sigma$ 그리고 transformation 변수들이 조정된다. Normalization statistics는 각 레이어에서 차례대로 estimation이 된다. 이 부분은 batch normalization이랑 동일하다고 생각하면 된다. 그리고 학습 가능한 parameter인 $\gamma, \beta$는 entropy minimization loss를 기준($\nabla H(\hat{y})$)으로 backward pass 과정에 업데이트된다. 계산 효율성을 위해 transformation 변수의 update가 다음 배치에만 영향을 주는 scheme을 그대로 사용했다고 한다.
Dataset으로는 대표적인 classification dataset인 CIFAR-10/CIFAR-100 그리고 ImageNet에 corruption을 준 데이터셋이나 MNIST/SVHN/MNIST-M/USPS의 숫자 데이터셋을 사용하여 실험을 진행하였다.
학습 네트워크 구조는 corruption(CIFAR-10, CIFAR-100 and ImageNet)의 경우에는 ResNet-26(CIFAR) 혹은 ResNet-50(ImageNet)을 사용하였다. 그리고 숫자 데이터셋에서는 ResNet-26을 고정으로 사용하였다. Residual neural network 구조는 batch normalization을 포함하기 때문에 test dataset에 대한 최적화가 가능하다.
Corruption benchmark로 사용된 CIFAR dataset에 대한 결과는 위와 같다. Corruption dataset은 gaussian, defocus를 포함한 여러 형태의 natural domain shift에 대해 그 강도를 severity로 정의하여 구성한 데이터셋인데, 좌측 표의 결과는 그 중 가장 강한(corruption이 심한) 샘플에 대한 결과이다. 우측 그래프는 각 domain shift에 따라 severity 마다의 error를 평균을 낸 결과로, TENT(파란색)가 전반적으로 좋은 결과를 보여주고 있는 것을 볼 수 있다. ANT라고 된 부분이 기존 SOTA인 adversarial noise training 방식인데, 대부분의 corruption에 대해 해당 성능을 뛰어넘은 것으로 볼 수 있다.
표에서 보이는 RG, UDA-SS, TTT는 unsupervised domain adaptation 방식인 adversarial domain adaptation과 self-supervised domain adaptation 그리고 앞서 언급했던 test-time training 방식을 의미한다. BN 방식은 단순히 test 상황에서 target data에 맞게 normalization statics를 업데이트하는 방식이다. PL은 pseudo-labeling으로, supervision이 없는 target data에 대해 confidence threshold를 주어 pseudo-label을 만들어내는 작업이다. RG, UDA-SS는 unsupervised domain adaptation 방식이고 BN, PL은 TENT와 같이 test time adaptation 방식에 해당된다.
앞서 본 내용은 사실상 domain adaptation이라기 보다는 dataset에 corruption이 발생했을때 자연적으로 domain shift가 발생하는데, 이럴 경우 generalization이 효과적으로 진행될 수 있는지 확인하는 과정이었고 SVHN dataset을 기준으로 MNIST와 같은 dataset에 대해 domain adaptation 결과를 본 모습은 위와 같다. 모든 dataset에 대해 SOTA를 달성한 것은 아니지만, source dataset 없이도 단일 epoch에 대해 test dataset으로 학습하는 방법이 UDA 방법에 필적할 만큼, 심지어 더 좋은 성능을 보일 수 있음을 입증하였다.
또한 저자는 TENT가 효과적으로 entropy를 감소시킬 수 있음을 보임과 동시에 loss의 변화와 entropy의 변화가 높은 상관관계를 가짐을 보여주었다. 좌측 그래프는 CIFAR-100-C에 대한 모든 corruption type/level에 대한 loss 변화와 entropy 변화 그래프이다. 우측 그래프를 보게 되면 TENT가 효과적으로 domain distribution을 따라갈 수 있음을 알 수 있는데, 뒤쪽에 있는 노란색 distribution이 corruption이 발생하기 전이고 앞쪽이 실제로 layer에서 나온 feature map의 distribution에 해당된다. Source(source dataset에 대해서만 학습된 네트워크)의 경우에는 전혀 분포를 따라가지 못하는 것을 볼 수 있지만, TENT는 기존의 BN보다 실제 corrupted dataset에 대한 분포를 잘 따라가고 있는 형태를 보여주며, Oracle(target에 대해 supervised learning을 한 네트워크)와 거의 유사한 분포를 보여줌으로써 supervision 없이도 domain shift에 적응할 수 있음을 잘 보여준다.
위의 방법은 효과적으로 test dataset만을 활용한 adaptation을 제시했지만 근본적인 문제가 있다. Test dataset에 대해 학습되는 과정에서 domain distribution이 변하지 않는다는 가정이 있어야한다. 하지만 실제론 그렇지가 않은 것을 알 수 있다. 예를 들어 주행 환경이 터널(어두운 곳)으로 변하게 되면, 어두운 환경에 적응하기 전에 터널을 빠져나오게 될 것이고, 갑자기 비가 내리다가 별안간 눈으로 바뀌는 경우와 같이 시시각각 inference domain이 달라지게 되면 test-time adaptation이 제대로 작동하지 못한다는 것이다.
이를 해결하고자 하는 것이 바로 continual learning이고, test time adaptation이 해결하고자 하는 문제랑 사실상 거의 맞닿아있다고 할 수 있다. 아무튼 앞서 entropy minimization을 통해 test-time adaptation을 제시한 TENT에 이어 continual learning 상황에서의 test time adaptation을 task로 제시한 CoTTA 논문에 대해서 살펴보도록 하겠다.
기존 방식들이 test time adaptation을 해결하는 방식은 TENT와 같이 entropy regularization을 하거나 pseudo-label을 활용한 self training 방식이 전부였다. 이러한 방법들은 test dataset이 stationary domain이라는 가정이 있어야하는데, 예컨데 pseudo-label을 사용하는 상황에서 효과적인 학습을 기대하는 것은, 일부 샘플에 대해 약간의 오차 범위 내에서 네트워크가 prediction을 하여 ground truth를 만들더라도 이에 대한 noise를 연속된 iteration을 통해 정규화할 수 있다는 점이다. 그러나 변하는 상황에서 pseudo-label을 사용하게 되면 이러한 정규화를 기대해볼 수가 없고, 단순히 noise가 중첩되면서 네트워크가 엉망이 되어버리는 것이다. 이러한 문제는 TENT와 같은 entropy minimization에서도 마찬가지의 영향력을 행사한다.
또한 가장 큰 문제는 mis-calibrated(제대로 정렬되지 않은) domain이 지속적으로 network를 학습시키다 보면 source dataset에 대한 접근이 제한되는 TTDA의 상황에서 source domain에 대한 knowledge가 손상되게 되고, 이는 catastrophic forgetting으로 이어진다.
이러한 문제들을 해결하기 위해 지속적으로 변하는 환경에서 adaptation을 진행할 수 있는 방법을 저자들은 제시하게 된다. CoTTA는 TTDA 기존 방법들의 문제로 두 가지를 제시하며 시작한다. 첫번째는 error accumulation으로, domain이 변하며 생긴 pseudo-label의 error를 경감시킬 방법에 대한 논의가 된다. 저자들은 pseudo-label의 퀄리티 및 학습 효율을 높이기 위해 두 가지의 서로 다른 방법(weight-averaged teacher model/augmentation-averaged prediction)을 제시한다. 두번째는 source knowledge에 대한 보존이다. TTDA 방식들에서는 학습 가능한 parameter를 제한하는 등의 방법을 사용했지만 continual setting에서는 같은 방법을 사용하기엔 부적절한 면이 많다. 따라서 source dataset에 pre-trained된 parameter를 부분적으로 restore하는 방식을 제안하였다. 방법론들에 대한 내용은 모두 뒤에서 하나씩 살펴보도록 하겠다.
방법들에 대해 살펴보기 전에 continual test time domain adaptation이 수행할 환경에 대한 setting을 먼저 살펴보면 다음과 같다.
Source dataset인 $(X^S, Y^S)$에 대해 사전 학습된 parameter $\theta$를 가지는 model $f_\theta(x)$가 있다고 해보자. 이 네트워크를 inference 상황에서 continual하게 변하는 target domain에 대해 최적화하는 것이 목적이고, 이때 기존의 TTDA의 setting과 동일하게 source dataset에 대해서는 접근이 제한된 상황을 가정한다. Unlabeled target domain dataset $X^T$가 네트워크에 연속적으로 주어지고, model은 오직 현재의 dataset domain에 대해서만 최적화할 수 있게 된다. 시간에 따라 target dataset의 도메인이 일정하지 않기 때문에 sequential한 상황을 나타내기 위해 time step $t$에서의 input을 target data $x_t^T$로 가정해보자. 이때 같은 time step $t$에서의 model $f_{\theta_t}$는 나름의 prediction $f_{\theta_t}(x_t^T)$을 내리고, 이를 통해 다음 input에 적용될 network의 parameter($\theta_{t+1}$)를 최적화하게 된다.
기존의 test time adaptation 방법들은 대체로 source model을 학습하는 과정에서 domain generalization에 도움이 될 process를 auxiliary task로 추가하게 된다. 예를 들어 TTT(Test time training)의 경우 source dataset에 대한 self-supervised task로 rotation prediction branch를 추가하여 함께 학습하는 식으로 supervised learning과 함께 최적화에 사용한다. 이러한 방식은 source model을 그대로 사용하지 않고 새로운 학습법을 사용하여 source dataset에 대한 학습 과정을 거쳐야한다는 번거로움이 있다. 하지만 CoTTA에서는 TENT와 마찬가지로 source dataset 학습 방법 자체를 바꾸지 않는다는 점, batch normalization이라는 구조적 특징이 필수적이었던 TENT와는 다르게 네트워크 구조 또한 자유롭게 사용할 수 있다는 점에서 source dataset에 대한 off-the-shelf라고 언급한다.
각 time step $t$에 따라 domain이 계속 변하는 과정에서의 pseudo label(network의 prediction)은 $\hat{y}_t^T = f_{\theta_t}(x_t^T)$로, 불안정하며 error accumulation 우려가 있다는 것을 앞서 이 논문에서 문제로 삼았다고 언급하고 넘어왔다. 기존 연구들 중 self-distillation과 같이 weight-averaged model의 안정성이 높다는 내용을 토대로, 저자들은 mean teacher 방식을 활용한 안정성 있는 weight averaged pseudo label을 생성하는 방법을 고안하였다.
방법은 위의 그림과 동일하며 수식을 통해 표현하면 다음과 같다. 처음 setting은 student network와 teacher network 모두 source dataset에 대해 pre-trained된 network의 parameter $\theta_0$을 기준으로 하고, teacher network는 loss를 통해 weight를 update하지 않고 mean-average 방식을 통해 student network의 parameter를 조금씩 가져오는 업데이트 방식을 취하며, 실질적으로 학습되는 것은 student network라고 생각하면 된다. Time step $t$에서teacher network($\theta^\prime_t$)로 하여금 추출된 pseudo-label을 $\hat{y^\prime}_t^T = f_{\theta_t^\prime}(x_t^T)$라고 할 때, student $f_{\theta_t}$는 teacher network의 prediction에 대한 cross entropy를 loss로 하여 최적화된다.
[ \mathcal{L}_{\theta_t}(x_t^T) = -\sum_c \hat{y^\prime}_{tc}^T \log \hat{y}^T_{tc} ]
위의 loss를 통해 optimization된 student network의 parameter에 대해($\theta_t \rightarrow \theta_{t+1}$) teacher network의 weight는 exponentially moving하는 과정을 거친다.
[ \theta_{t+1}^\prime = \alpha \theta_t^\prime +(1-\alpha) \theta_{t+1} ]
$\alpha$는 smoothing factor로서, 얼마나 previous parameter를 보존할 지 결정하는 hyperparameter가 된다.
Training time에서의 data augmentation은 흔히 overfitting을 방지하고자 하는 정규화 목적으로 사용된다. 그와는 반면에 test time augmentation은 domain shift에 상관없이 예측의 robustness를 높이기 위한 방법으로 제시되었다. CoTTA에서는 prediction의 confidence를 기준으로 domain shift 정도를 판단하고, augmentation averaged pseudo label을 사용하여 error accumulation 효과를 줄이는 방법을 사용하였다.
[ \begin{aligned} \tilde{y^\prime}_t^T =& \frac{1}{N} \sum_{i=0}^{N-1} f_{\theta_t^\prime} (\text{aug}_i (x_t^T)) \newline {y^\prime}_t^T =&\begin{cases} \hat{y^\prime}_t^T & \text{if conf}(f_{\theta_0}(x_t^T)) \ge p_{th} \newline \tilde{y^\prime}_t^T & \text{otherwise} \end{cases} \end{aligned} ]
수식에서 확인할 수 있듯 저자들은 hypothesis를 제시했는데, 이는 바로 confidence threshold $p_{th}$를 기준으로 domain distance를 구분한 것이다. 만약 confidence가 threshold보다 높다면 target data의 distribution이 source distribution과 유사하기 때문에 굳이 augmentation average를 사용하지 않고, 반대로 작다면 source distribution와 상이하다고 판단하여 $N$개의 서로 다른 augmentation input에 대한 prediction 평균을 pseudo label로 사용하게 된다.
앞서 논문에서 제시한 두 가지 문제 중 첫번째는 mean teacher/augmentation으로 해결했고, 두번째는 생각보다 심플한 방식으로 해결한다. Catastrophic forgetting 문제를 해결하기 위해 그냥 simple하게 연산에서 사용되는 특정 layer의 convolution parameter $W_{t+1}$에 대해,
[ M \sim \text{Bernoulli}(p) ]
$p = 0.01$을 $1$(occurence)에 대한 확률 값으로 가지는 베르누이 분포로 mask를 만들고(이렇게 되면 파라미터 $100$개 중 대략 $1$개 꼴로 $1$이 value인 mask가 생성됨), 이를 weight에 element-wise로 곱하게 된다.
[ W_{t+1} = M \odot W_0 + (1-M) \odot W_{t+1} ]
for nm, m in self.model.named_modules():
for npp, p in m.named_parameters():
if npp in ['weight', 'bias'] and p.requires_grad:
mask = (torch.rand(p.shape)<self.rst).float().cuda()
with torch.no_grad():
p.data = self.model_state[f"{nm}.{npp}"] * mask + p * (1.-mask)
Official github에서 masking을 하는 코드는 위와 같다. 모든 실험에서의 config에서 restoration mask의 probability는 $0.01$를 사용하였다. Stochastic restoration은 학습 과정에서 특정 node를 초기화한다는 점에서 다른 방식의 dropout으로 생각해볼 수 있다. 네트워크 학습 framework 전체를 요약한 구조는 아래와 같다.
실험 셋팅의 경우에는 TENT, SHOT과 같은 기존 논문들과 크게 다르진 않은데, TENT와 차이점이라고 하면 SVHN/MNIST variation에 대한 데이터셋이 제외되었다는 점이다. 그리고 TENT에서는 segmentation task에 대한 성능 부분이 main에 제시되지는 않는데, CoTTA에서는 아무래도 main figure를 자율 주행(ACDC)으로 설정했기 때문에 continual 상황에서의 domain adaptation을 보여주기에 가장 적절한 task라고 판단했다고 보인다. 또한 TENT는 애초에 test-time adaptation 기준으로 설계된 네트워크이므로 corruption dataset을 제외하고 비교를 할 때는 같은 continual 상황에 맞게끔 학습 후 비교하는 것을 볼 수 있다. Segmentation을 기준으로 TENT에서 segmentation backbone으로는 HRNet-W를 사용했고 CoTTA에서의 비교 과정에서는 Segformer를 baseline으로 삼았다.
위의 결과는 corruption과 관련된 CIFAR 성능 지표고, WideResNet28/ResNext 구조를 사용하였다. 따라서 segmentation이나 classification 모두 backbone 자체를 기존 TTDA와 다르게 가져간 것을 알 수 있는데, 이 부분에 대해서는 의문이 들었다(왜 ResNet-26, 50을 baseline으로 잡지 않았는가 혹은 굳이 다른 구조를 사용한 이유가 무엇인지…).
Segmentation test time adaptation 상황에서 성능 평가에 사용된 데이터셋은 adversarial한 주행환경에서의 road view segmentation을 위한 데이터셋인 ACDC(Adverse Conditions Dataset)이다. 실제로 논문에서 베르누이 probability factor나 조절할 수 있는 기타 hyperparameter에 대한 ablation을 추가로 진행하지는 않았고, 표에서 볼 수 있듯이 CIFAR-10(corupted) dataset에 대해 논문에서 제시한 세 방법들(weighted average prediction, augmentation average prediction, stochastic restore)의 효과에 대한 성능 비교를 통해 feasibility를 본 것이 전부이다.
본 논문은 continual test time adaptation이라는 task를 새롭게 정의함으로써 보다 real world system에 맞는 inference에 대한 domain generalization을 학습 목표로 삼았다. 그러나 논문에서 충분한 ablation이 진행되지 않은 점, TENT와는 다르게 학습 네트워크를 고정하지 않은 점(학습 방법에 의한 성능 향상인지, 단순히 네트워크 구조에 따른 성능 향상인지 확인하기 애매함)이 의문점으로 남았다.
아래에 utterance로 저자가 직접 답변해주셨는데, 실제로 TENT author는 ResNet26을 사용했다고 paper에서 밝혔지만 official code에서는 WRN28을 사용했다고 한다. 왜 오피셜에서와 페이퍼에서 사용된 구조가 다른지에 대해 정확한 이유는 모르겠지만 논의된 내용을 보니 ResNet26에 대한 코드를 공개하지 않은 것 같다. 그리고 hyperparameter에 대한 추가 실험은 supplementary material에 추가되었다고 한다(아래 댓글 참고).