머신러닝 & 딥러닝/딥러닝

Stable Diffusion 3.5 Large Fine-tuning Tutorial 정리

Haru_29 2024. 11. 24. 20:07

참고

https://stabilityai.notion.site/Stable-Diffusion-3-5-Large-Fine-tuning-Tutorial-11a61cdcd1968027a15bdbd7c40be8c6

1. 환경 설정

Python 환경 설정

# Python 3.11.6 이상 필요
python -m venv .venv
source .venv/bin/activate
pip install -U poetry pip

# Poetry 설정
poetry config virtualenvs.create false
poetry install

CUDA 요구사항

  • CUDA 12.2 이상 권장
  • 낮은 버전 사용시 수정된 설치 방법 필요:
# CUDA 12.1 예시
pip install torch==2.4.1+cu121 torchvision==0.19.1+cu121 torchaudio==2.4.1+cu121 --index-url https://download.pytorch.org/whl/cu121
 

필수 의존성

dependencies = [
    "diffusers",
    "transformers==4.45.1",
    "accelerate==0.34.2",
    "bitsandbytes==0.44.1",
    "wandb==0.18.2"
]

2. 모델 구성

기본 모델 설정

model_path = "stabilityai/stable-diffusion-3.5-large"

텍스트 인코더 구성

  • CLIP L/14
  • CLIP G/14
  • T5 XXL (4.7GB VRAM 사용)

VAE 설정

vae_settings = {
    "channels": 16,  # 최적의 품질
    "downsampling_factor": 8
}

3. 데이터셋 준비

이미지 전처리

image_settings = {
    "base_resolution": 1024,
    "aspect_ratios": [
        (1024, 1024), 
        (1152, 896),
        (896, 1152),
        # ... 기타 지원 해상도
    ]
}

캡션 처리

caption_settings = {
    "original_ratio": 0.5,  # 원본 캡션 비율
    "synthetic_ratio": 0.5,  # 합성 캡션 비율
    "trigger_word": "k4s4"  # 트리거 워드 예시
}

데이터 증강

augmentation_settings = {
    "repeats": 5,  # 데이터 반복 횟수
    "token_shuffle": True,  # 토큰 셔플링 활성화
    "keep_tokens": 2  # 앞부분 유지할 토큰 수
}

4. 학습 설정

LoRA 기본 설정

{
    "model_type": "lora",
    "model_family": "sd3",
    "learning_rate": 1.05e-3,
    "batch_size": 6,
    "lora_rank": 768,
    "lora_alpha": 768,
    "max_train_steps": 24000,
    "checkpointing_steps": 400,
    "validation_steps": 200
}

메모리 최적화 설정 (24GB VRAM)

{
    "batch_size": 1,
    "lora_rank": 64,
    "gradient_checkpointing": true,
    "mixed_precision": "bf16"
}

스케줄러 설정

{
    "lr_scheduler": "cosine",
    "lr_warmup_steps": 2400,
    "snr_gamma": 5
}

5. 학습 단계별 가이드

1단계: 낮은 해상도 사전 학습

{
    "resolution": 256,
    "validation_resolution": 256,
    "training_steps": 10000
}

2단계: 고해상도 미세조정

{
    "resolution": 1024,
    "validation_resolution": 1024,
    "training_steps": 14000,
    "timestep_shift_factor": 3.0  # 고해상도 timestep 조정
}

3단계: DPO (Direct Preference Optimization)

dpo_settings = {
    "learning_rate": 5e-5,
    "training_steps": 2000,
    "lora_rank": 128
}

6. 고급 최적화 기법

레이어별 학습 제어

layer_control = {
    "detail_layers": range(30, 38),  # 세부 디테일 제어
    "composition_layers": range(12, 25),  # 구도 제어
    "context_layers": "context_only"  # 컨텍스트 레이어만 학습
}

마스크 손실 학습 설정

{
    "mask_settings": {
        "edge_refinement": true,
        "min_mask_size": 2000,
        "mask_blur": 3
    }
}

APG (Adaptive Projected Guidance) 설정

apg_settings = {
    "eta": 0.3,
    "norm_threshold": 1.0,
    "use_momentum": true,
    "momentum": 0.9
}

7. 검증 및 모니터링

검증 설정

{
    "validation_prompt": "촬영 구도와 상황 설명이 포함된 프롬프트",
    "validation_negative_prompt": "blurry, cropped, ugly",
    "validation_guidance_scale": 7.5,
    "validation_inference_steps": 30
}

체크포인트 평가

evaluation_metrics = {
    "fid_score": True,
    "clip_score": True,
    "human_evaluation": {
        "prompt_following": True,
        "aesthetic_quality": True,
        "text_accuracy": True
    }
}

8. 특수 케이스 최적화

전문 사진 데이터셋

photo_settings = {
    "freeze_layers": range(30, 38),  # 후반부 레이어 고정
    "learning_rate": 5e-5
}

해부학적 정확도 유지

anatomy_settings = {
    "trainable_layers": ["attention_layers"],
    "frozen_layers": ["linear_layers"]
}

소규모 데이터셋

small_dataset_settings = {
    "trainable_blocks": ["context_blocks"],
    "frozen_blocks": ["main_blocks"],
    "caption_dropout": 0.1
}

9. 보조 스크립트

마스크 생성 스크립트

mask_generation_command = """
python generate_dataset_masks.py \
    --input_dir /images/input \
    --output_dir /images/output \
    --text_input "person" \
    --edge_refinement \
    --min_mask_size 2000 \
    --mask_blur 3
"""

체크포인트 평가 스크립트

evaluation_command = """
python evaluate_checkpoints.py \
    --checkpoint_dir /path/to/checkpoints \
    --output_dir /path/to/results \
    --batch_size 1 \
    --validation_prompts prompts.txt
"""

10. 문제 해결 가이드

VRAM 부족 해결

  1. 텍스트 인코더 고정 (10.4GB 절약)
  2. LoRA 랭크 감소 (768 → 64)
  3. 배치 사이즈 감소 (6 → 1)
  4. 그래디언트 체크포인팅 활성화

학습 불안정성 해결

  1. 학습률 감소 (1.05e-3 → 9.5e-4)
  2. QK-normalization 적용
  3. 워밍업 스텝 증가
  4. 그래디언트 클리핑 적용

품질 저하 해결

  1. 레이어 선택적 고정
  2. 캡션 품질 향상
  3. APG 스케일링 적용
  4. 검증 주기 조정