본문 바로가기

AI

Train Once, Test Anywhere : Zero-Shot Learning for Text Classification

반응형

지도학습과 비지도 학습 등 최근에는 다양한 인공지능 학습 방법에 대한 연구가 이뤄지고 있다. 그중에서도 Zero shot learning이라는 학습 방법에 대해 궁금해서 관련 논문으로 개념을 알아보려 한다.

Train Once, Test Anywhere - Zero-Shot Learning for Text Classification

Zero shot Learning

- zero-shot learning은 훈련하는 동안 훈련하지 않은 class에 대해서 추적할 수 있는 능력이다.
- 정확한 감독(explicit supervised) 없이 새로운 것에 대해서 생성하고 인식할 수 있는 능력을 학습한다.
- 예를 들어, transfer learning은 각 class 데이터 세트에 대해서 모델에 fine tune이나 훈련해야 한다. 그런데 zero-shot learning은 감정 분석과 새로운 뉴스 분류 같은 task를 task-specific 훈련 없이도 바로 할 수 있다.

Zero shot learning 뭐가 좋을까?

- 주요 목적은 효율성의 극대화
- 학습 과정에서 모델이 의미 정보(semantic information)를 적절히 배우도록 하기 위해서
- 각 도메인에 속하는 다양한 task에 적응
- 입력 데이터의 자체에 대한 이해와 표현력을 높이기 위한 비지도 학습(unsupervised learning)과 자기 지도학습

Zero shot learning-Train once , Test Anywhere

이 논문에서 zero shot classification을 위해서, X classes에서 텍스트를 분류하는 것이 아니라 새롭게 특정 클래스 각각을 binary classification으로 re-formulation 했다.

연구의 필요

- unseen classes에 대해서 예측하는 작업이 필요한 경우가 많다.
- 따라서 category tree를 사용해서 학습한 적 없는 extended category를 학습한 category의 latent space에서 relativeness를 찾아서 예측하는 zero-shot learning을 제안한다.

과정

- 데이터 수집 - 4.2 million news headlines를 크롤링했다.
- 워드 임베딩 Google의 사전 학습된 Word2vec을 사용했다
- 모델 설계 - 3가지 다른 모델 아키텍처를 제안했다.

모델 1

- word embedding 평균값과 label embedding 값을 가지고 concatenate해서 fully connected layer를 거치게 한다. 

모델 2

 

두 번째 모델은 word embedding의 평균값으로 하는 것이 아니라 word embedding을 LSTM을 거쳐서 마지막 hidden state를 문장 vector로 다룬다. 그리고 label embedding과 합쳐서 동일하게 fully connected layer를 거쳐서 binary classification을 하도록 한다.

모델 3

세 번째 모델에서는 word embedding과 label embedding을 concat을 먼저 하고 LSTM을 통과해서 last hidden state를 사용한다. last hidden state로 fully connected layer를 통과해서 binary classification을 하도록 한다.

훈련

News headlines Dataset 을 반으로 나눠서 반은 실제 label과 잘 mapping 되게 구성하고 나머지 반은 무작위로 선택된 관련 없는 label로 선택되도록 구성한다.
- 손실함수는 binary cross entropy를 사용한다
- 최적화 함수는 Adam을 사용한다.

결론

- train data set에서는 72%, 72.6% and 74%의 정확도를 모델 각각 달성했고 test 데이터 세트에서는 76% and 81%로 상대적으로 더 높게 나타냈다.
- category tree를 사용해서 threshold를 0.5로 두고 binary classification을 수행한 결과 64%, 53% and 64.5%의 정확도를 보였다. 라벨을 모두 학습한 모델인 SVC나 나이브 베이즈 모델은 74%, 78%의 정확도를 보였다.
- Architecture 3을 가지고 category tree를 사용하지 않고 직접적인 class label을 사용해서 훈련한 경우에는 49%의 정확도밖에 얻을 수 없었다.
이런 결과가 나올 수 있었던 이유는 아마도 문장과 데이터 세트에 있는 단어를 넘어서 확장된 단어 사이에 관련성의 개념을 학습했다고 할 수 있다. 그리고 이것은 위의 실험뿐만 아니라 다양한 future work 범위에 적용될 수 있음을 시사한다.

참고 자료

https://amitness.com/2020/05/zero-shot-text-classification/
https://arxiv.org/abs/1712.05972
https://paperswithcode.com/task/zero-shot-learning

반응형

'AI' 카테고리의 다른 글

Seq2Seq(시퀀스 투 시퀀스)  (0) 2024.01.15
EDA를 왜 해야 할까?  (0) 2024.01.15
Transformer - Encoder(어텐션 메커니즘)  (0) 2024.01.15
머신러닝 요약(ML Summary)  (0) 2024.01.15
분류 모델의 종류(classification model)  (0) 2024.01.15