본문 바로가기

PR113

[3줄 LLM] Rethinking Optimization and Architecture for Tiny Language Models(Arxiv, Feb 24)

 

 

Paper: https://arxiv.org/abs/2402.02791

 

1. Mobile용으로 사용하기에 LLM은 computation/memory cost가 매우 크다

  - 그래서 tiny language model(TLM)이 필요하기도 하고,

  - 다양한 학습 방법론이 있지만 데이터가 아닌 efficient training strategy에 집중하는 work은 적기도 하고,

  - 충분히 분석하기도 힘들고 cost때문에 optimization strategy를 다양하게 시도해 볼 수도 없었음

  - 그래서 TLM을 구축하는 다양한 방법론에 대해 실험해보고, 이걸로 좋은 모델을 만들어보자

 

2. Training methods for TLMs

  - 2-1. Neural Architecture

    - 2.1.1. Compact Tokenizer

      - 100K vocab. size로 initialize 시켜봤더니, 48k vocab.만으로 training corpus의 97.86%을 커버할 수 있다

      - Model parameter가 작아질수록 vocab. embedding matrix의 비율이 매우 커진다

    - 2.1.2. Architecture Tweak

      * 48k vocab, 1B parameter로 고정하고 실험

      - 깊어질수록 model performance는 올라가지만 inference cost 또한 올라가는데, 대충 20 정도가 좋은 것 같음

  - 2.2.Parameter Initialization

    - 2.2.1. Random Initialization

      - 보통 N(0, std^2)로 init. 하는데, std를 어떻게 주는 것이 좋을 지 다양하게 실험

      - MHA는 layer마다 다르게, MLP는 똑같게. 자세한 내용은 Appendix. B 참고

    - 2.2.2. Parameter Inheritance

      - 일부 layer를 skip하면서 실험해본 결과(Figure. 5) 초반/마지막 layer는 성능에 영향을 많이 미치고, 가운데 부분은 조금 덜 영향을 미침

        - Intra-layer에 redundancy가 있는 건 아닐까?

        - 무언가 metric으로 layer selection을 해야하는데, loss 관점에서는 L1, L2, Talyor 등보다 learnable mask를 쓰니 초반부터 coverge 시점까지 loss가 낮더라

  - 2.3 Model Optimization

    - 2.3.1. Batch Size & Learining Rate

      - lr=(bs/bs0)r×lr0 공식 사용, deafult batch size bs0 = 1M, learning ratge lr0 = 10^(-4), increment rate r = 0.5 or 1

      - Figure 7. 실험 결과를 보면, 너무 크지 않은 batch size(~4M) 정도면 convergence speed가 비슷하게 나옴

      - Large batch size일 경우엔 optimization strategy를 좀 더 조정할 필요가 있음

    - 2.3.2. Multiple-Round Training

      - 보통 1 round 학습하는데, catastrophic forgetting 문제도 존재하고 특히 tiny model은 limited capacity때문에 더 크게 찾아올 수 있음

        - Batch 별로 loss 계산 후, 높은 loss(hard to fit)를 보이는 loss는 다음 학습 때 높은 확률로 포함되도록 세팅

 

3. 위 process를 통해 얻은 결과로 PanGu 만들어서 비교했더니, 잘 나오더라!

    - 타 모델들(ex. Phi, TinyLLaMA, etc...) 성능이 매우 처참한 것처럼 보이는데(MMLU), 맞는지 모르겠다.

    - 특히 C- 붙은 task가 중국어 메인인 것 같은데, 중국어가 메인이 아닌 모델들도 실험에 들어가있지 않나...?