🤖

본 콘텐츠의 이미지 및 내용은 AI로 생성되었습니다.

⚠️

본 콘텐츠의 이미지 및 내용을 무단으로 복제, 배포, 수정하여 사용할 경우 저작권법에 의해 법적 제재를 받을 수 있습니다.

이미지 로딩 중...

Flash Attention 3과 Rotary Embeddings 완벽 분석 - 슬라이드 1/7
A

AI Generated

2026. 3. 31. · 0 Views

Flash Attention 3과 Rotary Embeddings 완벽 분석

AutoResearch 프로젝트의 train.py에 구현된 Flash Attention 3 커널 선택 로직, Rotary Position Embeddings(RoPE)의 수학적 원리와 구현, 그리고 Sliding Window Attention 패턴을 심도 있게 분석합니다.


목차

  1. Flash Attention 3 Hopper 전용 커널
  2. kernels-community 폴백 구현
  3. Rotary Position Embeddings RoPE 원리
  4. cos/sin 사전 계산과 적용
  5. RoPE 구현 apply_rotary_emb 함수
  6. Sliding Window Attention 패턴 SSSL
  7. 윈도우 사이즈 계산 로직

1. Flash Attention 3 Hopper 전용 커널

시작하며

어느 날 김개발 씨는 AutoResearch의 train.py 첫 부분을 열어보다가 낯선 import 문 하나에 시선이 멈추었습니다. from kernels import get_kernel이라는 한 줄이 전체 어텐션 계산의 운명을 결정짓고 있었기 때문입니다.

개요

**Flash Attention 3(FA3)**은 NVIDIA H100(Hopper 아키텍처) 전용으로 최적화된 어텐션 커널입니다. HBM(High Bandwidth Memory)에 대한 I/O를 극적으로 줄이고, **TMA(Tensor Memory Accelerator)**를 활용하여 메모리 액세스를 비동기적으로 처리합니다. 마치 고속도로에 전용 차선이 생긴 것과 같습니다. 기존 FA2 대비 Hopper GPU에서 압도적인 성능 향상을 보여줍니다.

코드 예제

from kernels import get_kernel

# GPU 컴퓨트_capability 확인
cap = torch.cuda.get_device_capability()

# Hopper(9,0) 전용 varunneal FA3, 아닌 경우 폴백
repo = ("varunneal/flash-attention-3"
       if cap == (9, 0)
       else "kernels-community/flash-attn3")
fa3 = get_kernel(repo).flash_attn_interface

# fa3.flash_attn_func(q, k, v, causal=True, window_size=...)

설명

"AutoResearch 완전 분석 - AI 자율 연구 에이전트" 코스의 세 번째 편에 오신 것을 환영합니다. 앞선 두 편에서는 AutoResearch의 전체 아키텍처와 GPT 모델의 뼈대인 CausalSelfAttention, Grouped Query Attention, MLP, Block 구조를 살펴보았습니다. 이번에는 그 모델이 실제로 어떻게 효율적으로 어텐션을 계산하는지, 그 핵심 기술들을 파헤쳐보겠습니다.

김개발 씨는 최근 팀 프로젝트에서 LLM 학습 성능을 높이는 미션을 받았습니다. "어텐션 계산이 병목이야. 이거 어떻게 최적화하지?" 선배 박시니어 씨가 슬쩍 train.py 파일을 열어 보였습니다. 첫 번째 줄부터 눈길을 사로잡는 코드가 있었습니다.

from kernels import get_kernel이라는 import가 바로 그 시작점입니다. kernels는 Hugging Face에 호스팅된 커스텀 CUDA 커널을 파이썬에서 즉시 로드할 수 있게 해주는 라이브러리입니다. 마치 앱스토어에서 필요한 앱을 다운로드받아 즉시 사용하는 것과 같습니다.

핵심은 다음 줄에 있습니다. torch.cuda.get_device_capability()로 현재 GPU의 컴퓨트 capability를 확인하고, 그에 따라 다른 커널을 선택합니다. Hopper 아키텍처, 즉 compute capability (9, 0)인 H100 GPU에서는 varunneal/flash-attention-3라는 전용 커널을 사용합니다.

왜 Hopper 전용일까요? FA3는 Hopper에 새로 추가된 **TMA(Tensor Memory Accelerator)**라는 하드웨어 기능을 적극적으로 활용합니다. TMA는 GPU 스레드 블록이 직접 글로벌 메모리에 접근하는 대신, 전용 하드웨어 유닛이 메모리를 비동기적으로 텐서 단위로 전송합니다. 마치 택배 기사가 물건을 직접 배달하는 대신, 물류 센터의 자동화 시스템이 일괄 처리하는 것과 같습니다.

이렇게 하면 메모리 액세스 지연시간을 숨길 수 있고, 연산 유닛이 더 많은 시간을 실제 계산에 쓸 수 있습니다. 또한 Hopper의 WGMMA(Warp Group Matrix Multiply-Accumulate) 명령어를 사용하여 연산 자체도 가속화합니다.

FA2와 비교하면, Hopper GPU에서 FA3는 약 2배 빠른 처리 속도를 보여줍니다. 이것이 5분 타임 버짓이라는 엄격한 제약 아래서 더 많은 실험을 돌릴 수 있는 핵심 비결 중 하나입니다.

하지만 여기서 한 가지 중요한 점이 있습니다. Hopper가 아닌 GPU에서는 어떻게 될까요? 바로 다음 카드에서 다룰 폴백(fallback) 로직이 그 답입니다.

실전 팁

  • FA3의 핵심 가속 요소는 TMA(비동기 메모리 전송)와 WGMMA(하드웨어 행렬 곱셈)입니다
  • get_device_capability()로 GPU 아키텍처를 런타임에 감지하여 최적 커널을 자동 선택합니다
  • 이 카드뉴스는 "AutoResearch 완전 분석 - AI 자율 연구 에이전트" 코스의 3/8편입니다

2. kernels-community 폴백 구현

시작하며

김개발 씨가 궁금해서 물었습니다. "그런데 H100이 아니면 FA3를 못 쓰는 건가요?" 박시니어 씨가 미소를 지었습니다. "그래서 폴백 로직이 있는 거지. 실무에서는 항상 최악의 경우까지 대비해야 해요."

개요

kernels-community/flash-attn3는 community가 Triton으로 작성한 Flash Attention 3 호환 커널입니다. Hopper 전용 최적화는 없지만, 모든 CUDA 호환 GPU에서 동작합니다. 마치 정규 업체가 아닌 대체 인력이라도 일단 업무는 처리할 수 있는 것과 같습니다. 핵심은 API 호환성을 유지하여 fa3.flash_attn_func() 호출 방식이 동일하다는 점입니다.

코드 예제

# GPU 아키텍처에 따른 조건부 커널 선택
cap = torch.cuda.get_device_capability()

# Hopper 전용 vs 범용 폴백
repo = ("varunneal/flash-attention-3"
       if cap == (9, 0)
       else "kernels-community/flash-attn3")

# 두 커널 모두 동일한 flash_attn_func API 제공
fa3 = get_kernel(repo).flash_attn_interface
# fa3.flash_attn_func(q, k, v, causal=True, window_size=...)
# API가 동일하므로 호출부 코드를 변경할 필요가 없습니다

설명

실무에서 가장 중요한 원칙 중 하나는 "하나의 GPU에서만 동작하는 코드는 배포 불가"입니다. AutoResearch가 단일 GPU에서 실험을 자율적으로 수행하는 시스템이라면, 다양한 GPU 환경에서 동작해야 합니다.

Karpathy는 이 문제를 아주 우아하게 해결했습니다. 단 세 줄의 조건부 로직으로 Hopper 전용 커널과 범용 폴백 커널 사이를 전환합니다.

varunneal/flash-attention-3은 varunneal이라는 개발자가 작성한 Hopper 전용 FA3 커널입니다. H100의 TMA, WGMMA 같은 하드웨어 기능을 최대한 활용하여 극한의 성능을 뽑아냅니다.

kernels-community/flash-attn3는 커뮤니티가 OpenAI의 Triton 프레임워크로 구현한 버전입니다. Triton은 NVIDIA CUDA보다 높은 수준의 GPU 프로그래밍 언어로, 컴파일러가 하드웨어 최적화를 자동으로 수행합니다. Hopper 전용 명령어는 사용하지 못하지만, 어텐션 계산의 수학적 결과는 동일합니다.

이 두 커널의 핵심 설계 철학은 API 호환성입니다. flash_attn_interface라는 공통 모듈을 통해 flash_attn_func(q, k, v, causal=True, window_size=...)이라는 완전히 동일한 호출 방식을 제공합니다. 이것이 가능한 이유는 kernels 라이브러리가 Hugging Face에서 커널 패키지를 동적으로 로드하기 때문입니다.

마치 플러그인 시스템과 같습니다. USB-C 포트에 어떤 기기를 꽂아도 호환되는 것처럼, get_kernel(repo).flash_attn_interface라는 통일된 인터페이스 뒤에서 실제 구현이 교체됩니다.

이러한 폴백 패턴은 실무에서 매우 중요합니다. RTX 4090(Ampere 기반), RTX 3090, 심지어 클라우드의 T4 GPU에서도 동일한 학습 코드를 실행할 수 있게 해줍니다. AutoResearch의 에이전트가 다양한 환경에서 실험을 수행할 수 있는 기반이 바로 이 폴백 로직입니다.

실전 팁

  • 폴백 패턴은 "하나의 코드로 다양한 환경 지원"의 핵심 전략입니다
  • kernels 라이브러리는 Hugging Face의 커널 허브에서 CUDA/Triton 커널을 동적으로 로드합니다
  • RTX 4090, A100 등 비-Hopper GPU에서도 동일한 학습 코드가 동작합니다
  • 이 카드뉴스는 "AutoResearch 완전 분석 - AI 자율 연구 에이전트" 코스의 3/8편입니다

3. Rotary Position Embeddings RoPE 원리

시작하며

박시니어 씨가 화이트보드에 그림을 그리기 시작했습니다. "어텐션에는 위치 정보가 필요해요. 단어가 문장의 어디에 있는지 알아야 의미를 파악할 수 있으니까요." 김개발 씨가 고개를 끄덕였습니다. "그래서 position embedding이 있는 거죠?" "맞아요. 하지만 전통적인 방식에는 한계가 있었어요."

개요

**Rotary Position Embedding(RoPE)**은 어텐션의 Q, K 벡터에 회전 행렬을 곱하여 위치 정보를 주입하는 기법입니다. 고정된 위치 임베딩을 더하는 방식과 달리, 벡터를 복소평면에서 회전시킵니다. 마치 시계 바늘이 각도로 시간을 표현하는 것과 같습니다. 이 방식은 상대적 위치 관계를 자연스럽게 포착하고, 길이 외삽(generalization)에 강한 장점이 있습니다.

코드 예제

# RoPE의 핵심 수학: 2차원 회전 행렬 적용
# x = [x1, x2] 벡터를 theta 각도만큼 회전
# | cos(theta)  -sin(theta) | | x1 |   | x1*cos - x2*sin |
# | sin(theta)   cos(theta) | | x2 | = | x1*sin + x2*cos |

def apply_rotary_emb(x, cos, sin):
    assert x.ndim == 4  # (B, T, n_head, head_dim)
    d = x.shape[3] // 2  # head_dim의 절반
    x1, x2 = x[..., :d], x[..., d:]
    # 회전 변환: 복소수 곱셈과 동일
    y1 = x1 * cos + x2 * sin
    y2 = x1 * (-sin) + x2 * cos
    return torch.cat([y1, y2], 3)

설명

트랜스포머에서 가장 근본적인 질문 중 하나는 "모델이 단어의 순서를 어떻게 알 수 있을까?"입니다. 어텐션 메커니즘 자체는 순서에 무관(permutation invariant)하기 때문입니다. "나는 사과를 좋아해"와 "사과는 나를 좋아해"는 같은 단어로 구성되어 있지만 완전히 다른 의미입니다.

전통적인 트랜스포머(GPT-2 등)는 학습 가능한 **절대 위치 임베딩(learned positional embedding)**을 사용했습니다. 위치 0, 1, 2, ...마다 고유한 벡터를 학습하고, 토큰 임베딩에 단순히 더했습니다. 하지만 이 방식은 학습 시킨 최대 길이를 넘어서면 동작하지 않는다는 치명적인 단점이 있습니다.

RoPE는 완전히 다른 접근법을 취합니다. 위치 정보를 "더하는" 것이 아니라 "회전시키는" 것입니다.

비유하자면, RoPE는 시계와 같습니다. 시계 바늘은 12시, 3시, 6시처럼 각도로 시간을 표현합니다. 3시에서 6시까지의 각도 차이와, 9시에서 12시까지의 각도 차이는 같습니다(모두 90도). 이처럼 RoPE도 상대적 거리를 각도 차이로 표현합니다.

수학적으로 RoPE는 각 헤드 차원을 2차원 쌍으로 나누고, 각 쌍에 대해 2차원 회전 행렬을 적용합니다. 위치 t에서의 회전 각도는 theta_i = t * (base^(-2i/d))로 계산됩니다. 여기서 base는 보통 10000, d는 head_dim, i는 차원 인덱스입니다.

이 회전의 마법 같은 성질은 어텐션 점수 계산에 있습니다. Q와 K의 내적(Q @ K^T)을 계산할 때, 두 벡터가 각각 다른 위치에서 회전되었다면, 그 내적은 오직 **두 위치의 차이(상대적 거리)**에만 의존하게 됩니다. 마치 시계에서 "3시와 6시의 차이"와 "9시와 12시의 차이"가 같은 것처럼, 위치 2와 위치 5의 관계가 위치 100과 위치 103의 관계와 동일하게 모델링됩니다.

이것이 RoPE가 길이 외삽에 강한 이유입니다. 학습 시킨 길이보다 긴 시퀀스를 처리하더라도, 인접 토큰 간의 상대적 관계는 동일하게 유지됩니다.

AutoResearch의 train.py에서 RoPE는 apply_rotary_emb라는 함수로 구현되어 있으며, Q와 K에만 적용합니다. V에는 위치 정보를 주입하지 않는 것이 RoPE의 표준 설계입니다.

실전 팁

  • RoPE는 Q와 K에만 적용하고, V에는 적용하지 않는 것이 표준 설계입니다
  • 회전 각도는 base=10000, 주파수는 head_dim에 따라 기하급수적으로 감소합니다
  • 상대적 위치만 고려하므로 학습 길이보다 긴 시퀀스에도 잘 일반화됩니다
  • 이 카드뉴스는 "AutoResearch 완전 분석 - AI 자율 연구 에이전트" 코스의 3/8편입니다

4. cos/sin 사전 계산과 적용

시작하며

"매번 어텐션을 계산할 때마다 cos과 sin을 새로 계산하면 비효율적이겠죠?" 김개발 씨의 질문에 박시니어 씨가 고개를 끄덕였습니다. "정확해요. 그래서 사전 계산(precomputation)을 사용하는 거예요. 이건 최적화의 기본 중 하나죠."

개요

AutoResearch는 모델 초기화 시점에 cos/sin 테이블을 미리 계산하여 버퍼에 저장합니다. _precompute_rotary_embeddings 메서드가 이 역할을 담당하며, base=10000의 역주파수(inverse frequency)를 생성하고 각 위치별 회전 각도의 cos/sin 값을 계산합니다. 마치 달력을 한 번 인쇄해두고 매일 펼쳐보는 것과 같습니다. 매번 날짜를 계산할 필요가 없습니다.

코드 예제

def _precompute_rotary_embeddings(self, seq_len, head_dim,
                                     base=10000, device=None):
    device = device or self.transformer.wte.weight.device
    # 0, 2, 4, ..., head_dim-2 까지의 채널 인덱스
    channel_range = torch.arange(0, head_dim, 2,
                                 dtype=torch.float32, device=device)
    # 역주파수: base^(-2i/d) - 저주파부터 고주파까지
    inv_freq = 1.0 / (base ** (channel_range / head_dim))
    # 위치별 각도: outer(t, inv_freq)
    t = torch.arange(seq_len, dtype=torch.float32, device=device)
    freqs = torch.outer(t, inv_freq)
    cos, sin = freqs.cos(), freqs.sin()
    return cos.bfloat16(), sin.bfloat16()

설명

최적화의 첫 번째 원칙은 "**반복되는 계산은 한 번만 하세요""입니다. RoPE에서 가장 비용이 많이 드는 부분은 cos와 sin의 계산입니다. PyTorch의 삼각함수 연산은 GPU에서 그렇게 가벼운 연산이 아닙니다.

AutoResearch는 모델이 생성될 때 _precompute_rotary_embeddings를 호출하여 전체 시퀀스 길이에 대한 cos/sin 테이블을 한 번에 계산합니다. 그리고 register_buffer로 등록하여 모델의 일부로 관리합니다.

코드를 단계별로 분석해보겠습니다. 먼저 channel_range는 0부터 head_dim까지 2씩 증가하는 인덱스 시퀀스입니다. head_dim이 128이면 [0, 2, 4, ..., 126]이 됩니다. 이는 각 2차원 쌍을 구분하기 위한 인덱스입니다.

다음으로 inv_freq가 핵심입니다. 1.0 / (base ** (channel_range / head_dim)) 공식에 따라 역주파수를 계산합니다. base=10000이므로, 저차원(작은 i)에서는 큰 역주파수(느린 회전), 고차원(큰 i)에서는 작은 역주파수(빠른 회전)를 얻습니다. 이것은 신호 처리에서 주파수 대역을 나누는 것과 동일한 원리입니다.

torch.outer(t, inv_freq)는 모든 위치-주파수 조합에 대한 각도 행렬을 만듭니다. 크기는 (seq_len, head_dim/2)입니다. 그리고 마지막으로 .cos().sin()으로 삼각함수 값을 구합니다.

주목할 점은 seq_lenconfig.sequence_len * 10으로 설정한다는 것입니다. 즉, 실제 학습 시퀀스 길이의 10배까지 cos/sin 값을 미리 계산합니다. 이는 긴 시퀀스에 대한 외삽을 지원하기 위한 여유분입니다.

계산된 cos/sin은 bf16(bfloat16)으로 변환됩니다. 어텐션 계산의 정밀도 요구사항이 그리 높지 않기 때문에, fp32 대신 bf16을 사용하여 메모리 사용량을 절반으로 줄입니다. 또한 register_buffer(..., persistent=False)로 등록하여 모델 가중치 파일에 저장하지 않고, 초기화 시점에 다시 계산하도록 합니다.

포워드 패스에서는 단순히 self.cos[:, :T], self.sin[:, :T]로 현재 시퀀스 길이만큼 슬라이싱하여 사용합니다. 한 번 계산해두고 필요한 만큼만 자르는 것입니다. 달력에서 오늘 날짜만 보는 것과 같습니다.

실전 팁

  • seq_len을 실제 길이의 10배로 설정하여 긴 시퀀스 외삽에 대비합니다
  • cos/sin을 bf16으로 저장하여 메모리 사용량을 절반으로 줄입니다
  • persistent=False로 설정하여 모델 체크포인트 크기를 줄입니다
  • 이 카드뉴스는 "AutoResearch 완전 분석 - AI 자율 연구 에이전트" 코스의 3/8편입니다

5. RoPE 구현 apply rotary emb 함수

시작하며

이제 실제 코드로 돌아가 봅시다. 김개발 씨가 화면의 8줄짜리 함수를 보며 감탄했습니다. "이게 RoPE 전체 구현이라니..." 박시니어 씨가 말했습니다. "최고의 코드는 불필요한 것이 없는 코드예요."

개요

apply_rotary_emb 함수는 AutoResearch에서 RoPE의 핵심 구현입니다. 입력 텐서의 마지막 차원을 반으로 나누어 (x1, x2) 쌍을 만들고, 2차원 회전 공식을 적용합니다. 수학적으로 복소수 곱셈 (x1 + ix2) * (cos + i*sin)의 실수부와 허수부를 계산하는 것과 완전히 동일합니다. 마치 시계의 시침과 분침이 독립적으로 회전하면서 시간을 표현하는 것과 같습니다.

코드 예제

def apply_rotary_emb(x, cos, sin):
    # x: (B, T, n_head, head_dim)
    assert x.ndim == 4
    d = x.shape[3] // 2
    # head_dim을 두 그룹으로 분할
    x1, x2 = x[..., :d], x[..., d:]
    # 2차원 회전 행렬 적용
    y1 = x1 * cos + x2 * sin      # 실수부
    y2 = x1 * (-sin) + x2 * cos   # 허수부
    return torch.cat([y1, y2], 3)

설명

코드를 읽을 때 가장 좋은 방법은 "**이 코드가 무엇을 하는가?""가 아니라 "**왜 이렇게 구현했는가?""를 묻는 것입니다. 이 8줄의 함수는 RoPE의 전체 구현이지만, 각 줄에는 의도적인 설계 결정이 담겨 있습니다.

먼저 assert x.ndim == 4에서 4차원 텐서를 요구합니다. (Batch, Time, n_head, head_dim) 형태입니다. 어텐션 계산 전에 헤드 분할이 이미 완료된 상태에서 호출됩니다.

d = x.shape[3] // 2는 head_dim을 반으로 나눕니다. head_dim이 128이면 d는 64입니다. 이는 각 2차원 쌍을 처리하기 위한 것입니다. 차원 0과 1이 하나의 쌍, 2와 3이 또 하나의 쌍, 이런 식으로 64개의 2차원 회전이 병렬로 수행됩니다.

x[..., :d]x[..., d:]는 텐서를 두 부분으로 슬라이싱합니다. ...은 앞의 3차원(Batch, Time, n_head)을 그대로 유지하라는 의미입니다.

회전 공식의 본체인 y1 = x1 * cos + x2 * siny2 = x1 * (-sin) + x2 * cos는 2차원 회전 행렬을 벡터화한 것입니다. 이것은 복소수 표현 (x1 + i*x2) * (cos(t) + i*sin(t))를 전개한 것과 수학적으로 완전히 동일합니다.

여기서 cos와 sin 텐서의 shape이 중요합니다. 사전 계산에서 cos[None, :, None, :]로 차원을 확장했으므로, (1, T, 1, d) 형태입니다. 브로드캐스팅에 의해 (B, T, n_head, d)의 x1, x2와 자동으로 호환됩니다.

마지막으로 torch.cat([y1, y2], 3)에서 회전된 두 부분을 다시 하나의 텐서로 결합합니다. dim=3(head_dim 차원)을 따라 연결하므로, 최종 shape은 원래의 (B, T, n_head, head_dim)과 동일합니다.

이 함수는 CausalSelfAttention의 forward에서 Q와 K에 각각 호출됩니다. 그 직후에 norm(q), norm(k)으로 RMSNorm을 적용하는 것도 흥미로운 포인트입니다. RoPE 적용 후 Q와 K의 크기를 정규화하여 어텐션 점수의 안정성을 높입니다.

이 함수의 아름다움은 효율성에 있습니다. PyTorch의 벡터화된 연산을 사용하므로 GPU에서 64개의 2차원 회전이 단일 커널 호출로 병렬 처리됩니다. 루프가 없고, 조건문이 없고, 메모리 할당이 최소화되어 있습니다.

실전 팁

  • 이 함수는 복소수 곱셈을 실수 연산으로 풀어쓴 것으로, 수학적으로 완벽히 동일한 결과를 보장합니다
  • Q와 K에만 적용되며, V에는 RoPE를 적용하지 않는 것이 표준 설계입니다
  • RoPE 적용 후 RMSNorm을 추가로 적용하여 어텐션 스코어의 안정성을 확보합니다
  • 이 카드뉴스는 "AutoResearch 완전 분석 - AI 자율 연구 에이전트" 코스의 3/8편입니다

6. Sliding Window Attention 패턴 SSSL

시작하며

김개발 씨가 GPTConfig를 다시 보며 의아해했습니다. "window_pattern이 SSSL로 설정되어 있는데, 이건 뭔가요?" 박시니어 씨가 설명했습니다. "어텐션이 모든 이전 토큰을 보면 비효율적이에요. 창문 너머는 가려두고 가까운 것만 보는 거죠."

개요

**Sliding Window Attention(SWA)**은 각 토큰이 자신의 근처 토큰들에게만 어텐션을 허용하는 기법입니다. AutoResearch는 SSSL 패턴을 사용하여 하위 레이어에서는 반(S = Short)만큼의 컨텍스트만 보고, 마지막 레이어에서는 전체(L = Long) 컨텍스트를 봅니다. 마치 돋보기를 사용할 때 가까운 곳부터 점차 멀리 확대해가는 것과 같습니다. 연산량을 줄이면서도 전체 문맥을 파악할 수 있습니다.

코드 예제

# GPTConfig의 윈도우 패턴 설정
@dataclass
class GPTConfig:
    window_pattern: str = "SSSL"  # S=Short, L=Long

# 실제 어텐션 호출 (CausalSelfAttention.forward)
y = fa3.flash_attn_func(
    q, k, v,
    causal=True,
    window_size=window_size  # 레이어별 다른 윈도우 적용
)

# S 레이어: window_size = sequence_len // 2
# L 레이어: window_size = sequence_len (전체)

설명

트랜스포머의 어텐션 메커니즘은 시퀀스 길이에 대해 **이차원(O(n^2))**의 연산 복잡도를 가집니다. 시퀀스가 두 배 길어지면 어텐션 연산량은 네 배가 됩니다. 이것은 긴 시퀀스를 처리할 때 가장 큰 병목입니다.

하지만 모든 토큰이 모든 다른 토큰과 동등하게 상호작용해야 할까요? 실제로 자연어에서는 인접한 단어들 사이의 관계가 먼 단어들 사이의 관계보다 훨씬 강합니다. "어제"와 "밥을"은 "어제"와 "맛있었다"보다 훨씬 가까운 문맥 관계를 가집니다.

이러한 관찰에서 **Sliding Window Attention(SWA)**이 탄생했습니다. 각 토큰이 이전 N개의 토큰에게만 어텐션을 허용하는 방식입니다. 연산량이 O(n*w)로 줄어듭니다(w는 윈도우 크기).

AutoResearch의 SSSL 패턴은 더 정교한 설계입니다. 각 레이어마다 다른 윈도우 크기를 적용합니다. 8개 레이어라면, 패턴 "SSSL"은 다음과 같이 매핑됩니다: 레이어 0(S), 레이어 1(S), 레이어 2(S), 레이어 3(L), 레이어 4(S), 레이어 5(S), 레이어 6(S), 레이어 7(L).

여기서 Ssequence_len // 2를 의미합니다. 예를 들어 sequence_len이 2048이면, S 레이어의 윈도우는 1024입니다. L은 전체 2048입니다.

왜 이런 패턴일까요? 하위 레이어는 주로 국소 패턴(문법, 구문 구조)을 학습합니다. 단어 사이의 인접 관계를 파악하는 데는 반 길이의 컨텍스트면 충분합니다. 반면 상위 레이어는 전역 문맥(문단 전체의 주제, 논리적 흐름)을 파악해야 합니다.

마지막 레이어를 L로 강제하는 것도 의도적인 설계입니다. _compute_window_sizes에서 window_sizes[-1] = (long_window, 0)으로 마지막 레이어를 항상 전체 컨텍스트로 설정합니다. 최종 예측을 내리기 전에는 전체 문맥을 반드시 참조해야 하기 때문입니다.

SSSL 패턴의 효과는 estimate_flops에서 정량화됩니다. S 레이어는 L 레이어보다 어텐션 FLOPs가 절반입니다. 8개 레이어 중 6개가 S, 2개가 L이라면, 어텐션 연산량을 약 25% 절감할 수 있습니다. 5분 타임 버짓에서 이 절감은 더 많은 학습 스텝을 의미합니다.

실전 팁

  • S 레이어에서는 반 길이, L 레이어에서는 전체 길이의 윈도우를 사용합니다
  • 마지막 레이어는 항상 L로 강제되어 전체 문맥을 보장합니다
  • FA3의 window_size 파라미터를 통해 하드웨어 수준에서 윈도우가 최적화됩니다
  • 이 카드뉴스는 "AutoResearch 완전 분석 - AI 자율 연구 에이전트" 코스의 3/8편입니다

7. 윈도우 사이즈 계산 로직

시작하며

"이 패턴이 실제 코드로 어떻게 구현되어 있을까요?" 김개발 씨의 마지막 질문이었습니다. 박시니어 씨가 _compute_window_sizes 메서드를 가리키며 말했습니다. "이 함수가 핵심이에요. 패턴 문자열을 실제 숫자 배열로 변환하죠."

개요

_compute_window_sizeswindow_pattern 문자열(예: "SSSL")을 레이어별 윈도우 크기 튜플의 리스트로 변환합니다. Ssequence_len // 2, Lsequence_len에 매핑되며, 패턴이 레이어 수보다 짧으면 순환하며, 마지막 레이어는 항상 전체 길이를 보장합니다. 마치 재봉틀의 패턴이 천을 반복적으로 잘라내듯, 짧은 패턴 문자열이 여러 레이어에 걸쳐 반복 적용됩니다.

코드 예제

def _compute_window_sizes(self, config):
    pattern = config.window_pattern.upper()
    assert all(c in "SL" for c in pattern)
    long_window = config.sequence_len     # 예: 2048
    short_window = long_window // 2       # 예: 1024
    char_to_window = {
        "L": (long_window, 0),
        "S": (short_window, 0)
    }
    window_sizes = []
    for layer_idx in range(config.n_layer):
        # 패턴을 순환하며 레이어에 매핑
        char = pattern[layer_idx % len(pattern)]
        window_sizes.append(char_to_window[char])
    # 마지막 레이어는 항상 전체 컨텍스트
    window_sizes[-1] = (long_window, 0)
    return window_sizes

설명

설계 패턴에서 "선언적 설정, 명령적 변환"이라는 원칙이 있습니다. 사용자는 "SSSL"이라는 의미 있는 패턴을 선언하고, 시스템이 이를 실제 숫자 배열로 변환합니다. 이 분리는 설정의 가독성과 유연성을 동시에 높입니다.

_compute_window_sizes는 GPT 클래스의 __init__에서 호출되어 self.window_sizes에 저장됩니다. 이 배열은 이후 모든 포워드 패스에서 CausalSelfAttention의 forward에 전달됩니다.

코드의 첫 줄인 assert all(c in "SL" for c in pattern)은 방어적 프로그래밍의 좋은 예입니다. 오타나 잘못된 설정값(예: "SSXL")을 조기에 감지합니다. 실패하면 명확한 에러 메시지가 반환됩니다.

char_to_window 딕셔너리는 S와 L을 실제 숫자 튜플에 매핑합니다. 각 튜플의 두 번째 요소는 항상 0인데, 이는 FA3의 window_size 파라미터가 (window_size, 0) 형태의 튜플을 받기 때문입니다.

핵심 로직은 pattern[layer_idx % len(pattern)]입니다. 모듈로 연산(%)을 사용하여 패턴을 순환시킵니다. 예를 들어 8개 레이어에 "SSSL" 패턴을 적용하면: 인덱스 0%4=0(S), 1%4=1(S), 2%4=2(S), 3%4=3(L), 4%4=0(S), 5%4=1(S), 6%4=2(S), 7%4=3(L).

결과적으로 [S, S, S, L, S, S, S, L]이라는 배열이 생성됩니다. 즉, 레이어 3과 레이어 7만 전체 컨텍스트를 보고, 나머지는 절반만 봅니다.

마지막 줄 window_sizes[-1] = (long_window, 0)이 가장 중요합니다. 이것은 안전 장치입니다. 어떤 패턴이든 마지막 레이어는 항상 전체 컨텍스트를 봅니다. 모델의 최종 출력 전에는 전체 시퀀스의 정보가 집약되어야 하기 때문입니다.

이 윈도우 크기는 estimate_flops에서도 사용됩니다. FLOPs 추정 시 각 레이어의 실제 윈도우 크기를 반영하여, MFU(Model FLOPs Utilization)를 정확하게 계산합니다. 학습 로그에 표시되는 MFU 퍼센티지는 이 계산을 기반으로 합니다.

AutoResearch의 에이전트가 이 패턴을 수정하여 실험할 수도 있습니다. 예를 들어 "LLLL"로 설정하면 모든 레이어에서 전체 컨텍스트를 보게 되지만, 연산량이 늘어나 5분 안에 더 적은 스텝을 돌게 됩니다. "SSSS"로 설정하면 연산량은 줄지만, 긴 문맥 파악 능력이 저하될 수 있습니다. 이런 트레이드오프를 탐색하는 것이 AutoResearch의 핵심 미션입니다.

실전 팁

  • 패턴 문자열은 assert로 검증하여 잘못된 설정을 조기에 차단합니다
  • 모듈로 연산으로 짧은 패턴을 여러 레이어에 걸쳐 순환 적용합니다
  • 마지막 레이어 강제 L 설정은 전체 문맥 파악을 위한 안전 장치입니다
  • 에이전트가 이 패턴을 수정하여 연산량-성능 트레이드오프를 자율적으로 탐색할 수 있습니다
  • 다음 카드뉴스에서는 Value Embeddings(ResFormer) 아키텍처를 다룹니다
  • 이 카드뉴스는 "AutoResearch 완전 분석 - AI 자율 연구 에이전트" 코스의 3/8편입니다

#Python#FlashAttention3#RoPE#SlidingWindow#Transformer

댓글 (0)

댓글을 작성하려면 로그인이 필요합니다.

함께 보면 좋은 카드 뉴스

이전3/3
다음
Flash Attention 3과 Rotary Embeddings 완벽 분석 | CodeDeck