머신러닝 & 딥러닝/딥러닝

Stable Diffusion 3.5 Medium 모델 학습 최적화 가이드

Haru_29 2024. 11. 23. 18:39

Stable Diffusion 3.5 medium 모델 구조 

1. 핵심 학습 방법 최적화

노이즈 샘플링 전략

  • Logit-normal 분포를 사용한 노이즈 샘플링 구현
def logit_normal_sampling(t, m=0.0, s=1.0):
    u = torch.randn_like(t) * s + m
    return torch.sigmoid(u)
  • 파라미터 설정:
    • 위치 파라미터 m = 0.0
    • 스케일 파라미터 s = 1.0
  • 중간 타임스텝에 더 많은 학습 가중치 부여
  • EDM이나 기존 확산 모델보다 우수한 성능 달성

손실 함수 최적화

  • Conditional Flow Matching (CFM) 손실 함수 사용:
     
def compute_cfm_loss(model_output, noise, t):
    # v_θ(z, t) - u_t(z|ε) 계산
    target = compute_velocity_target(noise, t)
    return F.mse_loss(model_output, target)
  • 검증 시 8개의 균등 분할된 타임스텝으로 평가
  • w_t = -1/2 * λ'_t/b^2_t 가중치 적용

2. 아키텍처 최적화

멀티모달 설계

  • 텍스트와 이미지 모달리티를 위한 분리된 가중치 스트림 구현
     
class MultiModalTransformer(nn.Module):
    def __init__(self, depth):
        self.text_stream = TransformerStream(depth)
        self.image_stream = TransformerStream(depth)
        self.cross_attention = CrossAttention(depth)
  • 모델 크기에 따른 스케일링:
    • 히든 크기 = 64 * depth
    • MLP 크기 = 256 * depth
    • 어텐션 헤드 수 = depth

학습 안정성 향상

  • Q/K 정규화를 위한 RMSNorm 추가:
     
class RMSNormAttention(nn.Module):
    def __init__(self):
        self.rms_norm = RMSNorm()
        
    def forward(self, q, k):
        q = self.rms_norm(q)
        k = self.rms_norm(k)
        return torch.matmul(q, k.transpose(-2, -1))
  • AdamW 옵티마이저 설정:
    • epsilon = 1e-15
    • bfloat16 mixed precision 사용

3. 해상도 및 위치 인코딩

고해상도 학습

  • 256x256 사전학습 후 고해상도로 확장
  • 해상도 기반 타임스텝 시프팅 적용:
     
def compute_resolution_timestep(t, m, n):
    alpha = math.sqrt(m/n)
    return alpha * t / (1 + (alpha - 1) * t)
  • 1024x1024 해상도에서는 α = 3.0 권장

위치 인코딩 조정

  • 가변 종횡비 처리:
def create_position_grid(h_max, w_max, S):
    pos_h = torch.linspace(-h_max/2, h_max/2, h_max) * (256/S)
    pos_w = torch.linspace(-w_max/2, w_max/2, w_max) * (256/S)
    return torch.meshgrid(pos_h, pos_w)

4. 학습 프로세스 최적화

데이터 파이프라인

  • 사전 계산 및 캐싱:
     
class DataPreprocessor:
    def __init__(self):
        self.vae = AutoencoderKL.from_pretrained(...)
        self.clip = CLIPModel.from_pretrained(...)
        self.t5 = T5EncoderModel.from_pretrained(...)
        
    def precompute(self, image, text):
        latents = self.vae.encode(image)
        clip_emb = self.clip.encode_text(text)
        t5_emb = self.t5(text).last_hidden_state
        return latents, clip_emb, t5_emb
  • 캡션 비율:
    • 원본 캡션 50%
    • 합성 캡션 50%

학습 스케줄

  • 배치 크기: 4096
  • 학습률: 1e-4 (1000스텝 웜업)
  • 텍스트 인코더 드롭아웃: 46.3%
  • EMA 감쇠율: 0.99 (100스텝마다 갱신)

5. 메모리 및 연산 최적화

메모리 효율성

class MemoryEfficientTrainer:
    def __init__(self):
        self.scaler = torch.cuda.amp.GradScaler()
        
    def training_step(self, batch):
        with torch.cuda.amp.autocast():
            loss = self.compute_loss(batch)
        self.scaler.scale(loss).backward()
        self.scaler.step(self.optimizer)
        self.scaler.update()

분산 처리

  • 다중 GPU 환경에서의 배치 분산
  • 어텐션 연산 분산
  • 텍스트/이미지 스트림 간 메모리 밸런싱

6. 평가 및 모니터링

메트릭 추적

class MetricsTracker:
    def __init__(self):
        self.metrics = {
            'val_loss': [],
            'clip_score': [],
            'fid': []
        }
        
    def update(self, outputs):
        for k, v in outputs.items():
            self.metrics[k].append(v)

품질 검증

  • GenEval 메트릭을 통한 텍스트 이해도 평가
  • 인간 평가를 통한 시각적 품질 검증
  • 타이포그래피 및 공간 추론 능력 확인
  • 프롬프트 준수도 및 구성 품질 모니터링