머신러닝 & 딥러닝/딥러닝

[AI 기초 다지기] Set2Set 논문 분석 및 코드 구현

Haru_29 2024. 11. 14. 20:30

Order Matters: Sequence to Sequence for Sets의 완벽 가이드

목차

  1. 연구 배경 및 동기
  2. 이론적 프레임워크
  3. 아키텍처 상세 설명
  4. 입력 집합 처리
  5. 출력 집합 처리
  6. 실험 및 결과 분석
  7. 구현 가이드
  8. 결론 

1. 연구 배경 및 동기

1.1 기존 Sequence-to-Sequence의 한계

  • 고정된 입력/출력 순서 가정
  • 비순차적 데이터 처리의 어려움
  • 조합적 문제에서의 제한사항

1.2 주요 해결 과제

challenges = {
    "입력": {
        "가변 길이": "입력 집합의 크기가 동적",
        "순서 독립성": "입력 순서에 불변한 표현 필요",
        "계산 효율성": "O(n²) 이하의 복잡도 목표"
    },
    "출력": {
        "순서 최적화": "최적의 출력 순서 결정",
        "다중 정답": "여러 유효한 출력 순서 존재",
        "학습 효율성": "효율적인 학습 알고리즘 필요"
    }
}

1.3 접근 방법의 혁신성

  1. Read-Process-Write 아키텍처 도입
  2. 순서 탐색을 포함한 학습 방법
  3. 주목 메커니즘의 새로운 활용

2. 이론적 프레임워크

2.1 수학적 정의

class SetEncodingFramework:
    def __init__(self):
        self.encoder = SetEncoder()
        self.decoder = OrderAwareDecoder()
        
    def encode_set(self, input_set):
        """
        입력 집합을 순서 불변 표현으로 인코딩
        
        Parameters:
        - input_set: 입력 집합 {x₁, x₂, ..., xₙ}
        
        Returns:
        - encoding: 순서 불변 벡터 표현
        """
        return self.encoder(input_set)
        
    def conditional_probability(self, output_seq, input_set):
        """
        조건부 확률 P(Y|X) 계산
        
        P(Y|X) = ∏ᵢ P(yᵢ|y₁,...,yᵢ₋₁,X)
        """
        return self.decoder(output_seq, self.encode_set(input_set))

2.2 주목 메커니즘 상세

class AttentionMechanism:
    def __init__(self, hidden_size):
        self.W1 = nn.Linear(hidden_size, hidden_size)
        self.W2 = nn.Linear(hidden_size, hidden_size)
        self.V = nn.Linear(hidden_size, 1)
        
    def forward(self, query, keys, values):
        # 에너지 계산
        scores = self.V(torch.tanh(
            self.W1(query.unsqueeze(1)) + self.W2(keys)
        ))
        
        # 주목 가중치
        weights = F.softmax(scores, dim=1)
        
        # 컨텍스트 벡터
        context = torch.sum(weights * values, dim=1)
        return context, weights

3. 아키텍처 상세 설명

3.1 Read-Process-Write 컴포넌트

3.1.1 Reader 모듈

class Reader(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Linear(input_dim, hidden_dim)
        self.normalizer = nn.LayerNorm(hidden_dim)
        
    def forward(self, input_set):
        # 각 입력 원소를 임베딩
        embeddings = [self.embedding(x) for x in input_set]
        # 정규화
        embeddings = [self.normalizer(e) for e in embeddings]
        return embeddings

3.1.2 Processor 모듈

class Processor(nn.Module):
    def __init__(self, hidden_dim, num_steps):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_steps = num_steps
        self.attention = MultiHeadAttention(hidden_dim)
        self.lstm = nn.LSTM(hidden_dim, hidden_dim)
        
    def forward(self, memories):
        state = None
        for _ in range(self.num_steps):
            # 주목 메커니즘 적용
            context = self.attention(memories)
            # LSTM 처리
            output, state = self.lstm(context, state)
        return output, memories

3.1.3 Writer 모듈

class Writer(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super().__init__()
        self.attention = AttentionMechanism(hidden_dim)
        self.output_proj = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, query, memories):
        # 주목 기반 디코딩
        context, weights = self.attention(query, memories)
        # 출력 생성
        output = self.output_proj(context)
        return output, weights

3.2 순서 최적화 알고리즘

class OrderOptimizer:
    def __init__(self, model, temperature=1.0):
        self.model = model
        self.temperature = temperature
        
    def sample_order(self, X, Y):
        """
        출력 순서를 확률적으로 샘플링
        """
        with torch.no_grad():
            logits = self.model(X)
            probs = F.softmax(logits / self.temperature, dim=-1)
            return torch.multinomial(probs, num_samples=1)
    
    def train_step(self, X, Y):
        """
        순서 탐색을 포함한 학습 단계
        """
        # 여러 가능한 순서에 대해 확률 계산
        orders = self.generate_candidate_orders(Y)
        losses = []
        
        for order in orders:
            # 각 순서에 대한 손실 계산
            Y_ordered = self.reorder(Y, order)
            loss = self.model.compute_loss(X, Y_ordered)
            losses.append(loss)
            
        # 최적의 순서 선택
        best_order = orders[torch.argmin(torch.tensor(losses))]
        
        # 모델 업데이트
        self.model.train_with_order(X, Y, best_order)

4. 입력 집합 처리

4.1 집합 임베딩 전략

def generate_sorting_dataset(n_samples, seq_length):
    """
    정렬 문제를 위한 데이터셋 생성
    """
    X = []  # 입력 시퀀스
    Y = []  # 정렬된 출력
    
    for _ in range(n_samples):
        # 랜덤 시퀀스 생성
        seq = np.random.uniform(0, 1, seq_length)
        # 정렬된 인덱스
        sorted_idx = np.argsort(seq)
        
        X.append(seq)
        Y.append(sorted_idx)
        
    return np.array(X), np.array(Y)

4.2 정렬 실험 세부 사항

4.2.1 데이터셋 생성

def generate_sorting_dataset(n_samples, seq_length):
    """
    정렬 문제를 위한 데이터셋 생성
    """
    X = []  # 입력 시퀀스
    Y = []  # 정렬된 출력
    
    for _ in range(n_samples):
        # 랜덤 시퀀스 생성
        seq = np.random.uniform(0, 1, seq_length)
        # 정렬된 인덱스
        sorted_idx = np.argsort(seq)
        
        X.append(seq)
        Y.append(sorted_idx)
        
    return np.array(X), np.array(Y)

4.2.2 성능 평가 메트릭

class SortingMetrics:
    @staticmethod
    def accuracy(pred, true):
        """완벽한 정렬 정확도"""
        return np.mean(np.all(pred == true, axis=1))
    
    @staticmethod
    def kendall_tau(pred, true):
        """순서 상관계수"""
        scores = []
        for p, t in zip(pred, true):
            score, _ = stats.kendalltau(p, t)
            scores.append(score)
        return np.mean(scores)

5. 출력 집합 처리

5.1 출력 순서 최적화

5.1.1 사전 학습 단계

def pretrain_uniform(self, n_steps=1000):
    """균일 분포로 사전 학습"""
    for _ in range(n_steps):
        batch = self.get_batch()
        for X, Y in batch:
            # 무작위 순서로 Y 재배열
            random_order = np.random.permutation(len(Y))
            Y_shuffled = Y[random_order]
            
            # 학습 수행
            loss = self.train_step(X, Y_shuffled)
            self.optimizer.step()

5.1.2 최적 순서 탐색

class OrderSearch:
    def __init__(self, model):
        self.model = model
        
    def beam_search(self, X, beam_size=10):
        """빔 탐색으로 최적 순서 탐색"""
        sequences = [[list(), 0.0]]
        for t in range(len(X)):
            candidates = []
            for seq, score in sequences:
                if len(seq) == len(X):
                    candidates.append((seq, score))
                    continue
                    
                # 가능한 다음 원소들
                logits = self.model.predict_next(X, seq)
                next_probs = F.softmax(logits, dim=-1)
                
                # 상위 beam_size개 후보 선택
                values, indices = torch.topk(next_probs, beam_size)
                for value, idx in zip(values, indices):
                    candidates.append((seq + [idx.item()],
                                    score - torch.log(value).item()))
                    
            # 상위 beam_size개 시퀀스 선택
            sequences = sorted(candidates, key=lambda x: x[1])[:beam_size]
            
        return sequences[0][0]  # 최적 순서 반환

6. 실험 및 결과 분석

6.1 벤치마크 결과 상세 분석

6.1.1 언어 모델링 실험

모델 설정Perplexity학습 시간메모리 사용량

자연 순서 86 24h 8GB
역순 86 24h 8GB
3단어 96 26h 8GB

 

6.1.2 정렬 성능 분석

모델N=5N=10N=15

Ptr-Net 90% 28% 4%
RPW (P=1) 92% 44% 5%
RPW (P=5) 94% 57% 4%
RPW (P=10) 94% 50% 10%

6.2 성능 영향 요인 분석

6.2.1 처리 단계 수의 영향

def analyze_processing_steps():
    results = {}
    for steps in [1, 5, 10, 15, 20]:
        model = RPWModel(processing_steps=steps)
        acc = evaluate_model(model)
        results[steps] = {
            'accuracy': acc,
            'training_time': measure_training_time(model),
            'memory_usage': measure_memory_usage(model)
        }
    return results

6.2.2 주목 메커니즘 분석

def attention_analysis(model, input_set):
    """주목 가중치 시각화 및 분석"""
    attentions = []
    for step in range(model.num_steps):
        # 각 처리 단계에서의 주목 가중치 수집
        attention = model.get_attention_weights(step)
        attentions.append(attention)
    
    # 시각화 및 통계 분석
    visualize_attention_patterns(attentions)
    analyze_attention_statistics(attentions)

7. 구현 가이드

7.1 모델 구현 핵심 사항

### 7.1 모델 구현 핵심 사항

class CompleteRPWModel(nn.Module):
    def __init__(self, 
                 input_dim, 
                 hidden_dim,
                 num_processing_steps,
                 num_heads=8):
        super().__init__()
        
        # 컴포넌트 초기화
        self.reader = Reader(input_dim, hidden_dim)
        self.processor = Processor(hidden_dim, num_processing_steps)
        self.writer = Writer(hidden_dim, input_dim)
        self.construct_attention(hidden_dim, num_heads)
        
    def construct_attention(self, hidden_dim, num_heads):
        """다중 헤드 주목 메커니즘 구성"""
        self.attention_heads = nn.ModuleList([
            AttentionHead(hidden_dim // num_heads)
            for _ in range(num_heads)
        ])
    
    def forward(self, input_set):
        # 읽기 단계
        memories = self.reader(input_set)
        
        # 처리 단계
        processed, attention_weights = self.processor(memories)
        
        # 쓰기 단계
        outputs = []
        hidden = None
        
        for step in range(len(input_set)):
            output, hidden = self.writer(
                processed,
                attention_weights,
                hidden
            )
            outputs.append(output)
            
        return torch.stack(outputs)
        
    def compute_loss(self, pred, target, order=None):
        """손실 함수 계산"""
        if order is not None:
            target = self.reorder_target(target, order)
            
        return F.cross_entropy(
            pred.view(-1, pred.size(-1)),
            target.view(-1)
        )

 

7.2 학습 파이프라인

class TrainingPipeline:
    def __init__(self, 
                 model,
                 optimizer,
                 scheduler=None,
                 order_search=True):
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.order_search = order_search
        
    def train_epoch(self, dataloader):
        """한 에포크 학습"""
        self.model.train()
        total_loss = 0
        
        for batch_idx, (X, Y) in enumerate(dataloader):
            self.optimizer.zero_grad()
            
            if self.order_search:
                # 최적 순서 탐색
                order = self.find_optimal_order(X, Y)
                loss = self.train_with_order(X, Y, order)
            else:
                # 일반 학습
                pred = self.model(X)
                loss = self.model.compute_loss(pred, Y)
            
            # 역전파 및 최적화
            loss.backward()
            # 그래디언트 클리핑
            torch.nn.utils.clip_grad_norm_(
                self.model.parameters(), 
                max_norm=1.0
            )
            self.optimizer.step()
            
            if self.scheduler is not None:
                self.scheduler.step()
            
            total_loss += loss.item()
            
            # 진행상황 로깅
            if batch_idx % 100 == 0:
                self.log_progress(batch_idx, total_loss)
                
        return total_loss / len(dataloader)

    def train_with_order(self, X, Y, order):
        """주어진 순서로 학습 수행"""
        Y_ordered = self.reorder_target(Y, order)
        pred = self.model(X)
        return self.model.compute_loss(pred, Y_ordered)

 

7.3 추론 및 디코딩

class InferenceEngine:
    def __init__(self, model, beam_size=5):
        self.model = model
        self.beam_size = beam_size
        
    def predict(self, input_set):
        """빔 탐색을 사용한 추론"""
        with torch.no_grad():
            # 초기 상태
            memories = self.model.reader(input_set)
            processed = self.model.processor(memories)
            
            # 빔 탐색 수행
            beams = [([], 0.0)]  # (sequence, score)
            
            for step in range(len(input_set)):
                candidates = []
                
                for seq, score in beams:
                    if len(seq) == len(input_set):
                        candidates.append((seq, score))
                        continue
                    
                    # 다음 토큰 예측
                    logits = self.model.writer.predict_next(
                        processed, seq
                    )
                    probs = F.softmax(logits, dim=-1)
                    
                    # 상위 beam_size개 후보 선택
                    values, indices = torch.topk(
                        probs, self.beam_size
                    )
                    
                    for value, idx in zip(values, indices):
                        new_seq = seq + [idx.item()]
                        new_score = score - torch.log(value)
                        candidates.append((new_seq, new_score))
                
                # 상위 beam_size개 시퀀스 유지
                beams = sorted(
                    candidates,
                    key=lambda x: x[1]
                )[:self.beam_size]
            
            return beams[0][0]  # 최적 시퀀스 반환

    def batch_predict(self, input_sets):
        """배치 단위 추론"""
        predictions = []
        for input_set in input_sets:
            pred = self.predict(input_set)
            predictions.append(pred)
        return predictions

 

7.4 모델 평가 및 분석

class ModelAnalyzer:
    def __init__(self, model):
        self.model = model
        self.metrics = {}
        
    def analyze_attention_patterns(self, input_set):
        """주목 패턴 분석"""
        attention_weights = self.model.get_attention_weights(input_set)
        
        analysis = {
            'entropy': self.compute_attention_entropy(attention_weights),
            'sparsity': self.compute_attention_sparsity(attention_weights),
            'coverage': self.compute_attention_coverage(attention_weights)
        }
        
        return analysis
        
    def compute_attention_entropy(self, weights):
        """주목 가중치의 엔트로피 계산"""
        entropy = -torch.sum(
            weights * torch.log(weights + 1e-9),
            dim=-1
        )
        return entropy.mean().item()
        
    def compute_attention_sparsity(self, weights):
        """주목 가중치의 희소성 계산"""
        threshold = 0.1
        active_weights = (weights > threshold).float()
        return active_weights.sum(dim=-1).mean().item()
        
    def compute_attention_coverage(self, weights):
        """입력 요소별 주목 커버리지 계산"""
        coverage = weights.sum(dim=0)
        return coverage.mean().item()

    def visualize_attention(self, input_set):
        """주목 가중치 시각화"""
        weights = self.model.get_attention_weights(input_set)
        
        plt.figure(figsize=(10, 6))
        sns.heatmap(weights.cpu().numpy(), 
                   cmap='viridis',
                   annot=True)
        plt.title('Attention Weights Heatmap')
        plt.xlabel('Input Elements')
        plt.ylabel('Processing Steps')
        plt.show()

 

7.5 하이퍼파라미터 최적화

class HyperparameterOptimizer:
    def __init__(self, 
                 model_class, 
                 param_grid,
                 n_trials=50):
        self.model_class = model_class
        self.param_grid = param_grid
        self.n_trials = n_trials
        self.results = []
        
    def optimize(self, train_data, val_data):
        """하이퍼파라미터 탐색 수행"""
        for trial in range(self.n_trials):
            # 파라미터 샘플링
            params = self.sample_parameters()
            
            # 모델 학습
            model = self.model_class(**params)
            trainer = TrainingPipeline(model)
            val_score = trainer.train_and_validate(
                train_data, 
                val_data
            )
            
            # 결과 저장
            self.results.append({
                'params': params,
                'score': val_score
            })
            
        # 최적 파라미터 반환
        best_trial = max(self.results, 
                        key=lambda x: x['score'])
        return best_trial['params']
        
    def sample_parameters(self):
        """파라미터 공간에서 샘플링"""
        params = {}
        for param_name, param_config in self.param_grid.items():
            if param_config['type'] == 'categorical':
                params[param_name] = np.random.choice(
                    param_config['values']
                )
            elif param_config['type'] == 'uniform':
                params[param_name] = np.random.uniform(
                    param_config['min'],
                    param_config['max']
                )
            elif param_config['type'] == 'log_uniform':
                params[param_name] = np.exp(np.random.uniform(
                    np.log(param_config['min']),
                    np.log(param_config['max'])
                ))
        return params

 

8. 결론 및 향후 연구 방향

8.1 주요 성과

  1. 순서 불변성을 가진 입력 처리 방법 제시
  2. 효율적인 출력 순서 최적화 방법 개발
  3. 다양한 태스크에서의 성능 입증

8.2 한계점

  1. 계산 복잡도 개선 필요
  2. 매우 긴 시퀀스에서의 성능 저하
  3. 출력 순서 탐색의 효율성

8.3 향후 연구 방향

  1. 계층적 주목 메커니즘 도입
  2. 순서 탐색의 효율적 알고리즘 개발
  3. 더 복잡한 구조화된 출력 처리