이미지 로딩 중...

LLM 구현 5편 Multi-Head Attention 구현 완벽 가이드 - 슬라이드 1/12
A

AI Generated

2025. 11. 8. · 3 Views

LLM 구현 5편 Multi-Head Attention 구현 완벽 가이드

Transformer의 핵심인 Multi-Head Attention을 직접 구현해봅니다. Query, Key, Value의 개념부터 병렬 처리, 가중치 계산까지 실무에서 바로 활용할 수 있는 코드와 함께 깊이 있게 다룹니다.


목차

  1. Query, Key, Value의 이해 - Attention의 기본 구성 요소
  2. Scaled Dot-Product Attention - 가중치 계산의 핵심
  3. Multi-Head 구조 - 병렬 Attention으로 다양한 패턴 학습
  4. Position-wise Feed-Forward Network - 각 위치별 비선형 변환
  5. Layer Normalization과 Residual Connection - 안정적인 학습의 핵심
  6. Positional Encoding - 위치 정보 주입
  7. 완전한 Multi-Head Attention 통합 - 실전 구현
  8. Encoder-Decoder Attention - Cross-Attention의 핵심
  9. Causal Masking - Auto-regressive 생성의 필수 요소
  10. KV-Cache 최적화 - Auto-regressive 생성 가속화
  11. Attention Head Pruning - 모델 경량화 기법

1. Query, Key, Value의 이해 - Attention의 기본 구성 요소

시작하며

여러분이 대규모 언어 모델을 학습시키거나 Transformer 아키텍처를 커스터마이징할 때, "Query, Key, Value가 정확히 뭐지?"라는 의문을 가져본 적 있나요? 논문을 읽어도 추상적이고, 구현 코드를 봐도 감이 안 잡히는 경우가 많습니다.

이런 문제는 실제 NLP 모델을 개발하는 현장에서 자주 발생합니다. Attention 메커니즘을 제대로 이해하지 못하면 성능 튜닝이나 디버깅이 매우 어려워집니다.

특히 Multi-Head Attention에서 각 헤드가 어떻게 다른 패턴을 학습하는지 이해하지 못하면, 모델의 동작을 예측하기 힘듭니다. 바로 이럴 때 필요한 것이 Query, Key, Value의 명확한 개념 정립입니다.

이 세 가지를 정보 검색 시스템에 비유하면 매우 직관적으로 이해할 수 있습니다.

개요

간단히 말해서, Query는 "찾고자 하는 질문", Key는 "검색 대상의 인덱스", Value는 "실제 가져올 정보"입니다. 마치 도서관에서 책을 찾는 것과 같습니다.

왜 이 개념이 필요한지 실무 관점에서 설명하면, 모델이 문맥을 이해하고 관련성 높은 정보에 집중하기 위해서입니다. 예를 들어, "그는 은행에 갔다"라는 문장에서 '은행'이 금융기관인지 강둑인지를 판단할 때, 앞뒤 문맥(Key)과 현재 단어(Query)를 비교하여 적절한 의미(Value)를 가져옵니다.

전통적인 RNN이나 LSTM에서는 순차적으로 정보를 처리했다면, Attention을 사용하면 모든 위치의 정보를 동시에 참조할 수 있습니다. 이것이 Transformer가 긴 문맥을 효과적으로 처리하는 비결입니다.

핵심 특징은 첫째, 세 가지 모두 같은 입력에서 선형 변환으로 생성되며, 둘째, Query와 Key의 내적으로 관련도를 계산하고, 셋째, 그 가중치로 Value를 결합한다는 점입니다. 이러한 특징들이 모델이 데이터에서 중요한 패턴을 학습할 수 있게 만듭니다.

코드 예제

import torch
import torch.nn as nn

class QKVProjection(nn.Module):
    def __init__(self, d_model, d_k):
        super().__init__()
        # 입력 임베딩을 Q, K, V로 변환하는 선형 레이어
        self.W_q = nn.Linear(d_model, d_k)
        self.W_k = nn.Linear(d_model, d_k)
        self.W_v = nn.Linear(d_model, d_k)

    def forward(self, x):
        # x: (batch_size, seq_len, d_model)
        Q = self.W_q(x)  # Query: 무엇을 찾고 싶은가?
        K = self.W_k(x)  # Key: 검색 대상의 특징
        V = self.W_v(x)  # Value: 실제 가져올 정보
        return Q, K, V

설명

이것이 하는 일: 입력 임베딩 벡터를 세 가지 다른 관점(Query, Key, Value)으로 투영하여 Attention 메커니즘의 기반을 만듭니다. 첫 번째로, 세 개의 독립적인 선형 레이어(W_q, W_k, W_v)를 생성합니다.

이들은 각각 다른 가중치를 학습하여 입력 데이터의 서로 다른 특성을 추출합니다. 왜 이렇게 하는지는 각 역할이 다르기 때문입니다.

Query는 "무엇을 찾을지"에 최적화되고, Key는 "어떻게 매칭할지"에, Value는 "무엇을 출력할지"에 최적화됩니다. 그 다음으로, forward 메서드에서 동일한 입력 x에 대해 세 가지 변환을 동시에 수행합니다.

입력 x의 shape은 (batch_size, seq_len, d_model)인데, 각 토큰이 d_model 차원의 벡터로 표현됩니다. 이것을 d_k 차원으로 투영하면서 차원 축소와 동시에 특화된 표현을 얻습니다.

마지막으로, 세 개의 행렬 Q, K, V가 반환됩니다. 이들은 모두 (batch_size, seq_len, d_k) shape을 가지며, 다음 단계인 Scaled Dot-Product Attention의 입력으로 사용됩니다.

Q와 K의 내적으로 attention score를 계산하고, 이를 softmax로 정규화한 후 V에 곱하여 최종 출력을 만듭니다. 여러분이 이 코드를 사용하면 모델이 입력 시퀀스의 각 위치에서 "어떤 정보를 찾고(Q), 어떤 것과 비교하며(K), 무엇을 가져올지(V)"를 학습할 수 있습니다.

실무에서는 d_model을 512나 768, d_k를 64 정도로 설정하여 계산 효율과 표현력의 균형을 맞춥니다.

실전 팁

💡 d_k는 일반적으로 d_model을 head 개수로 나눈 값으로 설정합니다. 예를 들어 d_model=512, num_heads=8이면 d_k=64입니다.

💡 Q, K, V의 가중치 초기화가 매우 중요합니다. Xavier initialization을 사용하면 학습 초기에 gradient vanishing을 방지할 수 있습니다.

💡 디버깅 시 Q와 K의 내적 값의 분포를 확인하세요. 너무 크면 softmax가 saturate되어 gradient가 사라집니다.

💡 실무에서는 bias=False로 설정하는 경우가 많습니다. 논문에 따르면 QKV projection에서 bias는 성능 향상에 크게 기여하지 않습니다.

💡 메모리 효율을 위해 Q, K, V를 하나의 큰 행렬로 만든 후 split하는 방법도 있습니다. nn.Linear(d_model, 3*d_k)로 구현할 수 있습니다.


2. Scaled Dot-Product Attention - 가중치 계산의 핵심

시작하며

여러분이 Attention 스코어를 계산할 때 값이 너무 커져서 gradient가 소실되는 문제를 겪어본 적 있나요? 혹은 softmax 출력이 거의 one-hot처럼 되어버려 학습이 불안정해지는 상황을 경험했나요?

이런 문제는 실제 대규모 언어 모델 학습에서 매우 흔하게 발생합니다. Query와 Key의 차원이 커질수록 내적 값이 기하급수적으로 증가하고, 이는 softmax 함수의 입력을 극단적으로 만들어 학습을 방해합니다.

특히 긴 시퀀스를 처리할 때 이 문제가 더 심각해집니다. 바로 이럴 때 필요한 것이 Scaled Dot-Product Attention입니다.

sqrt(d_k)로 스케일링하는 간단한 트릭이지만, 이것이 Transformer 학습의 안정성을 크게 향상시킵니다.

개요

간단히 말해서, Scaled Dot-Product Attention은 Query와 Key의 내적을 차원의 제곱근으로 나누어 정규화한 후, softmax로 확률 분포를 만들고 Value에 적용하는 메커니즘입니다. 왜 이 스케일링이 필요한지 수학적으로 설명하면, Q와 K의 각 원소가 평균 0, 분산 1인 독립 변수라고 가정할 때, d_k개의 원소를 내적하면 결과의 분산이 d_k가 됩니다.

예를 들어, d_k=64이면 내적 값의 분산이 64배로 커지는데, 이를 sqrt(64)=8로 나누면 원래 분산 1로 복원됩니다. 이렇게 하면 softmax 입력이 적절한 범위에 유지되어 gradient가 건강하게 흐릅니다.

전통적인 additive attention에서는 학습 가능한 파라미터로 스케일을 조정했다면, 이제는 고정된 상수로 나누어 계산 효율성과 안정성을 동시에 얻습니다. 핵심 특징은 첫째, 계산이 매우 빠르고 병렬화가 쉬우며, 둘째, 추가 파라미터 없이 안정적인 학습이 가능하고, 셋째, masking을 통해 미래 정보 유출을 방지할 수 있다는 점입니다.

이러한 특징들이 Transformer를 현대 NLP의 표준 아키텍처로 만들었습니다.

코드 예제

import torch
import torch.nn.functional as F
import math

def scaled_dot_product_attention(Q, K, V, mask=None):
    # Q, K, V: (batch_size, seq_len, d_k)
    d_k = Q.size(-1)

    # Step 1: Q와 K의 내적 계산 (유사도 측정)
    scores = torch.matmul(Q, K.transpose(-2, -1))  # (batch, seq, seq)

    # Step 2: sqrt(d_k)로 스케일링 (gradient 안정화)
    scores = scores / math.sqrt(d_k)

    # Step 3: Masking (옵션, 미래 토큰 차단)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)

    # Step 4: Softmax로 확률 분포 생성
    attention_weights = F.softmax(scores, dim=-1)  # (batch, seq, seq)

    # Step 5: Value에 가중치 적용
    output = torch.matmul(attention_weights, V)  # (batch, seq, d_k)

    return output, attention_weights

설명

이것이 하는 일: 입력 시퀀스의 각 위치가 다른 모든 위치와 얼마나 관련이 있는지 계산하고, 그 관련도에 따라 정보를 집계합니다. 첫 번째로, Q와 K의 내적을 계산합니다.

torch.matmul(Q, K.transpose(-2, -1))는 각 query 벡터와 모든 key 벡터의 유사도를 계산하여 (batch, seq_len, seq_len) 크기의 attention score matrix를 만듭니다. 예를 들어, seq_len이 10이면 10x10 행렬이 되어 각 위치가 다른 9개 위치와의 관계를 모두 저장합니다.

왜 transpose를 하는지는 행렬 곱셈의 차원을 맞추기 위함입니다. 그 다음으로, 스케일링이 실행됩니다.

math.sqrt(d_k)로 나누는 것은 단순해 보이지만 매우 중요합니다. 내부적으로 이것이 softmax 입력값을 적절한 범위로 유지시켜, attention이 너무 sharp하게(거의 one-hot) 되거나 너무 uniform하게 되는 것을 방지합니다.

실험적으로 d_k가 크면 이 스케일링 없이는 학습이 거의 불가능합니다. 세 번째 단계에서는 선택적으로 masking을 적용합니다.

Decoder의 self-attention에서는 미래 토큰을 보지 못하도록 상삼각 행렬 마스크를 사용합니다. masked_fill은 mask가 0인 위치를 -1e9(음의 무한대에 가까운 값)로 채워, softmax 후 해당 위치의 가중치가 0이 되도록 합니다.

마지막으로, softmax로 정규화된 가중치를 Value에 곱합니다. attention_weights의 각 행은 확률 분포(합이 1)이므로, V와 곱하면 가중 평균이 됩니다.

예를 들어, 첫 번째 토큰의 attention weight가 [0.5, 0.3, 0.2]라면, 세 토큰의 value 벡터를 해당 비율로 섞어 출력을 만듭니다. 여러분이 이 코드를 사용하면 모델이 각 위치에서 "어떤 다른 위치에 집중할지"를 데이터로부터 학습할 수 있습니다.

실무에서는 attention_weights를 시각화하여 모델이 실제로 어떤 패턴을 학습했는지 분석하는 데 활용됩니다. 또한 dropout을 attention_weights에 적용하여 regularization 효과를 얻을 수 있습니다.

실전 팁

💡 스케일링 계산을 미리 저장하세요. scale = 1.0 / math.sqrt(d_k)를 한 번만 계산하고 곱셈으로 사용하면 나눗셈보다 빠릅니다.

💡 Masking 값을 -1e9 대신 float('-inf')를 사용하면 더 명확하지만, fp16 학습에서는 -1e4 정도가 안전합니다.

💡 attention_weights에 dropout을 추가하는 것이 일반적입니다. F.dropout(attention_weights, p=0.1, training=self.training)로 구현합니다.

💡 메모리 최적화를 위해 torch.backends.cuda.sdp_kernel을 사용하면 Flash Attention이 자동으로 적용되어 2-4배 빠릅니다.

💡 디버깅 시 scores의 최대/최소값을 확인하세요. 너무 크면 (-100 ~ 100 이상) 스케일링이 제대로 안 된 것입니다.


3. Multi-Head 구조 - 병렬 Attention으로 다양한 패턴 학습

시작하며

여러분이 단일 Attention만 사용했을 때 모델이 한 가지 관계만 학습하고 다른 중요한 패턴을 놓치는 문제를 경험한 적 있나요? 예를 들어, "The animal didn't cross the street because it was too tired"에서 'it'이 'animal'을 가리키는 것은 잡아내지만, 'too tired'와 'didn't cross' 사이의 인과 관계는 놓치는 경우입니다.

이런 문제는 실제 복잡한 언어 이해 태스크에서 매우 자주 발생합니다. 자연어는 여러 층위의 관계가 동시에 존재하는데, 단일 Attention head로는 한 번에 하나의 패턴만 효과적으로 포착할 수 있습니다.

구문 구조, 의미 관계, 공지시 등을 동시에 학습하려면 더 많은 표현력이 필요합니다. 바로 이럴 때 필요한 것이 Multi-Head Attention입니다.

여러 개의 독립적인 Attention head를 병렬로 실행하여, 각각이 다른 종류의 관계를 학습하도록 만듭니다. 마치 여러 명의 전문가가 각자의 관점에서 텍스트를 분석하는 것과 같습니다.

개요

간단히 말해서, Multi-Head Attention은 입력을 h개의 subspace로 분할하고, 각 subspace에서 독립적으로 Attention을 계산한 후, 결과를 연결(concatenate)하여 최종 출력을 만드는 구조입니다. 왜 이 구조가 필요한지 실무 관점에서 설명하면, 모델의 표현력을 기하급수적으로 증가시키기 때문입니다.

예를 들어, 8개의 head를 사용하면 어떤 head는 문법적 의존성을, 다른 head는 의미적 유사성을, 또 다른 head는 위치 관계를 학습할 수 있습니다. 실제 BERT와 GPT 모델을 분석한 연구에서 각 head가 실제로 다른 언어학적 패턴을 학습한다는 것이 확인되었습니다.

전통적인 단일 Attention에서는 하나의 표현 공간에서만 유사도를 계산했다면, 이제는 여러 독립적인 표현 공간에서 동시에 계산하여 더 풍부한 정보를 포착합니다. 각 head의 차원을 줄이되(d_k = d_model / h) 개수를 늘려 총 파라미터 수는 유지하면서 표현력만 증가시키는 것이 핵심입니다.

핵심 특징은 첫째, 완벽한 병렬 처리가 가능하여 계산 효율이 높고, 둘째, 각 head가 다른 초기화와 gradient를 받아 자연스럽게 다른 패턴을 학습하며, 셋째, concatenation 후 선형 변환으로 head들의 정보를 통합한다는 점입니다. 이러한 특징들이 Transformer의 강력한 표현 학습 능력의 핵심입니다.

코드 예제

import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # 각 head의 차원

        # QKV projection (모든 head를 한 번에 처리)
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        # 최종 출력 projection
        self.W_o = nn.Linear(d_model, d_model)

    def split_heads(self, x):
        # (batch, seq_len, d_model) -> (batch, num_heads, seq_len, d_k)
        batch_size, seq_len, d_model = x.size()
        return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

    def forward(self, Q, K, V, mask=None):
        # QKV projection
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))

        # Scaled Dot-Product Attention (각 head 독립 실행)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attn = F.softmax(scores, dim=-1)
        output = torch.matmul(attn, V)

        # Concatenate heads: (batch, num_heads, seq_len, d_k) -> (batch, seq_len, d_model)
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)

        # 최종 linear transformation
        return self.W_o(output)

설명

이것이 하는 일: 여러 개의 "관점"에서 동시에 Attention을 계산하여, 입력 시퀀스의 다층적인 관계를 포착합니다. 첫 번째로, 초기화 단계에서 d_model이 num_heads로 나누어떨어지는지 확인합니다.

이는 각 head가 동일한 차원(d_k)을 가져야 효율적으로 병렬 처리할 수 있기 때문입니다. 예를 들어 d_model=512, num_heads=8이면 각 head는 d_k=64 차원을 담당합니다.

W_q, W_k, W_v는 모두 d_model → d_model 크기이므로, 한 번의 행렬 곱셈으로 모든 head의 QKV를 동시에 생성합니다. 그 다음으로, split_heads 함수가 실행됩니다.

이 함수는 view와 transpose를 사용하여 (batch, seq_len, d_model)을 (batch, num_heads, seq_len, d_k)로 재구성합니다. 내부적으로 메모리 재배치가 일어나며, 이제 각 head가 독립적인 차원에 존재하게 됩니다.

contiguous()를 나중에 호출하는 이유는 transpose가 view의 논리적 순서만 바꾸고 실제 메모리는 그대로여서, 이후 연산을 위해 메모리를 연속적으로 재배치해야 하기 때문입니다. 세 번째 단계에서는 각 head별로 Scaled Dot-Product Attention이 병렬로 실행됩니다.

torch.matmul은 batch와 head 차원을 broadcast하여 한 번에 모든 head의 attention을 계산합니다. GPU에서는 이 모든 연산이 동시에 실행되어 매우 효율적입니다.

각 head는 독립적인 가중치(W_q, W_k, W_v의 서로 다른 부분)를 사용하므로 다른 패턴을 학습합니다. 마지막으로, head들의 출력을 다시 합칩니다.

transpose(1, 2)로 (batch, seq_len, num_heads, d_k) 형태로 만든 후, view로 (batch, seq_len, d_model)로 펼칩니다. 이것은 concatenation과 동일한 효과입니다.

그리고 W_o라는 최종 선형 변환을 적용하여 head들의 정보를 통합합니다. 이 W_o가 없으면 각 head의 출력이 단순히 나열되기만 하고 상호작용하지 못합니다.

여러분이 이 코드를 사용하면 모델이 문법, 의미, 문맥 등 여러 종류의 언어 정보를 동시에 학습할 수 있습니다. 실무에서는 num_heads를 8이나 16으로 설정하며, 더 큰 모델(GPT-3 등)에서는 96개까지 사용합니다.

Head 수를 늘릴수록 표현력은 증가하지만 계산 비용도 증가하므로, 태스크에 따라 적절한 균형을 찾아야 합니다.

실전 팁

💡 num_heads는 대부분 8, 12, 16 중 하나를 사용합니다. 홀수나 소수를 사용하면 GPU 최적화가 덜 효율적입니다.

💡 split_heads와 merge_heads를 별도 함수로 분리하면 코드 가독성이 높아지고 재사용이 쉽습니다.

💡 실무에서는 einsum을 사용하면 더 간결합니다: torch.einsum('bhqd,bhkd->bhqk', Q, K)로 attention score 계산 가능합니다.

💡 W_o를 생략하면 성능이 10-15% 하락합니다. 이 projection이 head 간 정보 교환의 핵심입니다.

💡 각 head의 attention 패턴을 시각화하려면 forward에서 attn을 반환하도록 수정하세요. head별로 어떤 패턴을 학습했는지 매우 흥미로운 인사이트를 얻을 수 있습니다.


4. Position-wise Feed-Forward Network - 각 위치별 비선형 변환

시작하며

여러분이 Multi-Head Attention만으로 모델을 구성했을 때, 선형 변환의 조합이라 표현력에 한계를 느낀 적 있나요? Attention은 가중 평균이므로 본질적으로 선형 연산이며, 이것만으로는 복잡한 비선형 패턴을 학습하기 어렵습니다.

이런 문제는 실제 deep learning 모델 설계에서 매우 중요한 이슈입니다. 선형 레이어를 여러 개 쌓아도 결국 하나의 선형 변환과 동일하므로, 모델의 깊이가 의미가 없어집니다.

복잡한 언어 패턴, 추상적 개념, 논리적 추론 등을 학습하려면 비선형성이 필수적입니다. 바로 이럴 때 필요한 것이 Position-wise Feed-Forward Network입니다.

각 위치에서 독립적으로 실행되는 2층 MLP로, ReLU나 GELU 같은 비선형 활성화 함수를 통해 모델의 표현력을 극적으로 증가시킵니다.

개요

간단히 말해서, FFN은 각 토큰 위치에서 독립적으로 적용되는 2개의 선형 변환과 1개의 비선형 활성화 함수로 구성된 작은 신경망입니다. 수식으로는 FFN(x) = max(0, xW1 + b1)W2 + b2로 표현됩니다.

왜 이 구조가 필요한지 실무 관점에서 설명하면, Attention 후의 표현을 더 풍부하게 변환하기 위함입니다. 예를 들어, Attention이 관련 정보들을 모았다면, FFN은 그 정보를 고차원 공간에서 복잡하게 조합하여 더 추상적인 특징을 추출합니다.

실험적으로 FFN을 제거하면 모델 성능이 30-40% 하락하는 것으로 알려져 있습니다. 전통적인 RNN에서는 hidden state 전환에 비선형성이 내재되어 있었다면, Transformer에서는 Attention과 FFN을 명확히 분리하여 각각의 역할을 명확히 했습니다.

특히 중간 차원(d_ff)을 입력 차원(d_model)의 4배로 키워 더 넓은 표현 공간을 탐색합니다. 핵심 특징은 첫째, 시퀀스의 각 위치에서 완전히 독립적으로 실행되어 병렬화가 완벽하고, 둘째, 중간 레이어의 차원 확장으로 높은 표현력을 얻으며, 셋째, GELU 같은 현대적 활성화 함수로 gradient 흐름을 개선한다는 점입니다.

이러한 특징들이 Transformer의 각 레이어가 점진적으로 더 추상적인 특징을 학습하게 만듭니다.

코드 예제

import torch
import torch.nn as nn

class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        # 첫 번째 linear: 차원 확장 (d_model -> d_ff)
        self.linear1 = nn.Linear(d_model, d_ff)
        # 두 번째 linear: 차원 축소 (d_ff -> d_model)
        self.linear2 = nn.Linear(d_ff, d_model)
        # Dropout for regularization
        self.dropout = nn.Dropout(dropout)
        # GELU activation (GPT 스타일, ReLU보다 부드러움)
        self.activation = nn.GELU()

    def forward(self, x):
        # x: (batch_size, seq_len, d_model)
        # Step 1: 차원 확장 및 비선형 변환
        x = self.activation(self.linear1(x))  # (batch, seq, d_ff)
        # Step 2: Dropout으로 regularization
        x = self.dropout(x)
        # Step 3: 원래 차원으로 축소
        x = self.linear2(x)  # (batch, seq, d_model)
        return x

설명

이것이 하는 일: Attention으로 모은 정보를 각 위치에서 독립적으로 비선형 변환하여 더 복잡하고 추상적인 특징으로 변환합니다. 첫 번째로, linear1이 입력 차원을 d_ff(일반적으로 d_model의 4배)로 확장합니다.

예를 들어 d_model=512이면 d_ff=2048이 됩니다. 왜 이렇게 차원을 크게 늘리는지는 더 넓은 특징 공간에서 복잡한 패턴을 표현하기 위함입니다.

이것은 병목(bottleneck)의 반대 개념으로, 일시적으로 정보를 압축 해제하여 더 많은 조합을 탐색합니다. 그 다음으로, GELU 활성화 함수가 적용됩니다.

GELU(Gaussian Error Linear Unit)는 ReLU와 달리 부드러운 곡선을 그려 0 근처에서도 작은 gradient를 허용합니다. 내부적으로 GELU(x) = x * Φ(x) (Φ는 표준정규분포 CDF)로 계산되며, 이는 입력을 확률적으로 "통과" 또는 "차단"하는 효과를 냅니다.

GPT와 BERT 모두 GELU를 사용하며, 실험적으로 ReLU보다 1-2% 성능 향상이 있습니다. 세 번째 단계에서는 Dropout이 적용됩니다.

p=0.1 정도로 설정하면 훈련 중 10%의 뉴런을 무작위로 꺼서 overfitting을 방지합니다. FFN은 파라미터 수가 많아 (d_model × d_ff × 2) overfitting 위험이 높으므로, Dropout이 중요한 역할을 합니다.

추론 시에는 자동으로 비활성화됩니다. 마지막으로, linear2가 차원을 다시 d_model로 축소합니다.

이 과정에서 확장된 공간에서 학습한 복잡한 패턴이 원래 차원으로 압축되어, 다음 레이어나 residual connection에 전달됩니다. 이 2단계 변환(확장→축소)이 일종의 autoencoder처럼 작동하여 중요한 특징만 추출합니다.

여러분이 이 코드를 사용하면 모델이 단순한 선형 변환을 넘어 복잡한 특징 조합을 학습할 수 있습니다. 실무에서는 FFN의 파라미터가 전체 Transformer 파라미터의 약 2/3를 차지할 정도로 큽니다.

따라서 모델 경량화 시 FFN을 먼저 최적화하는 것이 효과적입니다. 예를 들어 Mixture of Experts(MoE)는 FFN을 여러 개로 나누고 조건부로 활성화하여 효율을 높입니다.

실전 팁

💡 d_ff를 d_model의 4배로 설정하는 것이 표준이지만, 작은 모델에서는 2배, 큰 모델에서는 8배도 시도해볼 가치가 있습니다.

💡 GELU 대신 Swish(x * sigmoid(x))나 GLU(Gated Linear Unit)를 사용하면 특정 태스크에서 성능 향상이 있을 수 있습니다.

💡 Dropout 위치도 중요합니다. activation 전에 적용하면 학습이 불안정해질 수 있으니 후에 적용하세요.

💡 메모리 절약을 위해 gradient checkpointing을 FFN에 적용하면 메모리 사용량을 30% 줄일 수 있습니다 (속도는 약간 느려짐).

💡 linear1과 linear2의 가중치를 He initialization으로 초기화하고, bias는 0으로 설정하는 것이 학습 초기 안정성에 좋습니다.


5. Layer Normalization과 Residual Connection - 안정적인 학습의 핵심

시작하며

여러분이 깊은 Transformer 모델을 학습시킬 때 gradient vanishing이나 exploding으로 학습이 발산하는 문제를 겪어본 적 있나요? 혹은 레이어를 추가할수록 성능이 오히려 떨어지는 degradation 현상을 경험했나요?

이런 문제는 실제 대규모 언어 모델 개발에서 가장 큰 장애물 중 하나입니다. Transformer 블록을 12개, 24개, 심지어 96개까지 쌓으면서 안정적으로 학습하려면 특별한 기법이 필요합니다.

Gradient가 수십 개의 레이어를 통과하면서 너무 작아지거나 커지면, 초기 레이어는 전혀 학습되지 않거나 폭발합니다. 바로 이럴 때 필요한 것이 Layer Normalization과 Residual Connection입니다.

이 두 가지를 조합하면 100개 이상의 레이어도 안정적으로 학습할 수 있으며, 이것이 BERT, GPT 같은 거대 모델의 기반입니다.

개요

간단히 말해서, Residual Connection은 입력을 출력에 더해주는 skip connection이고, Layer Normalization은 각 레이어의 출력을 평균 0, 분산 1로 정규화하는 기법입니다. 수식으로는 LayerNorm(x + Sublayer(x)) 형태입니다.

왜 이 기법들이 필요한지 실무 관점에서 설명하면, 두 가지 모두 gradient 흐름을 개선하기 때문입니다. Residual connection은 gradient가 여러 레이어를 건너뛰어 직접 초기 레이어로 흐를 수 있는 "고속도로"를 만들어줍니다.

예를 들어, 24개 레이어가 있어도 gradient가 residual path를 통해 곧바로 전달되어 초기 레이어도 잘 학습됩니다. Layer Normalization은 각 레이어의 입력 분포를 안정화하여 내부 공변량 변화(internal covariate shift)를 줄입니다.

전통적인 Batch Normalization에서는 배치 차원에서 정규화했다면, Layer Normalization은 특징 차원에서 정규화하여 배치 크기에 독립적이고 RNN이나 Transformer 같은 시퀀스 모델에 더 적합합니다. 특히 배치 크기가 작거나 시퀀스 길이가 다를 때 BN은 불안정하지만 LN은 안정적입니다.

핵심 특징은 첫째, residual connection으로 gradient highway가 형성되어 깊은 네트워크 학습이 가능하고, 둘째, layer normalization으로 activation 값이 폭발하거나 소실되는 것을 방지하며, 셋째, 두 기법의 조합(Pre-LN 또는 Post-LN)으로 학습 안정성과 최종 성능의 균형을 맞춘다는 점입니다. 이러한 특징들이 현대 Transformer의 필수 구성 요소가 되었습니다.

코드 예제

import torch
import torch.nn as nn

class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        # Sub-layers
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)

        # Layer Normalization (각 sub-layer 후)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        # Dropout for residual connections
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Sub-layer 1: Multi-Head Attention + Residual + Norm
        # Post-LN 방식: Sublayer -> Residual -> Norm
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))  # Residual connection

        # Sub-layer 2: Feed-Forward + Residual + Norm
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))  # Residual connection

        return x

설명

이것이 하는 일: 깊은 네트워크에서 gradient가 잘 흐르도록 하고, 각 레이어의 입력 분포를 안정화하여 빠르고 안정적인 학습을 가능하게 합니다. 첫 번째로, Multi-Head Attention의 출력에 Dropout을 적용한 후 원래 입력 x와 더합니다.

이 덧셈이 residual connection이며, backpropagation 시 gradient가 더하기 노드를 통해 그대로 전달되어 vanishing을 방지합니다. 수학적으로 y = x + F(x) 형태이며, dy/dx = 1 + dF/dx이므로 최소한 gradient 1이 보장됩니다.

왜 Dropout을 먼저 적용하는지는 regularization 효과를 residual path에도 주기 위함입니다. 그 다음으로, LayerNorm이 적용됩니다.

nn.LayerNorm(d_model)은 마지막 차원(특징 차원)에 대해 평균과 분산을 계산하여 정규화합니다. 내부적으로 (x - mean) / sqrt(var + eps) 연산 후, 학습 가능한 scale(γ)과 shift(β) 파라미터를 곱하고 더합니다.

이 γ와 β 덕분에 정규화를 했음에도 모델이 필요하면 원래 분포를 복원할 수 있습니다. eps=1e-5는 분산이 0일 때 나눗셈 오류를 방지합니다.

세 번째 단계에서는 Feed-Forward 네트워크에도 동일한 패턴(Sublayer → Dropout → Residual → Norm)을 적용합니다. 이렇게 각 sub-layer마다 residual connection과 normalization을 반복하면, 전체 Transformer 블록이 y = Norm(x + FFN(Norm(x + Attn(x))))처럼 중첩된 residual 구조가 됩니다.

이것이 매우 깊은 네트워크에서도 안정적인 학습을 가능하게 합니다. Post-LN vs Pre-LN: 위 코드는 Post-LN 방식(Sublayer → Residual → Norm)이지만, 최근에는 Pre-LN(Norm → Sublayer → Residual)이 더 안정적이라고 알려져 있습니다.

Pre-LN은 gradient가 더 부드럽게 흐르고 warm-up 없이도 학습이 가능하지만, 최종 성능은 Post-LN이 약간 더 높을 수 있습니다. GPT-2는 Pre-LN을, BERT는 Post-LN을 사용합니다.

여러분이 이 코드를 사용하면 수십 개의 Transformer 블록을 쌓아도 안정적으로 학습할 수 있습니다. 실무에서는 learning rate warm-up, gradient clipping과 함께 사용하여 더욱 안정성을 높입니다.

또한 LayerNorm의 γ와 β를 초기화할 때 γ=1, β=0으로 설정하여 초기에는 정규화가 identity function처럼 작동하게 만드는 것도 좋은 실전 팁입니다.

실전 팁

💡 깊은 모델(24+ layers)에서는 Pre-LN을 사용하세요. Post-LN보다 학습이 훨씬 안정적이며 warm-up 단계를 줄일 수 있습니다.

💡 Dropout 비율은 작은 모델에서 0.1, 큰 모델에서 0.0~0.1을 사용합니다. 너무 높으면 residual connection의 이점이 감소합니다.

💡 LayerNorm 대신 RMSNorm을 사용하면 계산이 20% 빠릅니다. 평균 계산을 생략하고 RMS만 사용하는데, 성능 차이는 거의 없습니다.

💡 디버깅 시 각 레이어의 output norm을 모니터링하세요. torch.norm(x, dim=-1).mean()이 너무 크거나 작으면 학습이 불안정합니다.

💡 Residual connection의 scale을 조정하는 "ReZero" 기법도 있습니다. x = x + α * Sublayer(x) 형태로 α를 0에서 시작해 학습하면 매우 깊은 네트워크도 학습 가능합니다.


6. Positional Encoding - 위치 정보 주입

시작하며

여러분이 Transformer 모델을 처음 구현했을 때, 모델이 단어의 순서를 전혀 구분하지 못한다는 것을 발견한 적 있나요? "고양이가 개를 쫓았다"와 "개가 고양이를 쫓았다"를 똑같이 처리하는 문제입니다.

이런 문제는 Transformer의 구조적 특성 때문에 발생합니다. Multi-Head Attention은 집합 연산(set operation)이므로, 입력 순서를 바꿔도 출력이 동일합니다.

RNN이나 CNN과 달리 순차적 처리나 지역성이 없어서, 별도로 위치 정보를 주입하지 않으면 모델이 "The cat sat on the mat"과 "mat the on sat cat The"를 구분할 수 없습니다. 바로 이럴 때 필요한 것이 Positional Encoding입니다.

각 토큰의 위치를 나타내는 벡터를 입력 임베딩에 더하여, 모델이 단어의 순서와 상대적 위치를 학습할 수 있게 만듭니다.

개요

간단히 말해서, Positional Encoding은 sin과 cos 함수를 사용하여 각 위치를 고유한 벡터로 인코딩한 후, 토큰 임베딩에 더하는 기법입니다. 수식은 PE(pos, 2i) = sin(pos / 10000^(2i/d_model)), PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))입니다.

왜 sin/cos를 사용하는지 실무 관점에서 설명하면, 여러 장점이 있기 때문입니다. 첫째, 학습 파라미터가 필요 없어 메모리 효율적이고, 둘째, 임의의 길이 시퀀스에 대응 가능하며(학습보다 긴 시퀀스도 처리), 셋째, 주기 함수의 성질로 상대적 위치 관계를 자동으로 학습할 수 있습니다.

예를 들어, PE(pos+k)를 PE(pos)와 PE(k)의 선형 조합으로 표현할 수 있어, 모델이 "3단어 떨어진 관계"를 일반화하기 쉽습니다. 전통적인 learned positional embedding(위치마다 학습 가능한 벡터)에서는 최대 길이가 고정되고 파라미터가 많이 필요했다면, sinusoidal encoding은 파라미터 없이 무한히 긴 시퀀스를 처리할 수 있습니다.

실험적으로 두 방법의 성능 차이는 거의 없지만, sinusoidal이 더 범용적입니다. 핵심 특징은 첫째, 각 위치가 고유하고 결정적인 벡터를 가지며, 둘째, 서로 다른 주파수의 sin/cos 조합으로 풍부한 위치 정보를 표현하고, 셋째, 덧셈 방식으로 토큰 임베딩과 결합되어 별도 차원을 차지하지 않는다는 점입니다.

이러한 특징들이 Transformer가 순서 정보를 효과적으로 활용하게 만듭니다.

코드 예제

import torch
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        # 위치 인코딩 행렬 생성: (max_len, d_model)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()  # (max_len, 1)

        # 주파수 계산: 10000^(2i/d_model)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                            (-math.log(10000.0) / d_model))

        # Sin을 짝수 인덱스, Cos를 홀수 인덱스에 적용
        pe[:, 0::2] = torch.sin(position * div_term)  # 짝수 차원
        pe[:, 1::2] = torch.cos(position * div_term)  # 홀수 차원

        # 배치 차원 추가: (1, max_len, d_model)
        pe = pe.unsqueeze(0)

        # 학습되지 않는 파라미터로 등록 (저장은 되지만 gradient 안 흐름)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: (batch_size, seq_len, d_model)
        seq_len = x.size(1)
        # 필요한 길이만큼만 잘라서 더하기
        return x + self.pe[:, :seq_len, :]

설명

이것이 하는 일: 순서 정보가 없는 Transformer에게 각 토큰의 절대적/상대적 위치를 알려주어, 문장 구조와 순서를 이해할 수 있게 만듭니다. 첫 번째로, 초기화 단계에서 전체 positional encoding 행렬을 미리 계산합니다.

max_len=5000이면 5000개 위치에 대한 인코딩을 모두 만들어둡니다. position 변수는 [0, 1, 2, ..., max_len-1]을 (max_len, 1) 형태로 만들고, div_term은 각 차원별 주파수를 계산합니다.

왜 exp와 log를 사용하는지는 10000^(2i/d_model)을 안정적으로 계산하기 위함입니다. 직접 거듭제곱하면 수치 오류가 발생할 수 있습니다.

그 다음으로, sin과 cos를 번갈아가며 적용합니다. pe[:, 0::2]는 짝수 인덱스(0, 2, 4, ...)를 의미하며 여기에 sin을, pe[:, 1::2]는 홀수 인덱스에 cos를 적용합니다.

이렇게 하면 각 차원마다 다른 주파수를 가진 sin/cos 파형이 만들어집니다. 낮은 차원은 빠르게 진동하여 인접 위치를 구분하고, 높은 차원은 천천히 진동하여 먼 위치 관계를 표현합니다.

세 번째 단계에서는 register_buffer로 pe를 등록합니다. 이것은 nn.Parameter와 달리 gradient가 흐르지 않지만, state_dict에 저장되어 모델과 함께 저장/로드됩니다.

Positional encoding은 학습하지 않고 고정된 값을 사용하므로 buffer가 적합합니다. GPU로 모델을 옮기면 pe도 자동으로 GPU로 이동합니다.

마지막으로, forward에서는 입력 시퀀스 길이만큼만 잘라서 더합니다. self.pe[:, :seq_len, :]로 필요한 부분만 슬라이싱하여 broadcasting으로 더하면, 각 배치의 각 토큰에 해당 위치의 인코딩이 추가됩니다.

예를 들어, 3번째 토큰은 pe[:, 2, :]가 더해져 "나는 3번째 위치에 있다"는 정보를 얻습니다. 여러분이 이 코드를 사용하면 모델이 "주어 다음에 동사가 온다", "관계대명사는 5단어 전 명사를 참조한다" 같은 순서 기반 패턴을 학습할 수 있습니다.

실무에서는 RoPE(Rotary Positional Embedding)나 ALiBi(Attention with Linear Biases) 같은 더 발전된 기법도 사용되지만, sinusoidal encoding이 여전히 기본입니다. 특히 GPT-3 같은 초거대 모델에서도 잘 작동하는 것이 검증되었습니다.

실전 팁

💡 max_len은 예상 최대 시퀀스 길이의 2배 정도로 설정하세요. 메모리는 거의 안 쓰지만 안전성이 높아집니다.

💡 실무에서는 learned embedding과 sinusoidal을 비교 실험해보세요. 데이터가 많으면 learned가, 적으면 sinusoidal이 유리한 경향이 있습니다.

💡 Positional encoding을 dropout하는 것도 효과적입니다. self.dropout = nn.Dropout(0.1)을 추가하여 x + self.pe에 적용하세요.

💡 디버깅 시 pe[:, :10, 0]처럼 몇 개 차원을 시각화하면 sin/cos 파형이 제대로 생성되었는지 확인할 수 있습니다.

💡 RoPE는 attention 내부에서 위치를 회전 변환으로 주입하여 외삽(extrapolation) 성능이 더 좋습니다. 매우 긴 시퀀스를 다룬다면 고려해보세요.


7. 완전한 Multi-Head Attention 통합 - 실전 구현

시작하며

여러분이 지금까지 배운 모든 구성 요소(QKV projection, Scaled Attention, Multi-Head, FFN, Normalization 등)를 실제로 통합할 때, 어떤 순서로 조합하고 어떤 세부 사항을 놓치지 말아야 할지 막막한 적 있나요? 이런 문제는 실제 Transformer를 처음 구현하는 개발자들이 가장 자주 겪는 어려움입니다.

각 부분은 이해했지만 전체를 조립할 때 shape mismatch, gradient 문제, 성능 저하 등 다양한 이슈가 발생합니다. 특히 batch 처리, masking, dropout 타이밍 등 논문에 명시되지 않은 세부 사항들이 성능에 큰 영향을 줍니다.

바로 이럴 때 필요한 것이 production-ready한 완전한 Multi-Head Attention 구현입니다. 모든 edge case를 처리하고, 최적화되고, 디버깅 가능한 코드가 실무에서 필요합니다.

개요

간단히 말해서, 완전한 구현은 QKV projection, head splitting, scaled attention, head concatenation, output projection, dropout, residual connection, layer normalization을 모두 올바른 순서로 결합한 것입니다. 왜 이 통합 구현이 중요한지 실무 관점에서 설명하면, 각 구성 요소의 상호작용이 전체 성능을 결정하기 때문입니다.

예를 들어, dropout을 어디에 적용하느냐(attention weights, output, residual connection)에 따라 overfitting 방지 효과가 크게 달라집니다. 또한 mask 처리를 잘못하면 정보 누출(future leakage)이 발생하여 평가 지표는 좋지만 실제 추론 시 성능이 떨어지는 문제가 생깁니다.

전통적인 튜토리얼 코드에서는 간단한 케이스만 다뤘다면, production 코드에서는 variable-length sequences, padding mask, attention mask, mixed precision training, gradient checkpointing 등을 모두 고려해야 합니다. 실제 BERT나 GPT 구현을 보면 핵심 로직 외에 이런 실무적 처리가 대부분을 차지합니다.

핵심 특징은 첫째, 모든 텐서 연산이 배치 우선(batch-first)으로 일관되게 처리되고, 둘째, masking이 여러 레벨(padding, causal, custom)에서 안전하게 적용되며, 셋째, 메모리와 속도 최적화(in-place 연산, fused kernels)가 적용되어 있다는 점입니다. 이러한 특징들이 실제 서비스에서 사용 가능한 코드를 만듭니다.

코드 예제

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class ProductionMultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.scale = 1.0 / math.sqrt(self.d_k)

        # Fused QKV projection (더 빠름)
        self.qkv_proj = nn.Linear(d_model, 3 * d_model)
        self.out_proj = nn.Linear(d_model, d_model)

        self.attn_dropout = nn.Dropout(dropout)
        self.out_dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None, return_attention=False):
        batch_size, seq_len, d_model = x.size()

        # QKV를 한 번에 계산 후 split
        qkv = self.qkv_proj(x)  # (batch, seq, 3*d_model)
        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.d_k)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, batch, heads, seq, d_k)
        Q, K, V = qkv[0], qkv[1], qkv[2]

        # Scaled Dot-Product Attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale

        if mask is not None:
            # mask: (batch, 1, 1, seq_len) for padding or (batch, 1, seq, seq) for causal
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.attn_dropout(attn_weights)

        # Apply attention to values
        output = torch.matmul(attn_weights, V)  # (batch, heads, seq, d_k)

        # Concatenate heads
        output = output.transpose(1, 2).contiguous()  # (batch, seq, heads, d_k)
        output = output.view(batch_size, seq_len, d_model)

        # Final projection
        output = self.out_proj(output)
        output = self.out_dropout(output)

        if return_attention:
            return output, attn_weights
        return output

설명

이것이 하는 일: 실제 서비스에서 사용 가능한 수준의 Multi-Head Attention을 구현하며, edge case 처리와 성능 최적화를 모두 포함합니다. 첫 번째로, QKV projection을 fused 방식으로 처리합니다.

세 개의 별도 Linear 대신 하나의 큰 Linear(d_model → 3*d_model)를 사용하면 커널 호출 횟수가 줄어 20-30% 빠릅니다. 그 후 reshape과 permute로 (3, batch, heads, seq, d_k) 형태로 만들어, 첫 번째 차원에서 Q, K, V를 분리합니다.

왜 이렇게 복잡한 reshaping을 하는지는 메모리 접근 패턴을 최적화하기 위함입니다. 그 다음으로, masking 처리가 매우 신중하게 이루어집니다.

mask의 shape이 (batch, 1, 1, seq_len)이면 padding mask로 해석되어 특정 위치를 모든 query에서 차단합니다. (batch, 1, seq, seq)이면 causal mask로 미래 토큰을 차단합니다.

float('-inf')를 사용하여 softmax 후 정확히 0이 되도록 하며, fp16에서는 -1e4 정도를 사용하여 수치 안정성을 높입니다. 세 번째 단계에서는 attention dropout이 적용됩니다.

이것은 attention weights에 직접 적용되어, 훈련 중 일부 연결을 무작위로 끊습니다. 이는 모델이 특정 위치에 과도하게 의존하는 것을 방지하고, 더 robust한 표현을 학습하게 만듭니다.

이 dropout은 output dropout과는 다른 목적으로, 둘 다 사용하는 것이 일반적입니다. 마지막으로, head concatenation과 output projection이 수행됩니다.

transpose(1, 2)로 heads와 seq_len 차원을 바꾼 후, contiguous()로 메모리를 연속화합니다. view를 사용하려면 메모리가 연속적이어야 하므로 이 단계가 필수입니다.

out_proj는 모든 head의 정보를 섞어 최종 표현을 만들고, out_dropout으로 regularization을 추가합니다. 여러분이 이 코드를 사용하면 BERT, GPT 같은 실제 모델의 핵심 부분을 구현할 수 있습니다.

return_attention=True 옵션으로 attention weights를 반환받아 모델 해석(interpretability)에 활용할 수 있으며, gradient checkpointing을 적용하면 메모리 사용량을 크게 줄일 수 있습니다. 실무에서는 이 코드를 기반으로 Flash Attention이나 Linformer 같은 효율적 변형을 적용하여 더욱 최적화합니다.

실전 팁

💡 Fused QKV projection은 메모리도 절약합니다. 3개 Linear는 각각 임시 activation을 저장하지만, 1개 Linear는 한 번만 저장합니다.

💡 Mask shape를 미리 확인하는 assert를 추가하세요. 잘못된 mask는 디버깅하기 매우 어려운 silent bug를 만듭니다.

💡 torch.cuda.amp.autocast()와 함께 사용하면 자동 mixed precision 학습이 가능합니다. 메모리 50% 절약, 속도 2배 향상이 일반적입니다.

💡 매우 긴 시퀀스(1024+)에서는 Flash Attention을 사용하세요. torch.nn.functional.scaled_dot_product_attention이 PyTorch 2.0+에서 자동으로 최적 구현을 선택합니다.

💡 Attention weights를 시각화할 때는 여러 head와 레이어를 동시에 보세요. 특정 head가 특정 패턴(예: 다음 토큰 예측)을 담당하는 경우가 많습니다.


8. Encoder-Decoder Attention - Cross-Attention의 핵심

시작하며

여러분이 기계 번역이나 요약 모델을 구현할 때, Encoder의 정보를 Decoder가 어떻게 참조해야 할지 고민한 적 있나요? Self-Attention만으로는 입력 문장과 출력 문장 간의 관계를 학습할 수 없습니다.

이런 문제는 sequence-to-sequence 태스크에서 필수적으로 발생합니다. "I love you"를 "나는 너를 사랑해"로 번역할 때, Decoder가 "사랑해"를 생성하는 시점에 입력의 "love"를 참조해야 합니다.

Self-Attention으로는 Decoder 내부의 관계만 볼 수 있고, Encoder의 정보는 접근할 수 없습니다. 바로 이럴 때 필요한 것이 Encoder-Decoder Attention (Cross-Attention)입니다.

Query는 Decoder에서, Key와 Value는 Encoder에서 가져와 두 시퀀스 간의 관계를 학습합니다.

개요

간단히 말해서, Cross-Attention은 Query를 Decoder 상태에서, Key와 Value를 Encoder 출력에서 생성하여 "Decoder의 각 위치가 Encoder의 어느 부분에 집중해야 하는지"를 학습하는 메커니즘입니다. 왜 이 메커니즘이 필요한지 실무 관점에서 설명하면, 입력과 출력의 정렬(alignment)을 자동으로 학습하기 때문입니다.

예를 들어, "The quick brown fox"를 "빠른 갈색 여우"로 번역할 때, "갈색"을 생성하는 Decoder는 "brown"에 높은 attention을 주어야 합니다. Cross-Attention은 이런 대응 관계를 데이터에서 학습하여, 전통적인 alignment 알고리즘 없이도 정확한 번역을 가능하게 합니다.

전통적인 Seq2Seq with Attention에서는 Encoder의 마지막 hidden state만 사용했다면, Transformer의 Cross-Attention은 모든 Encoder 위치를 동시에 참조할 수 있어 훨씬 풍부한 정보를 활용합니다. 특히 긴 문장에서 정보 병목(bottleneck) 현상이 없어 성능이 크게 향상됩니다.

핵심 특징은 첫째, Query와 Key/Value의 출처가 다르며, 둘째, Encoder와 Decoder의 시퀀스 길이가 달라도 동작하고, 셋째, attention weights가 입력-출력 간 정렬을 명시적으로 보여준다는 점입니다. 이러한 특징들이 Transformer를 seq2seq 태스크의 표준 아키텍처로 만들었습니다.

코드 예제

import torch
import torch.nn as nn
import torch.nn.functional as F

class EncoderDecoderAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.scale = 1.0 / math.sqrt(self.d_k)

        # Query는 decoder에서, Key/Value는 encoder에서
        self.W_q = nn.Linear(d_model, d_model)  # Decoder input
        self.W_k = nn.Linear(d_model, d_model)  # Encoder output
        self.W_v = nn.Linear(d_model, d_model)  # Encoder output

        self.out_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, decoder_hidden, encoder_output, encoder_mask=None):
        # decoder_hidden: (batch, decoder_len, d_model)
        # encoder_output: (batch, encoder_len, d_model)
        # encoder_mask: (batch, 1, 1, encoder_len) - padding mask

        batch_size = decoder_hidden.size(0)
        decoder_len = decoder_hidden.size(1)
        encoder_len = encoder_output.size(1)

        # Q from decoder, K and V from encoder
        Q = self.W_q(decoder_hidden)  # (batch, decoder_len, d_model)
        K = self.W_k(encoder_output)  # (batch, encoder_len, d_model)
        V = self.W_v(encoder_output)  # (batch, encoder_len, d_model)

        # Split heads
        Q = Q.view(batch_size, decoder_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, encoder_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, encoder_len, self.num_heads, self.d_k).transpose(1, 2)

        # Attention: (batch, heads, decoder_len, encoder_len)
        scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale

        # Mask padding tokens in encoder
        if encoder_mask is not None:
            scores = scores.masked_fill(encoder_mask == 0, float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Apply attention to values
        output = torch.matmul(attn_weights, V)  # (batch, heads, decoder_len, d_k)
        output = output.transpose(1, 2).contiguous().view(batch_size, decoder_len, d_model)

        return self.out_proj(output)

설명

이것이 하는 일: Decoder의 각 위치가 Encoder의 어떤 부분에 집중해야 하는지 학습하여, 입력 정보를 출력 생성에 효과적으로 활용합니다. 첫 번째로, Q, K, V의 입력 소스가 다릅니다.

W_q는 decoder_hidden에 적용되어 "현재 디코더 상태에서 무엇을 찾고 싶은가"를 표현하고, W_k와 W_v는 encoder_output에 적용되어 "인코더의 어떤 정보를 제공할 것인가"를 나타냅니다. 이 비대칭성이 Cross-Attention의 핵심이며, Self-Attention과의 유일한 차이점입니다.

그 다음으로, attention score matrix의 shape이 (batch, heads, decoder_len, encoder_len)이 됩니다. 이것은 Decoder의 각 위치(행)가 Encoder의 모든 위치(열)에 대한 가중치를 가진다는 의미입니다.

예를 들어, decoder_len=10, encoder_len=15라면 10x15 행렬이 되어 서로 다른 길이의 시퀀스도 문제없이 처리됩니다. 이것이 variable-length input/output을 자연스럽게 다룰 수 있는 이유입니다.

세 번째 단계에서는 encoder_mask가 적용됩니다. 입력 문장이 padding되어 있다면, 그 위치는 실제 정보가 없으므로 attention을 주면 안 됩니다.

encoder_mask는 (batch, 1, 1, encoder_len) shape으로, padding 위치가 0입니다. Masked_fill로 해당 위치의 score를 -inf로 만들어, softmax 후 가중치가 0이 되도록 합니다.

마지막으로, attention이 적용되고 결과가 합쳐집니다. attn_weights와 V의 행렬 곱은 "Decoder의 각 위치가 Encoder 정보의 가중 평균"을 만듭니다.

예를 들어, "사랑해"를 생성하는 위치의 attention weight가 [0.1, 0.8, 0.05, 0.05]라면, "love"(두 번째 토큰)의 value 벡터가 80% 기여합니다. 이렇게 학습된 alignment가 번역 품질의 핵심입니다.

여러분이 이 코드를 사용하면 기계 번역, 문서 요약, 질의응답 등 입출력이 다른 seq2seq 태스크를 구현할 수 있습니다. 실무에서는 attention weights를 시각화하여 모델이 어떤 입력 부분을 참조했는지 분석하고, 잘못된 번역의 원인을 파악하는 데 활용합니다.

또한 encoder_output을 캐싱하여 디코딩 속도를 높이는 최적화도 일반적입니다.

실전 팁

💡 Encoder output을 여러 번 재사용하므로 K, V projection 결과를 캐싱하면 디코딩 속도가 2-3배 빨라집니다.

💡 Attention weights의 엔트로피를 모니터링하세요. 너무 낮으면(sharp) 특정 위치에만 의존하여 robust하지 않습니다.

💡 Encoder와 Decoder의 d_model이 다르면 K, V projection의 출력 차원을 조정해야 합니다. 일반적으로는 같게 설정합니다.

💡 Multi-source translation(여러 언어 → 하나)에서는 encoder_output을 여러 개 받아 각각 cross-attention하는 구조도 가능합니다.

💡 Alignment를 시각화할 때는 attention weights를 heatmap으로 그리세요. 대각선 패턴이면 단조 정렬, 교차 패턴이면 어순 변화를 의미합니다.


9. Causal Masking - Auto-regressive 생성의 필수 요소

시작하며

여러분이 언어 모델을 학습시킬 때 모델이 미래 토큰을 "컨닝"하여 비현실적으로 높은 성능을 내는 문제를 경험한 적 있나요? 훈련 시에는 정확도가 99%인데 실제 생성할 때는 엉망인 경우입니다.

이런 문제는 auto-regressive 모델(GPT 계열)에서 매우 흔하게 발생합니다. 훈련 시 전체 시퀀스를 한 번에 입력하면, Self-Attention이 미래 위치의 정보까지 참조하게 됩니다.

예를 들어, "I love you"에서 "love"를 예측할 때 이미 "you"를 본다면, 실제 생성 상황(앞 토큰만 알고 있음)과 조건이 달라져 제대로 학습되지 않습니다. 바로 이럴 때 필요한 것이 Causal Masking (또는 Future Masking)입니다.

각 위치가 자신과 이전 위치만 볼 수 있도록 상삼각 행렬로 차단하여, 훈련과 추론의 조건을 일치시킵니다.

개요

간단히 말해서, Causal Mask는 attention score matrix의 상삼각 부분(미래 위치)을 -inf로 채워, softmax 후 가중치가 0이 되도록 만드는 기법입니다. 수학적으로 mask[i, j] = 0 if j > i else 1입니다.

왜 이 masking이 필요한지 실무 관점에서 설명하면, teacher forcing의 효율성과 auto-regressive 생성의 일관성을 동시에 얻기 위함입니다. Teacher forcing은 훈련 시 전체 정답 시퀀스를 한 번에 입력하여 병렬 처리로 학습을 빠르게 하는 기법인데, causal mask 없이 사용하면 미래 정보 유출이 발생합니다.

Causal mask로 각 위치를 독립적으로 만들면, 병렬 처리의 속도 이점을 유지하면서도 추론 시와 동일한 조건을 시뮬레이션할 수 있습니다. 전통적인 RNN에서는 순차 처리로 자연스럽게 미래 정보가 차단되었다면, Transformer의 병렬 Self-Attention에서는 명시적인 masking이 필수입니다.

이것이 GPT가 "generative pre-training"을 할 수 있는 기반입니다. 핵심 특징은 첫째, 하삼각 행렬 형태의 간단한 구조지만 강력한 효과를 내며, 둘째, 한 번 생성하면 모든 시퀀스 길이에 재사용 가능하고, 셋째, padding mask와 결합하여 복잡한 조건도 처리할 수 있다는 점입니다.

이러한 특징들이 GPT-3 같은 초거대 언어 모델의 핵심 기술입니다.

코드 예제

import torch
import torch.nn as nn

def create_causal_mask(seq_len, device='cpu'):
    """
    상삼각 행렬을 0으로 만든 causal mask 생성

    예: seq_len=4
    [[1, 0, 0, 0],   (첫 번째 토큰은 자기만 볼 수 있음)
     [1, 1, 0, 0],   (두 번째는 자기와 첫 번째)
     [1, 1, 1, 0],   (세 번째는 자기와 이전 두 개)
     [1, 1, 1, 1]]   (네 번째는 모두)
    """
    # torch.tril: 하삼각 행렬만 1, 나머지 0
    mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
    return mask.unsqueeze(0).unsqueeze(0)  # (1, 1, seq, seq) for broadcasting

class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1, max_len=512):
        super().__init__()
        self.attention = ProductionMultiHeadAttention(d_model, num_heads, dropout)

        # Causal mask를 미리 생성하여 버퍼로 저장
        # 학습되지 않지만 모델과 함께 저장/로드됨
        self.register_buffer(
            'causal_mask',
            create_causal_mask(max_len)
        )

    def forward(self, x, padding_mask=None):
        # x: (batch, seq_len, d_model)
        seq_len = x.size(1)

        # Causal mask 적용 (필요한 길이만큼만 슬라이싱)
        causal_mask = self.causal_mask[:, :, :seq_len, :seq_len]

        # Padding mask와 결합 (둘 다 있으면 AND 연산)
        if padding_mask is not None:
            # padding_mask: (batch, 1, 1, seq_len)
            # causal_mask: (1, 1, seq_len, seq_len)
            # Broadcasting으로 결합
            mask = causal_mask & padding_mask
        else:
            mask = causal_mask

        return self.attention(x, mask=mask)

설명

이것이 하는 일: 병렬 학습의 효율성을 유지하면서도, 각 위치가 미래 정보 없이 예측하도록 강제하여 실제 생성 조건과 일치시킵니다. 첫 번째로, create_causal_mask 함수가 하삼각 행렬을 생성합니다.

torch.tril은 "triangle lower"의 약자로, 대각선 이하만 1이고 나머지는 0인 행렬을 만듭니다. 예를 들어 seq_len=5이면 5x5 행렬에서 각 행 i는 0부터 i까지만 1이고 나머지는 0입니다.

unsqueeze(0).unsqueeze(0)로 (1, 1, seq, seq) shape으로 만들면 batch와 head 차원에 broadcasting됩니다. 왜 이렇게 하는지는 attention score (batch, heads, seq, seq)에 바로 적용하기 위함입니다.

그 다음으로, register_buffer로 causal_mask를 저장합니다. 이것은 한 번만 계산하고 재사용하기 위함입니다.

max_len=512로 설정하면 512x512 마스크를 미리 만들어두고, 실제 시퀀스가 100이면 [:100, :100]만 슬라이싱하여 사용합니다. Buffer로 등록하면 모델을 GPU로 옮길 때 자동으로 함께 이동하고, state_dict에 포함되어 저장/로드도 자동입니다.

세 번째 단계에서는 padding_mask와 결합합니다. padding_mask는 (batch, 1, 1, seq_len)로 각 배치의 padding 위치를 표시하고, causal_mask는 (1, 1, seq, seq)로 미래 위치를 표시합니다.

& 연산(논리 AND)으로 결합하면 "padding이 아니면서 동시에 미래도 아닌" 위치만 1이 됩니다. Broadcasting 덕분에 batch별로 다른 padding을 처리하면서도 공통 causal 제약을 적용할 수 있습니다.

마지막으로, 결합된 mask가 attention에 전달됩니다. Attention 내부에서 mask==0인 위치의 score를 -inf로 채우면, softmax 후 해당 위치의 가중치가 정확히 0이 됩니다.

예를 들어, 3번째 토큰은 4, 5번째 토큰(미래)과 padding 토큰에 대한 attention이 0이 되어, 오직 0, 1, 2번째 토큰만 참조합니다. 여러분이 이 코드를 사용하면 GPT 스타일의 언어 모델을 구현할 수 있습니다.

훈련 시 teacher forcing으로 병렬 학습하면서도, 각 위치는 독립적으로 예측하므로 추론 시와 동일한 조건이 보장됩니다. 실무에서는 KV-cache를 추가하여 생성 속도를 10배 이상 높이는 최적화도 함께 사용됩니다.

Causal mask 덕분에 이전 토큰의 K, V를 재사용할 수 있기 때문입니다.

실전 팁

💡 torch.tril 대신 torch.triu(torch.ones(...), diagonal=1) == 0으로 만들어도 동일합니다. 취향 차이입니다.

💡 매우 긴 시퀀스(4096+)에서는 causal mask를 bool 타입으로 만들면 메모리를 1/4로 줄일 수 있습니다.

💡 Prefix LM(앞부분은 bidirectional, 뒷부분만 causal)을 구현하려면 mask의 일부만 하삼각으로 만드세요.

💡 디버깅 시 mask.sum(dim=-1)을 확인하면 각 위치가 몇 개 토큰을 볼 수 있는지 알 수 있습니다. 1부터 seq_len까지 증가해야 정상입니다.

💡 Flash Attention 사용 시 causal=True 옵션만 주면 자동으로 처리되어 더 빠르고 메모리 효율적입니다.


10. KV-Cache 최적화 - Auto-regressive 생성 가속화

시작하며

여러분이 GPT 모델로 긴 텍스트를 생성할 때, 각 토큰마다 전체 시퀀스를 다시 처리하느라 속도가 매우 느린 문제를 겪어본 적 있나요? 100개 토큰을 생성하는데 몇 분씩 걸리는 경우입니다.

이런 문제는 순진한(naive) auto-regressive 생성 구현에서 매우 흔합니다. 각 step마다 전체 시퀀스를 다시 인코딩하면, 이미 계산한 Key와 Value를 매번 재계산하게 됩니다.

예를 들어, 100번째 토큰을 생성할 때 0~99번째 토큰의 K, V를 다시 계산하는데, 이들은 이미 99번째 step에서 계산한 것과 동일합니다. 이것은 엄청난 중복 계산입니다.

바로 이럴 때 필요한 것이 KV-Cache입니다. 이전 step에서 계산한 Key와 Value를 캐싱하여 재사용하면, 각 step에서 새 토큰의 K, V만 계산하면 되어 10-100배 빠릅니다.

개요

간단히 말해서, KV-Cache는 각 레이어의 Key와 Value 텐서를 step별로 누적 저장하고, 다음 step에서 새 K, V만 추가하여 concatenate하는 최적화 기법입니다. 왜 이 최적화가 필요한지 실무 관점에서 설명하면, 생성 속도가 실시간 서비스의 핵심이기 때문입니다.

ChatGPT 같은 대화 시스템에서 응답이 10초씩 걸리면 사용자 경험이 매우 나쁩니다. KV-Cache를 사용하면 동일한 모델로 10배 빠른 응답이 가능하여, 더 큰 모델을 사용하거나 배치 크기를 늘릴 수 있습니다.

예를 들어, GPT-3 175B 모델도 KV-Cache 없이는 실시간 서비스가 불가능합니다. 전통적인 방법에서는 매 step마다 O(n²) 계산을 했다면, KV-Cache를 사용하면 O(n)으로 줄어듭니다.

n=1024일 때 1024배 차이가 나므로, 실전에서는 필수적인 최적화입니다. 핵심 특징은 첫째, 계산 복잡도를 획기적으로 줄이며, 둘째, 메모리 사용량은 증가하지만(cache 저장) 일반적으로 trade-off가 유리하고, 셋째, 구현이 비교적 간단하여 대부분의 프레임워크에서 지원한다는 점입니다.

이러한 특징들이 실시간 LLM 서비스의 기반 기술입니다.

코드 예제

import torch
import torch.nn as nn

class KVCacheAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.scale = 1.0 / math.sqrt(self.d_k)

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, past_kv=None, use_cache=False):
        # x: 새 토큰 (batch, 1, d_model) 또는 전체 시퀀스 (batch, seq_len, d_model)
        # past_kv: 이전 step의 (K, V) 튜플 또는 None

        batch_size, seq_len, _ = x.size()

        # 현재 step의 Q, K, V 계산
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

        # KV-Cache 처리
        if past_kv is not None:
            # 이전 K, V와 현재 K, V를 concatenate
            past_K, past_V = past_kv
            K = torch.cat([past_K, K], dim=2)  # seq_len 차원에서 연결
            V = torch.cat([past_V, V], dim=2)

        # 다음 step을 위해 현재 K, V 저장
        if use_cache:
            present_kv = (K, V)
        else:
            present_kv = None

        # Scaled Dot-Product Attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        output = torch.matmul(attn_weights, V)
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.out_proj(output)

        return output, present_kv

설명

이것이 하는 일: Auto-regressive 생성 시 중복 계산을 제거하여 각 step에서 새 토큰만 처리하도록 최적화합니다. 첫 번째로, 입력 x가 두 가지 모드로 작동합니다.

초기 step(past_kv=None)에는 전체 프롬프트가 입력되어 (batch, prompt_len, d_model) shape이고, 이후 step에는 새 토큰 하나만 입력되어 (batch, 1, d_model)입니다. 현재 step의 Q, K, V를 계산하는 것은 동일하지만, seq_len이 1이므로 매우 빠릅니다.

왜 Q는 캐싱하지 않는지는 각 step의 Q가 새 토큰에만 대응하기 때문입니다. 그 다음으로, past_kv가 있으면 이전 K, V와 현재 K, V를 concatenate합니다.

torch.cat([past_K, K], dim=2)는 시퀀스 길이 차원(dim=2)에서 연결하여, 예를 들어 past_K가 (batch, heads, 99, d_k)이고 새 K가 (batch, heads, 1, d_k)이면 결과는 (batch, heads, 100, d_k)가 됩니다. 내부적으로 메모리 복사가 일어나지만, 재계산하는 것보다 훨씬 빠릅니다.

세 번째 단계에서는 use_cache=True이면 현재 K, V를 present_kv로 반환합니다. 다음 step에서 이것이 past_kv로 전달되어 계속 누적됩니다.

이렇게 각 레이어마다 KV-cache를 유지하면, 전체 Transformer의 모든 레이어에서 재사용이 가능합니다. 실무에서는 이 cache를 적절히 관리(메모리 초과 시 제거 등)해야 합니다.

마지막으로, Attention 계산은 동일하지만 효율이 매우 다릅니다. Q는 (batch, heads, 1, d_k)이고 K는 (batch, heads, 100, d_k)이므로, matmul 결과는 (batch, heads, 1, 100)입니다.

즉, 새 토큰 하나가 이전 모든 토큰을 참조하는 attention이지만, K, V는 재사용하므로 계산량은 O(n) 수준입니다. 100개 토큰 생성 시 O(100²) 대신 O(100)이 되어 100배 빠릅니다.

여러분이 이 코드를 사용하면 실시간 텍스트 생성이 가능한 서비스를 만들 수 있습니다. ChatGPT, Claude 같은 모든 상용 LLM 서비스는 KV-Cache를 사용합니다.

실무에서는 여러 레이어의 cache를 리스트로 관리하고, 배치 생성 시 beam search와 결합하여 더욱 복잡한 최적화를 적용합니다. 또한 quantization(INT8, FP16)을 cache에 적용하면 메모리를 2-4배 절약할 수 있습니다.

실전 팁

💡 Cache의 메모리 사용량은 2 * num_layers * batch * heads * max_len * d_k입니다. 긴 시퀀스에서는 모델 파라미터보다 클 수 있으니 모니터링하세요.

💡 실무에서는 past_kv를 별도 객체로 관리합니다. dataclass나 namedtuple로 (K, V, seq_len) 등을 묶으면 편합니다.

💡 Multi-turn 대화에서는 cache를 세션별로 유지하여 컨텍스트를 재사용하세요. 프롬프트 재계산을 완전히 제거할 수 있습니다.

💡 Batch 생성 시 일부 시퀀스가 먼저 끝나면 해당 cache를 제거하여 메모리를 절약하세요. Padding을 계속 유지하면 낭비입니다.

💡 Flash Attention v2는 KV-Cache를 더 효율적으로 처리하는 커널을 제공합니다. PagedAttention(vLLM)은 cache를 paging하여 메모리 효율을 극대화합니다.


11. Attention Head Pruning - 모델 경량화 기법

시작하며

여러분이 거대한 Transformer 모델을 모바일이나 엣지 디바이스에 배포하려고 할 때, 모델 크기와 추론 속도 때문에 좌절한 적 있나요? BERT-Large나 GPT-2 같은 모델은 수백 MB에서 GB 단위라 제약이 큰 환경에서는 사용이 어렵습니다.

이런 문제는 실제 on-device AI를 구현할 때 가장 큰 장애물입니다. 클라우드 API를 사용하면 비용과 지연시간이 문제고, 모델을 압축하면 성능이 크게 떨어지는 딜레마가 있습니다.

특히 Multi-Head Attention이 파라미터와 계산량의 큰 부분을 차지하는데, 모든 head가 정말 필요한지 의문입니다. 바로 이럴 때 필요한 것이 Attention Head Pruning입니다.

중요도가 낮은 head를 제거하여 모델 크기와 속도를 30-50% 개선하면서도 성능 손실은 1-2%로 최소화하는 기법입니다.

개요

간단히 말해서, Head Pruning은 각 attention head의 중요도를 측정하고, 낮은 head를 제거하여 모델을 경량화하는 기법입니다. 중요도는 attention weight의 엔트로피, gradient 크기, Taylor expansion 등으로 측정합니다.

왜 이 기법이 유용한지 실무 관점에서 설명하면, 많은 연구에서 Transformer의 일부 head가 거의 기여하지 않는다는 것이 밝혀졌기 때문입니다. 예를 들어, BERT의 144개 head 중 약 40%를 제거해도 성능이 1% 미만으로만 떨어집니다.

이는 head 간 중복성이 크고, 일부 head는 거의 uniform한 attention(모든 위치에 동일한 가중치)을 보여 실질적 정보가 없기 때문입니다. 전통적인 모델 압축(quantization, distillation)에서는 전체 모델을 균일하게 줄였다면, head pruning은 구조적 중복성을 찾아 선택적으로 제거하여 더 효율적입니다.

특히 quantization과 결합하면 시너지 효과가 큽니다. 핵심 특징은 첫째, structured pruning으로 실제 하드웨어 가속 혜택을 받을 수 있고, 둘째, fine-tuning으로 성능을 회복할 수 있으며, 셋째, 레이어와 head별로 다른 전략을 사용할 수 있다는 점입니다.

이러한 특징들이 실용적인 경량화를 가능하게 합니다.

코드 예제

import torch
import torch.nn as nn
import numpy as np

class PrunableMultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.qkv_proj = nn.Linear(d_model, 3 * d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

        # Head별 중요도 점수 (학습 중 업데이트)
        self.head_importance = nn.Parameter(torch.ones(num_heads))

        # Pruning mask (1=유지, 0=제거)
        self.register_buffer('head_mask', torch.ones(num_heads))

    def compute_head_importance(self, attn_weights, output_grad):
        """
        각 head의 중요도를 gradient 기반으로 계산
        attn_weights: (batch, heads, seq, seq)
        output_grad: attention output의 gradient
        """
        # Taylor expansion: importance ≈ weight * gradient
        importance = (attn_weights.abs() * output_grad.abs()).sum(dim=[0, 2, 3])
        # importance: (heads,)
        return importance

    def prune_heads(self, num_heads_to_prune):
        """
        중요도가 낮은 head를 제거
        """
        if num_heads_to_prune == 0:
            return

        # 중요도 기준으로 정렬하여 하위 head 선택
        _, sorted_indices = torch.sort(self.head_importance, descending=False)
        heads_to_prune = sorted_indices[:num_heads_to_prune]

        # Mask 업데이트 (제거할 head를 0으로)
        self.head_mask[heads_to_prune] = 0

        print(f"Pruned heads: {heads_to_prune.tolist()}")
        print(f"Remaining heads: {(self.head_mask == 1).sum().item()} / {self.num_heads}")

    def forward(self, x):
        batch_size, seq_len, _ = x.size()

        # QKV projection
        qkv = self.qkv_proj(x)
        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.d_k)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        Q, K, V = qkv[0], qkv[1], qkv[2]

        # Attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn_weights = F.softmax(scores, dim=-1)

        # Head mask 적용 (pruned head는 0으로)
        # (batch, heads, seq, seq) * (heads, 1, 1)
        head_mask_expanded = self.head_mask.view(1, -1, 1, 1)
        attn_weights = attn_weights * head_mask_expanded

        output = torch.matmul(attn_weights, V)
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.out_proj(output)

        return output, attn_weights

설명

이것이 하는 일: 각 head의 기여도를 측정하고, 불필요한 head를 선택적으로 제거하여 효율성을 높입니다. 첫 번째로, head_importance 파라미터를 관리합니다.

초기에는 모두 1로 설정되며, 학습 중 compute_head_importance를 통해 업데이트됩니다. Taylor expansion 방법은 "중요도


#LLM#Multi-Head Attention#Transformer#Attention Mechanism#Deep Learning#AI

댓글 (0)

댓글을 작성하려면 로그인이 필요합니다.