Stable Diffusion 3.5 medium 모델 구조
1. 핵심 학습 방법 최적화
노이즈 샘플링 전략
- Logit-normal 분포를 사용한 노이즈 샘플링 구현
def logit_normal_sampling(t, m=0.0, s=1.0):
u = torch.randn_like(t) * s + m
return torch.sigmoid(u)
- 파라미터 설정:
- 위치 파라미터 m = 0.0
- 스케일 파라미터 s = 1.0
- 중간 타임스텝에 더 많은 학습 가중치 부여
- EDM이나 기존 확산 모델보다 우수한 성능 달성
손실 함수 최적화
- Conditional Flow Matching (CFM) 손실 함수 사용:
def compute_cfm_loss(model_output, noise, t):
# v_θ(z, t) - u_t(z|ε) 계산
target = compute_velocity_target(noise, t)
return F.mse_loss(model_output, target)
- 검증 시 8개의 균등 분할된 타임스텝으로 평가
- w_t = -1/2 * λ'_t/b^2_t 가중치 적용
2. 아키텍처 최적화
멀티모달 설계
- 텍스트와 이미지 모달리티를 위한 분리된 가중치 스트림 구현
class MultiModalTransformer(nn.Module):
def __init__(self, depth):
self.text_stream = TransformerStream(depth)
self.image_stream = TransformerStream(depth)
self.cross_attention = CrossAttention(depth)
- 모델 크기에 따른 스케일링:
- 히든 크기 = 64 * depth
- MLP 크기 = 256 * depth
- 어텐션 헤드 수 = depth
학습 안정성 향상
- Q/K 정규화를 위한 RMSNorm 추가:
class RMSNormAttention(nn.Module):
def __init__(self):
self.rms_norm = RMSNorm()
def forward(self, q, k):
q = self.rms_norm(q)
k = self.rms_norm(k)
return torch.matmul(q, k.transpose(-2, -1))
- AdamW 옵티마이저 설정:
- epsilon = 1e-15
- bfloat16 mixed precision 사용
3. 해상도 및 위치 인코딩
고해상도 학습
- 256x256 사전학습 후 고해상도로 확장
- 해상도 기반 타임스텝 시프팅 적용:
def compute_resolution_timestep(t, m, n):
alpha = math.sqrt(m/n)
return alpha * t / (1 + (alpha - 1) * t)
- 1024x1024 해상도에서는 α = 3.0 권장
위치 인코딩 조정
- 가변 종횡비 처리:
def create_position_grid(h_max, w_max, S):
pos_h = torch.linspace(-h_max/2, h_max/2, h_max) * (256/S)
pos_w = torch.linspace(-w_max/2, w_max/2, w_max) * (256/S)
return torch.meshgrid(pos_h, pos_w)
4. 학습 프로세스 최적화
데이터 파이프라인
- 사전 계산 및 캐싱:
class DataPreprocessor:
def __init__(self):
self.vae = AutoencoderKL.from_pretrained(...)
self.clip = CLIPModel.from_pretrained(...)
self.t5 = T5EncoderModel.from_pretrained(...)
def precompute(self, image, text):
latents = self.vae.encode(image)
clip_emb = self.clip.encode_text(text)
t5_emb = self.t5(text).last_hidden_state
return latents, clip_emb, t5_emb
- 캡션 비율:
- 원본 캡션 50%
- 합성 캡션 50%
학습 스케줄
- 배치 크기: 4096
- 학습률: 1e-4 (1000스텝 웜업)
- 텍스트 인코더 드롭아웃: 46.3%
- EMA 감쇠율: 0.99 (100스텝마다 갱신)
5. 메모리 및 연산 최적화
메모리 효율성
class MemoryEfficientTrainer:
def __init__(self):
self.scaler = torch.cuda.amp.GradScaler()
def training_step(self, batch):
with torch.cuda.amp.autocast():
loss = self.compute_loss(batch)
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
분산 처리
- 다중 GPU 환경에서의 배치 분산
- 어텐션 연산 분산
- 텍스트/이미지 스트림 간 메모리 밸런싱
6. 평가 및 모니터링
메트릭 추적
class MetricsTracker:
def __init__(self):
self.metrics = {
'val_loss': [],
'clip_score': [],
'fid': []
}
def update(self, outputs):
for k, v in outputs.items():
self.metrics[k].append(v)
품질 검증
- GenEval 메트릭을 통한 텍스트 이해도 평가
- 인간 평가를 통한 시각적 품질 검증
- 타이포그래피 및 공간 추론 능력 확인
- 프롬프트 준수도 및 구성 품질 모니터링
'머신러닝 & 딥러닝 > 딥러닝' 카테고리의 다른 글
UNet과 Text Encoder의 학습 방법 (0) | 2024.11.25 |
---|---|
Stable Diffusion 3.5 Large Fine-tuning Tutorial 정리 (1) | 2024.11.24 |
Flux perfermence improved by Flash attention3 + Triton (2) | 2024.11.22 |
Flux 모델 최적화를 위한 TorchAO 및 torch.compile() 활용 가이드 (1) | 2024.11.21 |
Langchain 문서(Simple Start) (0) | 2024.11.19 |