A Contrastive Framework for Neural Text Generation (NeurIPS 2022)

발표자료

느낀점

  • 잘 쓴 논문 같다
  • 간단한 아이디어지만 효과적
  • decoding 시간이 생각보다 덜 걸려서 신기했음
  • contrastive 방법론이 결국 비슷한 토큰 안나오게 하겠다인데, simCTG는 MLE 로 보완되지만 디코딩 부분은 아예 비슷한걸 견제하는 식으로 나오는데도 결과가 좋게 나오는게 신기 (물론 기존의 확률아 있어서 보완이 되지만) -> degeneration penalty를 크게 줘도 ppl 결과가 좋길래 신기했음)

Note

Author

Yixuan Su* Tian Lan** Yan Wang** Dani Yogatama+ Lingpeng Kong++ Nigel Collier*
*Language Technology Lab, University of Cambridge
**Tencent AI Lab +DeepMind
++Department of Computer Science, The University of Hong Kong

Abstract

  • 문제: maximization-based decoding methods (e.g., beam search) of neural language models often lead to degenerate solutions
    • the generated text is unnatural and contains undesirable repetitions
  • 시중에 나온 대안: Existing approaches introduce stochasticity via sampling or modify training objectives to decrease the probabilities of certain tokens (e.g., unlikelihood training)
    • 대안의 문제: However, they often lead to solutions that lack coherence
  • 본 논문에서 보인 것: an underlying reason for model degeneration is the anisotropic distribution of token representations.
  • present a contrastive solution:
    • (i) SimCTG, a contrastive training objective to calibrate the model’s representation space,
      • anisotropic 해소하겠다
    • (ii) a decoding method—contrastive search—to encourage diversity while maintaining coherence in the generated text.
      • 다양하게 뽑지만 coherence 유지해보겠다
  • SOTA를 이기는 결과 보여줬음

Introduction

  • the conventional approach of training a language model with maximum likelihood estimation (MLE) and decoding the most likely sequence is often not sufficient
    • 평범한 접근방법인 MLE 기반의 학습과 디코딩은 대체로 충분하지 않음
    • degeneration 결과를 보여주기도함
      • tend to be dull and contain undesirable repetitions at different levels (e.g., token-, phrase-, and sentence-level)
      • 해결방법중 하나는 less likely vocab에서 샘플링하는 디코딩방법을 사용하는 것 (To alleviate this problem, previous solutions modify the decoding strategy by sampling from less likely vocabularies)
      • 하지만 이런 방법은 의미적으로 안맞거나 반대되기도하는등 부작용이 있음
      • Another approach addresses the degeneration problem by modifying the model’s output vocabulary distribution with unlikelihood training (unlikelihood training도 비슷한 맥락)
  • 이러한 이유는 token representation distribution의 비대칭 때문이라고 주장해보겠음 (the degeneration of neural language models stems from the anisotropic distribution of token representations, i.e., their representations reside in a narrow subset of the entire space [10, 9, 44].)
  • Figure 1은 GPT-2의 token representation에 대한 cosine sim matrix임 대부분이 0,95 이상인걸 볼수 있음
    • In an ideal setting, the token representations should follow an isotropic distribution, i.e., the token similarity matrix should be sparse and the representations of distinct tokens should be discriminative
      image
  • 본 논문에서 제안하는 모델
    • SimCTG (a simple contrastive framework for neural text generation) that encourages the model to learn discriminative and isotropic token representations.
    • The Key intuition
      • (i) at each decoding step, the output should be selected from the set of most probable candidates predicted by the model to better maintain the semantic coherence between the generated text and the human-written prefix
      • (ii) the sparseness of the token similarity matrix of the generated text should be preserved to avoid degeneration.
  • PPL이나 휴먼 평가등에서도 개선된 결과를 보여줌
    • the experimental results verify that SimCTG improves the intrinsic qualities of the language model, as evaluated by perplexity and token prediction accuracy (§4.2 and Appendix D). Moreover, we demonstrate that the proposed contrastive search significantly outperforms previous state-of-the-art decoding methods in both human and automatic evaluations

Background

  • MLE로 학습하는 LM은 transformer-based model 구조에서 모델의 표현이 anisotropic distribution을 갖게 됨
  • Deterministic Sampling은 greedy, beam이고 highest probability에 의존해서 degeneration을 야기함
  • Stochastic Sampling은 top-k, nucleus sampling류임, 가끔 의미적으로 반대되는 단어까지 생성하기도 하는 부작용 있음

image

Methodology

  • how to apply contrastive learning to calibrate the representation space of the language model
  • introduce our proposed contrastive search decoding algorithm

Contrastive Training

  • Our goal is to encourage the language model to learn discriminative and isotropic token representations
    • cosine sim으로 유사한 토큰들은 더 큰 loss를 받는 구조 -> 붙어있는 토큰을 더 멀리 떨어뜨리게 하는 효과
      • Q) 벡터적으로 유사한 토큰표현을 떨어뜨리는 효과가 서로 구분할 수 있는 효과를 주지만, 학습은 잘 되게하는걸까? 이 부분은 MLE가 잘해야하는 구조인듯
    • ρ 값이 0 이면 적용 안하는거나 마찬가지 -> MLE만 쓰는 구조

image

  • 모델이 예측한 셋 안에서 확률 높은 후보들이되, 이전 문맥과 구분이 될 수 있어야함
  • 토큰 생성에 대한 확률값에 해당 토큰의 hidden states와 이전 토큰들의 hidden states의 유사도중 max값을 뽑아서 penalty term으로 줌
    • token들이 많으면 이거 계산시간 오래 걸리지 않을까?

image

Document Generation

  • Open-ended document generation에 적용
  • Our proposed approach is architecture-agnostic
    • GPT2 (117M)에 Loss_SimCTG 적용해서 파인튜닝하고, contrastive search 이용해서 decoding 해봤음
    • baseline은 GPT2를 evaluated benchmark에 대해서 finetuning하되 아래 방법으로 함
      • [1] MLE GPT2
      • [2] unlikelihood GPT2
  • Evaluation Benchmark는 Wikitext-103 데이터
  • Training은 SimCTG, MLE는 Wikitext-103 (40k training steps)데이터에 대해 파인튜닝했고 UL baseline에 대해서는 38.5K steps를 token-level, 1.5K steps를 sentence-level로 UL 학습함
    • bs: 128, max_seq_len: 256, optim: adam, lr: 2e-5
  • Decoding은 prefix를 32~128 length정도 되는 정보를 주고 시작함
    • deterministic method: greedy, beam (10) search
    • stochastic method: p=0.95
    • proposed contrastive search: k and α in Eq. (5) are set as 8 (top_k 8개 보고) and 0.6. (degeneration penalty에 점수를 좀 더 줬음)

Evaluation Metrics

평가 기준은 아래의 관점으로 정함

  • (1) language modelling quality
    • Perplexity on the test set of Wikitext-103.
    • Prediction Accuracy (토큰맞추기)
    • Prediction Repetition (next token의 top-1 예측이 prefix(이전입력)에 있으면 카운팅됨), 낮은게 좋음
  • (2) generation quality
    • Generation Repetition (sentence-level에서 n-grams의 반복을 카운팅) rep-n = 100 × (1.0 − ( |unique n-grams(xˆ )| / |total n-grams(x^)| ))
    • Diversity (n-gram levels에서 repetition을 계산함)
    • MAUVE (생성한거랑 human-written text와 token distribution closeness를 계산함)
    • Semantic Coherence (simCSE로 prefix와 generated text의 representation을 구해서 coherence score를 계산함)
    • Perplexity of Generated Text

image

Human Evaluation

  • 평가 기준 (5점 척도)
    • Coherence: Whether the generated text is semantically consistent with the prefix.
    • Fluency: Whether the generated text is fluent and easy to understand.
    • Informativeness: Whether the generated text is diverse and contains interesting content.
  • 평가 데이터
    • randomly select 200 prefixes with length of 32 from the test set of Wikitext-103
  • 큰 모델에서 가장 좋은 점수를 얻었기 때문에 GPT3등에 대해서도 future work 해볼 예정

image

Open-domain Dialogue Generation

  • Benchmark and Baselines
    • 영어랑 중국어 진행함
    • 영어: DailyDialog
    • 중국어: LCCC

image

Further Analysis

image

Token Representation Self-similarity

  • Self-similarity? 토큰끼리 sim의 평균
    image
  • Figure2. 보면 중간 레이어는 self-similarity가 비슷함
  • output layer에서는 차이가 확 나기 시작함

The Effect of Contrastive Loss Margin

  • contrastive loss margin ρ (Eq. (2))에 대해서 분석해보면 perplexity on the Wikitext-103 test set기준에서는 0.5값이 가장 적당한 마진임을 알 수 있음

Contrastive Search versus Nucleus Sampling

  • 두가지 관점에서 분석함
    • (1) generation diversity
    • (2) perplexity of the generated text (gen-ppl)

Decoding Latency Comparison

  • 생각보다 latency차이가 많이 안남 (신기)

아래 그림은 simCTG 모델 기준임
image

Case Study

  • 실제 예제들
    image

Comparison of Token Similarity Matrix

  • 각 기법별 token similarity matrix보면 제안 기법의 align이 잘 되어있는걸 볼 수 있음
    image

Conclusion

  • Neural LM의 degeneation의 문제는 token representation의 anisotropic distribution 문제임을 보임
  • SimCTG 제안함, isotropic, discriminative representation space 만들어줌
  • contrastive search이라는 디코딩 방식도 제안함
  • automatic and human evaluations에서 모두 가장 좋은 점수 얻고 SOTA보다 높은 점수 기록함

Appendix

Usage

  • installation

    1
    pip install simctg --upgrade.
  • example

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    import torch
    # load the language model
    from simctg.simctggpt import SimCTGGPT model_name = r’cambridgeltl/simctg_wikitext103’ model = SimCTGGPT(model_name)
    model.eval()
    tokenizer = model.tokenizer
    # prepare input
    prefix_text = # The prefix text in Table 4
    print (’Prefix is: {}’.format(prefix_text))
    tokens = tokenizer.tokenize(prefix_text)
    input_ids = tokenizer.convert_tokens_to_ids(tokens) input_ids = torch.LongTensor(input_ids).view(1,-1)
    # generate result with contrastive search
    beam_width, alpha, decoding_len = 8, 0.6, 128
    output = model.fast_contrastive_search(input_ids=input_ids,
    beam_width=beam_width, alpha=alpha,
    decoding_len=decoding_len) print("Output:\n" + 100 * ’-’)
    print(tokenizer.decode(output))

Gen-ppl Results Measured by Different Models

  • 다른 모델들에 대한 결과를 보면 ppl 자체는 낮은 모델들도 있지만 human-written text와 가장 유사한건 역시 제안하는 모델
  • ppl이 낮은것 보다 사람이랑 유사한게 제일 좋은 것이다라고 주장 (이런 주장들은 기존에도 쭉 있었고 여기서도 같은 주장제기)
    image

A Contrastive Framework for Neural Text Generation (NeurIPS 2022)

https://eagle705.github.io/A-Contrastive-Framework-for-Neural-Text-Generation-NeurIPS-2022/

Author

Joosung Yoon

Posted on

2022-10-05

Updated on

2022-10-05

Licensed under

댓글