머신러닝 & 딥러닝/딥러닝

Gradient Accumulation 문제 및 해결 방법과 Batch 비교

Haru_29 2024. 11. 27. 19:52

Transformer Gradient Accumulation 문제 및 해결 방법

1. 문제 개요

  • 문제: Gradient Accumulation(GA)은 수학적으로 Full Batch Training과 동일해야 하지만, 실험 결과 특정 조건에서 Training Loss가 일치하지 않음.
  • 현상: GA 크기가 클수록 Training Loss가 증가하는 이상 현상 발생. 예를 들어 bsz=1, ga=16와 bsz=16, ga=1의 Training Loss가 다름.
  • 원인: Cross Entropy(CE) Loss의 Normalizer가 제대로 작동하지 않음. 이는 L2 Norm이 증가하는 결과를 초래.

 

2. 문제 원인 분석

  1. Cross Entropy Loss 정상화 문제:
    • CE Loss의 정규화 과정에서 Gradient Accumulation 단계별로 잘못된 분모(Denominator)를 사용.
    • 이는 다양한 시퀀스 길이(Sequence Length) 조건에서 특히 심각한 영향을 줌.
  2. L2 Norm의 이상 증가:
    • GA 크기가 증가할수록 L2 Norm이 커져서 Loss 계산에 영향을 미침.
    • 실험 결과 bsz=16, ga=16과 같은 조건에서 L2 Norm이 10배 이상 증가.
  3. 수학적 근거:
    • GA는 수학적으로 Batch 내 모든 Gradient의 합(또는 평균)과 동일해야 함.
    • 그러나 CE Loss 정상화 문제로 인해 GA가 Full Batch와 일치하지 않음.

 

3. 해결 과정

  1. 문제 재현:
    • 동일한 조건에서 다양한 GA 크기로 실험해 문제를 재현.
    • bsz=16, ga=1 대비 bsz=1, ga=16의 Loss 차이를 관찰.
  2. CE Loss 비정규화(Un-normalized):
    • CE Loss를 정규화 없이 계산한 경우, 모든 Training Loss가 일치.
  3. 정규화 재설정:
    • CE Loss의 분모를 정확히 계산하도록 수정.
    • GA 단계별로 모든 Gradient를 합산한 뒤 전체 데이터 수로 나눔.
  4. 수정 결과 검증:
    • 수정 후 모든 Training Loss 곡선(bsz=1, ga=16, bsz=16, ga=1)이 완벽히 일치.

 

4. 실험 결과

  1. L2 Norm 비교:
    • 수정 전: GA 크기가 커질수록 L2 Norm 증가.
    • 수정 후: L2 Norm 일정.
  2. Training Loss 곡선 비교:
    • 수정 전: GA 크기에 따라 Loss 불일치.
    • 수정 후: 모든 GA 크기에서 Loss 곡선 일치.
  3. 업데이트된 라이브러리:
    • unsloth 라이브러리에 수정 사항 반영.
    • pip install --upgrade --no-cache-dir unsloth로 업데이트 가능.

 

5. Gradient Accumulation의 수학적 원리

  1. Batch와 GA 관계:
    • Batch 크기를 N개로 설정:
      • Batch N = Batch K + GA N/K + Steps N/K
    • 예시:
      • Batch 4 = Batch 2 + GA 2 + Steps 2배
      • Batch 8 = Batch 4 + GA 2 + Steps 2배 = Batch 2 + GA 4 + Steps 4배
  2. GA 작동 방식:
    • Batch 내부의 Gradient는 모두 더한 뒤 평균을 계산.
    • GA를 통해 여러 Step에서 Gradient를 누적한 뒤 한 번 업데이트.
  3. 업데이트 횟수 감소:
    • GA 크기가 증가하면 업데이트 횟수 감소.
    • 하지만 실제 Batch 크기와 동일한 효과를 냄.

 

6. 결론 및 영향

  • 문제의 영향:
    • 모든 Gradient Accumulation을 사용하는 프레임워크(DP, DDP 포함)에 동일한 문제가 존재.
    • 특히 다양한 시퀀스 길이를 처리하는 모델에서 손실(Loss) 계산 오류.
  • 해결 방법의 중요성:
    • CE Loss 정상화 수정으로 모든 GA 크기에서 Full Batch Training과 동일한 결과 보장.
    • 멀티 GPU 학습 및 대규모 모델 학습 시 안정성 향상.
  • 다음 단계:
    • 관련 라이브러리와 워크플로 업데이트 권장.
    • 다양한 조건에서 추가 테스트 및 최적화 진행.