Transformer Gradient Accumulation 문제 및 해결 방법
1. 문제 개요
- 문제: Gradient Accumulation(GA)은 수학적으로 Full Batch Training과 동일해야 하지만, 실험 결과 특정 조건에서 Training Loss가 일치하지 않음.
- 현상: GA 크기가 클수록 Training Loss가 증가하는 이상 현상 발생. 예를 들어 bsz=1, ga=16와 bsz=16, ga=1의 Training Loss가 다름.
- 원인: Cross Entropy(CE) Loss의 Normalizer가 제대로 작동하지 않음. 이는 L2 Norm이 증가하는 결과를 초래.
2. 문제 원인 분석
- Cross Entropy Loss 정상화 문제:
- CE Loss의 정규화 과정에서 Gradient Accumulation 단계별로 잘못된 분모(Denominator)를 사용.
- 이는 다양한 시퀀스 길이(Sequence Length) 조건에서 특히 심각한 영향을 줌.
- L2 Norm의 이상 증가:
- GA 크기가 증가할수록 L2 Norm이 커져서 Loss 계산에 영향을 미침.
- 실험 결과 bsz=16, ga=16과 같은 조건에서 L2 Norm이 10배 이상 증가.
- 수학적 근거:
- GA는 수학적으로 Batch 내 모든 Gradient의 합(또는 평균)과 동일해야 함.
- 그러나 CE Loss 정상화 문제로 인해 GA가 Full Batch와 일치하지 않음.
3. 해결 과정
- 문제 재현:
- 동일한 조건에서 다양한 GA 크기로 실험해 문제를 재현.
- bsz=16, ga=1 대비 bsz=1, ga=16의 Loss 차이를 관찰.
- CE Loss 비정규화(Un-normalized):
- CE Loss를 정규화 없이 계산한 경우, 모든 Training Loss가 일치.
- 정규화 재설정:
- CE Loss의 분모를 정확히 계산하도록 수정.
- GA 단계별로 모든 Gradient를 합산한 뒤 전체 데이터 수로 나눔.
- 수정 결과 검증:
- 수정 후 모든 Training Loss 곡선(bsz=1, ga=16, bsz=16, ga=1)이 완벽히 일치.
4. 실험 결과
- L2 Norm 비교:
- 수정 전: GA 크기가 커질수록 L2 Norm 증가.
- 수정 후: L2 Norm 일정.
- Training Loss 곡선 비교:
- 수정 전: GA 크기에 따라 Loss 불일치.
- 수정 후: 모든 GA 크기에서 Loss 곡선 일치.
- 업데이트된 라이브러리:
- unsloth 라이브러리에 수정 사항 반영.
- pip install --upgrade --no-cache-dir unsloth로 업데이트 가능.
5. Gradient Accumulation의 수학적 원리
- Batch와 GA 관계:
- Batch 크기를 N개로 설정:
- Batch N = Batch K + GA N/K + Steps N/K
- 예시:
- Batch 4 = Batch 2 + GA 2 + Steps 2배
- Batch 8 = Batch 4 + GA 2 + Steps 2배 = Batch 2 + GA 4 + Steps 4배
- Batch 크기를 N개로 설정:
- GA 작동 방식:
- Batch 내부의 Gradient는 모두 더한 뒤 평균을 계산.
- GA를 통해 여러 Step에서 Gradient를 누적한 뒤 한 번 업데이트.
- 업데이트 횟수 감소:
- GA 크기가 증가하면 업데이트 횟수 감소.
- 하지만 실제 Batch 크기와 동일한 효과를 냄.
6. 결론 및 영향
- 문제의 영향:
- 모든 Gradient Accumulation을 사용하는 프레임워크(DP, DDP 포함)에 동일한 문제가 존재.
- 특히 다양한 시퀀스 길이를 처리하는 모델에서 손실(Loss) 계산 오류.
- 해결 방법의 중요성:
- CE Loss 정상화 수정으로 모든 GA 크기에서 Full Batch Training과 동일한 결과 보장.
- 멀티 GPU 학습 및 대규모 모델 학습 시 안정성 향상.
- 다음 단계:
- 관련 라이브러리와 워크플로 업데이트 권장.
- 다양한 조건에서 추가 테스트 및 최적화 진행.
'머신러닝 & 딥러닝 > 딥러닝' 카테고리의 다른 글
UNet과 Text Encoder의 학습 방법 (0) | 2024.11.25 |
---|---|
Stable Diffusion 3.5 Large Fine-tuning Tutorial 정리 (1) | 2024.11.24 |
Stable Diffusion 3.5 Medium 모델 학습 최적화 가이드 (0) | 2024.11.23 |
Flux perfermence improved by Flash attention3 + Triton (2) | 2024.11.22 |
Flux 모델 최적화를 위한 TorchAO 및 torch.compile() 활용 가이드 (1) | 2024.11.21 |