Haru's 개발 블로그

[논문 리뷰] High-Resolution Image Synthesis with Latent Diffusion Models (LDM) 논문 리뷰 본문

논문 리뷰/Diffusion

[논문 리뷰] High-Resolution Image Synthesis with Latent Diffusion Models (LDM) 논문 리뷰

Haru_29 2024. 4. 20. 10:33

Link

 

코드 리뷰

 

AutoEncoder.py

AutoEncoder를 적대적 방식으로 훈련하여 패치 기반 판별기가 재구성 이미지 D(E(x))와 원본 이미지를 구별하도록 최적화를 진행합니다.

Discriminator와 loss function의 구현은 contperceptual.py에서 볼 수 있습니다.

 

정규화 방식은 두 가지가 있는데 VQ 방식의 경우 양자화 과정에서 발생하는 loss를 줄여 정규화를 진행합니다.

class VQModel(pl.LightningModule):

    def encode(self, x):
        h = self.encoder(x)
        h = self.quant_conv(h)
        quant, emb_loss, info = self.quantize(h)
        return quant, emb_loss, info

    def decode(self, quant):
        quant = self.post_quant_conv(quant)
        dec = self.decoder(quant)
        return dec

    def forward(self, input, return_pred_indices=False):
        quant, diff, (_,_,ind) = self.encode(input)
        dec = self.decode(quant)
        if return_pred_indices:
            return dec, diff, ind
        return dec, diff
        
    def training_step(self, batch, batch_idx, optimizer_idx):
        x = self.get_input(batch, self.image_key)
        xrec, qloss, ind = self(x, return_pred_indices=True)

        if optimizer_idx == 1:
            # discriminator
            discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
                                            last_layer=self.get_last_layer(), split="train")
            self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
            return discloss

 

KL 방식의 경우

정규 분포와의 Kullback-Leibler-term을 이용하여 정규화를 진행합니다.

class DiagonalGaussianDistribution

class AutoencoderKL(pl.LightningModule):

    def encode(self, x):
        h = self.encoder(x)
        moments = self.quant_conv(h)
        posterior = DiagonalGaussianDistribution(moments)
        return posterior

    def decode(self, z):
        z = self.post_quant_conv(z)
        dec = self.decoder(z)
        return dec

    def forward(self, input, sample_posterior=True):
        posterior = self.encode(input)
        if sample_posterior:
            z = posterior.sample()
        else:
            z = posterior.mode()
        dec = self.decode(z)
        return dec, posterior
        
    def training_step(self, batch, batch_idx, optimizer_idx):
        inputs = self.get_input(batch, self.image_key)
        reconstructions, posterior = self(inputs)

        if optimizer_idx == 1:
            # train the discriminator
            discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
                                                last_layer=self.get_last_layer(), split="train")

            self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
            self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
            return discloss

 

Implementations of τθ

 

 

BERT tokenizer, CLIP embedder 등등 각종 전처리 모듈입니다.

modules.py

 

이때 τθ는 트랜스포머로 구현합니다.

x_transformer.py

class TransformerWrapper(nn.Module):
    def __init__(
            self,
            *,
            num_tokens,
            max_seq_len,
            attn_layers,
            emb_dim=None,
            max_mem_len=0.,
            emb_dropout=0.,
            num_memory_tokens=None,
            tie_embedding=False,
            use_pos_emb=True
    ):
        super().__init__()
        assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'

        dim = attn_layers.dim
        emb_dim = default(emb_dim, dim)

        self.max_seq_len = max_seq_len
        self.max_mem_len = max_mem_len
        self.num_tokens = num_tokens

        self.token_emb = nn.Embedding(num_tokens, emb_dim)
        self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
                    use_pos_emb and not attn_layers.has_pos_emb) else always(0)
        self.emb_dropout = nn.Dropout(emb_dropout)

        self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
        self.attn_layers = attn_layers
        self.norm = nn.LayerNorm(dim)

        self.init_()

        self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()

    def init_(self):
        nn.init.normal_(self.token_emb.weight, std=0.02)

    def forward(
            self,
            x,
            return_embeddings=False,
            mask=None,
            return_mems=False,
            return_attn=False,
            mems=None,
            **kwargs
    ):
        b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
        x = self.token_emb(x)
        x += self.pos_emb(x)
        x = self.emb_dropout(x)

        x = self.project_emb(x)

        x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
        x = self.norm(x)

        mem, x = x[:, :num_mem], x[:, num_mem:]

        out = self.to_logits(x) if not return_embeddings else x
        
        return out

 

Cross attention

조건을 ablated U-Net에 주입하기 위해 U-Net의 self attention 계층을 T Block으로 구성된 트랜스포머로 교체합니다.

#T Block
class BasicTransformerBlock(nn.Module):
    def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
        super().__init__()
        self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout)  # is a self-attention
        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
        self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
                                    heads=n_heads, dim_head=d_head, dropout=dropout)  # is self-attn if context is none
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)
        self.checkpoint = checkpoint

    def forward(self, x, context=None):
        return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)

    def _forward(self, x, context=None):
        x = self.attn1(self.norm1(x)) + x
        x = self.attn2(self.norm2(x), context=context) + x
        x = self.ff(self.norm3(x)) + x
        return x


class SpatialTransformer(nn.Module):
    """
    Transformer block for image-like data.
    First, project the input (aka embedding)
    and reshape to b, t, d.
    Then apply standard transformer action.
    Finally, reshape to image
    """
    def __init__(self, in_channels, n_heads, d_head,
                 depth=1, dropout=0., context_dim=None):
        super().__init__()
        self.in_channels = in_channels
        inner_dim = n_heads * d_head
        self.norm = Normalize(in_channels)

        self.proj_in = nn.Conv2d(in_channels,
                                 inner_dim,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)

        self.transformer_blocks = nn.ModuleList(
            [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
                for d in range(depth)]
        )

        self.proj_out = zero_module(nn.Conv2d(inner_dim,
                                              in_channels,
                                              kernel_size=1,
                                              stride=1,
                                              padding=0))

    def forward(self, x, context=None):
        # note: if no context is given, cross-attention defaults to self-attention
        b, c, h, w = x.shape
        x_in = x
        x = self.norm(x)
        x = self.proj_in(x)
        x = rearrange(x, 'b c h w -> b (h w) c')
        for block in self.transformer_blocks:
            x = block(x, context=context)
        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
        x = self.proj_out(x)
        return x + x_in

 

Abstract

품질과 유연성을 유지하면서 제한된 계산 리소스에 대한 Diffusion Model training을 가능하게 하기 위해서 autoencoder을 사용 + cross attention을 도입한 latent diffusion models (LDMs)을 제안합니다. 이를 통하여 다양한 조건과 인페인팅, 초해상도 등 다양한 작업니 가능합니다.

 

Introduction

Diffusion Model의 문제점

  • 막대한 계산 비용 필요함
  • 많은 순차적 단계 때문에 훈련된 모델을 평가하는 데도 많은 시간이 소모됨

데이터 공간과 지각적으로 동일한 저차원 표현 공간(더 낮은 차원의 매니폴드에서도 충분히 이미지의 정보를 담을 수 있음)을 제공하는 autoencoder를 훈련하여 더 나은 학습된 잠재 공간에서 확산 모델을 훈련하며 이를 통한 결과 모델 클래스를 LDM(Latent Diffusion Models)이라고 부릅니다.

 

이 접근 방식은 장점은 인코딩 단계를 한 번만 훈련하면 여러 다른 확산 모델 훈련에 재사용하거나 완전히 다른 작업에도 사용할 수 있다는 것입니다.

 

Method

훈련 단계에서 손실 항을 언더샘플링하여 지각적으로 관련 없는 세부 사항을 무시할 수 있지만 여전히 평가 단계에서 많은 비용이 드므로 생성 학습 단계에서 압축을 명시적으로 분리하여 이러한 단점을 피합니다.

이를 위해 autoencoder를 사용하고 이 접근 방식의 장점 :

  • 저차원 공간에서 작업하므로 효율적임
  • 확산 모델의 귀납적 편향을 여전히 이용
  • 학습된 인코더 잠재공간을 다른 작업에 사용할 수 있음

Perceptual Image Compression

지각 압축 모델은 지각 손실, 패치 기반, 적대적 목표의 조합으로 훈련된 autoencoder로 구성된다. 이렇게 하면 L1, L2와 같은 픽셀 기반 손실로 인해 발생하는 흐릿함을 피할 수 있습니다.

패치 기반 판별자 Dψ가 재구성 이미지 D(E(x))와 원본 이미지를 구별하도록 적대적 방식으로 훈련하고 지각 손실항과 정규화항 추가를 진행합니다.

 

잠재공간 정규화 방법으로는 정규 분포와 비교하여 KL 페널티를 부여하는 KL-reg와 VQGAN과 같이 벡터 양자화를 이용하는 VG-reg 둘 중 하나를 사용함. 재구성 품질을 위해 정규화는 매우 작은 가중치를 적용합니다.

 

Latent Diffusion Models

일반적인 확산 모델의 목적 함수 :

효율적인 저차원 인코딩 잠재 공간에 접근할 수 있도록 변경 :

신경 백본은 time-conditional UNet(?)으로 구현.

 

Conditioning Mechanisms

조건 y를 사전 처리하는 도메인 특정 인코더 τθ 도입합니다.이때 다양한 입력 조건을 반영하기 위해 cross attention 사용합니다.  (ϕi(zt)는 U-Net의 flattened 중간 표현)

목적 함수 :

 

Experiments

 

레이아웃 조건

 

Semantic map 조건

 

초해상도

 

인페인팅

Comments