들어가며
Transformer 아키텍처는 자연어 처리(NLP)를 넘어 컴퓨터 비전, 음성, 단백질 구조 예측 등 거의 모든 AI 분야를 지배하고 있다. 그러나 Transformer의 핵심인 Self-Attention 메커니즘은 시퀀스 길이 에 대해 의 시간 및 메모리 복잡도를 가지며, 이것이 긴 시퀀스를 처리하는 데 가장 큰 병목이 되어 왔다.
이 문제를 해결하기 위해 수많은 근사(approximate) 어텐션 기법이 제안되었지만, 대부분 실제 wall-clock 속도 향상은 미미했고 모델 품질 저하가 동반되었다. FlashAttention은 완전히 다른 접근법을 택한다. 어텐션 알고리즘 자체를 근사하지 않고, GPU 메모리 계층(memory hierarchy)을 인식하는 IO-aware 알고리즘을 설계하여 정확한(exact) 어텐션을 훨씬 빠르고 메모리 효율적으로 계산한다.
핵심 아이디어: 알고리즘의 수학적 결과를 바꾸지 않으면서, GPU의 SRAM과 HBM 간 데이터 이동(IO)을 최소화하는 타일링(tiling) 기법으로 어텐션을 재구성한다.
이 글에서는 Tri Dao 등이 2022년 NeurIPS에서 발표한 “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness” 논문을 상세히 리뷰한다.
논문 개요
| 항목 | 내용 |
|---|---|
| 제목 | FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness |
| 저자 | Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré |
| 소속 | Stanford University, University at Buffalo |
| 학회 | NeurIPS 2022 |
| 핵심 키워드 | Attention, IO-Awareness, Tiling, Memory Hierarchy, GPU Optimization |
| 코드 | github.com/Dao-AILab/flash-attention |
배경: 왜 Standard Attention이 느린가?
Standard Attention의 계산 과정
Transformer의 Self-Attention은 쿼리(), 키(), 밸류() 행렬을 입력으로 받아 다음과 같이 계산한다.
여기서:
– : 각각 쿼리, 키, 밸류 행렬 (은 시퀀스 길이, 는 head 차원)
– : 키 벡터의 차원 (스케일링 팩터)
– : 행(row) 단위로 적용
표준 구현의 문제점
표준 PyTorch 구현에서는 다음 단계를 순차적으로 수행한다:
- 계산 → HBM에 저장
- 계산 → HBM에 저장
- Dropout 적용 (학습 시)
- 계산
여기서 결정적인 병목은 크기의 중간 행렬 와 를 GPU의 고대역폭 메모리(HBM)에 읽고 쓰는 것이다.
GPU 메모리 계층 이해
이 논문의 핵심 통찰을 이해하려면 GPU 메모리 구조를 알아야 한다.
| 메모리 종류 | 용량 | 대역폭 | 특징 |
|---|---|---|---|
| SRAM (온칩) | ~20MB (A100 기준) | ~19TB/s | 매우 빠르지만 매우 작음 |
| HBM (오프칩) | 40~80GB (A100) | ~1.5~2.0TB/s | 크지만 SRAM 대비 10배 이상 느림 |
핵심 관찰: 현대 GPU에서 대부분의 연산은 compute-bound가 아니라 memory-bound이다. 즉, 실제 산술 연산보다 데이터를 HBM에서 읽고 쓰는 시간이 전체 실행 시간을 지배한다.
어텐션의 경우 연산량(FLOPs)은 이지만, HBM 접근(IO)은 이다. 가 64~128로 작은 반면 은 수천~수만이므로, IO 비용이 실질적 병목이 된다.
FlashAttention의 핵심 기법
기법 1: 타일링(Tiling)
FlashAttention의 첫 번째 핵심은 어텐션 행렬을 절대 전체적으로 구체화(materialize)하지 않는 것이다.
대신, 입력 , , 를 블록(block) 단위로 나누어 SRAM에 올리고, 각 블록에 대한 어텐션 출력을 점진적으로(incrementally) 계산한다.
알고리즘 개요:
1. Q를 블록 단위로 나눔: Q_1, Q_2, ..., Q_{T_r}
2. K, V를 블록 단위로 나눔: K_1, K_2, ..., K_{T_c}
3. 각 Q 블록에 대해:
- 모든 K, V 블록을 순회하며
- 블록 단위 어텐션을 SRAM 내에서 계산
- 결과를 점진적으로 누적
4. 최종 출력 O를 HBM에 한 번만 기록
블록 크기는 SRAM 용량에 맞추어 설정한다. A100 GPU의 경우 SRAM이 약 20MB이므로, 블록 크기를 적절히 조절하여 , , 의 블록과 중간 결과가 모두 SRAM에 들어가도록 한다.
기법 2: Online Softmax (Tiling과 Softmax의 양립)
타일링의 가장 큰 기술적 난관은 softmax이다. Softmax는 한 행의 모든 원소를 알아야 정규화 상수를 계산할 수 있다:
그런데 타일링에서는 한 번에 의 일부 블록만 보므로, 전체 행의 합을 모른다. FlashAttention은 Online Softmax 기법을 활용하여 이 문제를 해결한다.
Online Softmax의 원리
수치 안정성을 위해 softmax는 보통 행의 최댓값 을 빼서 계산한다:
Online Softmax는 이 최댓값 과 정규화 상수 를 점진적으로 갱신한다. 새로운 블록이 들어올 때마다:
이미 계산된 출력 도 스케일링 팩터를 곱해서 보정한다:
여기서:
– : 지금까지 본 블록들의 최댓값
– : 현재 블록의 최댓값
– : 지금까지의 softmax 분모(정규화 상수)
– : 현재 블록의 softmax 분모
– : 현재 블록의 softmax 값 (비정규화)
이 기법의 핵심: 모든 블록 처리가 끝나면, 최종 결과는 표준 어텐션과 수학적으로 동일하다. 어떤 근사도 없다.
기법 3: 역전파를 위한 재계산(Recomputation)
학습 시 역전파(backpropagation)를 위해 표준 구현은 크기의 와 행렬을 HBM에 저장해야 한다. 이것이 메모리의 원인이다.
FlashAttention은 이를 저장하지 않고, 역전파 시 , , 와 출력 , softmax 통계량(, )만을 저장한다. 역전파 시 필요한 와 는 다시 계산(recompute)한다.
| 저장 항목 | 표준 어텐션 | FlashAttention |
|---|---|---|
| () | ❌ 저장 안 함 | |
| () | ❌ 저장 안 함 | |
| softmax 통계 (, ) | – | |
| 총 메모리 |
재계산에 드는 추가 FLOPs가 있지만, HBM 접근 감소로 인한 속도 이득이 훨씬 크기 때문에 전체적으로 더 빠르다. 이는 gradient checkpointing의 아이디어와 유사하지만, 어텐션에 특화되어 훨씬 효율적이다.
알고리즘 상세 (Forward Pass)
FlashAttention의 전체 Forward 알고리즘을 의사코드로 정리한다.
def flash_attention_forward(Q, K, V, B_r, B_c):
"""
Q, K, V: (N, d) 행렬
B_r: Q 블록 크기 (행)
B_c: K/V 블록 크기 (열)
"""
N, d = Q.shape
# 블록 수 계산
T_r = ceil(N / B_r) # Q 블록 수
T_c = ceil(N / B_c) # K, V 블록 수
# 출력 및 통계량 초기화
O = zeros(N, d) # 출력
m = full(N, -inf) # 행별 최댓값
l = zeros(N) # 행별 softmax 분모
# 외부 루프: K, V 블록 순회
for j in range(T_c):
# K_j, V_j를 HBM에서 SRAM으로 로드
K_j = K[j*B_c : (j+1)*B_c] # (B_c, d)
V_j = V[j*B_c : (j+1)*B_c] # (B_c, d)
# 내부 루프: Q 블록 순회
for i in range(T_r):
# Q_i, O_i, m_i, l_i를 SRAM으로 로드
Q_i = Q[i*B_r : (i+1)*B_r] # (B_r, d)
O_i = O[i*B_r : (i+1)*B_r]
m_i = m[i*B_r : (i+1)*B_r]
l_i = l[i*B_r : (i+1)*B_r]
# 블록 어텐션 스코어 계산 (SRAM 내)
S_ij = Q_i @ K_j.T / sqrt(d) # (B_r, B_c)
# 블록의 행별 최댓값과 softmax 분모
m_block = S_ij.max(dim=-1) # (B_r,)
P_block = exp(S_ij - m_block) # (B_r, B_c)
l_block = P_block.sum(dim=-1) # (B_r,)
# Online softmax 갱신
m_new = maximum(m_i, m_block)
l_new = exp(m_i - m_new) * l_i + exp(m_block - m_new) * l_block
# 출력 갱신
O_i = (l_i * exp(m_i - m_new)).unsqueeze(-1) * O_i \
+ (exp(m_block - m_new)).unsqueeze(-1) * (P_block @ V_j)
O_i = O_i / l_new.unsqueeze(-1)
# 통계량 갱신
m_i = m_new
l_i = l_new
# SRAM에서 HBM으로 기록
O[i*B_r : (i+1)*B_r] = O_i
m[i*B_r : (i+1)*B_r] = m_i
l[i*B_r : (i+1)*B_r] = l_i
return O, m, l
IO 복잡도 분석
논문의 핵심 이론적 결과는 다음과 같다.
정리 1: FlashAttention의 HBM 접근 횟수는 이다. 여기서 은 SRAM 크기이다.
표준 어텐션의 HBM 접근 와 비교하면, (일반적으로 , )이므로 FlashAttention이 훨씬 적은 HBM 접근을 한다.
| 방법 | HBM 접근 | 추가 메모리 |
|---|---|---|
| 표준 어텐션 | ||
| FlashAttention |
구체적 예시로, , , (SRAM 일부)일 때:
– 표준: 접근
– FlashAttention: 접근
약 24배 감소이다.
Block-Sparse FlashAttention
논문은 추가로 Block-Sparse FlashAttention을 제안한다. 이는 어텐션 패턴에 대한 사전 지식(예: local attention, strided attention)이 있을 때, 불필요한 블록의 계산을 아예 건너뛰는 기법이다.
마스크 행렬 을 정의하고, 인 블록은 계산하지 않는다.
이 경우 IO 복잡도는:
여기서 는 마스크에서 0이 아닌 블록의 비율이다. 예를 들어 local attention으로 전체의 1/4만 계산하면 IO가 추가 4배 감소한다.
실험 결과
속도 비교
논문은 A100 GPU에서 다양한 시퀀스 길이에 대해 벤치마크를 수행했다.
| 시퀀스 길이 | PyTorch 표준 | Megatron (최적화) | FlashAttention | 속도 향상 |
|---|---|---|---|---|
| 512 | 기준 | 1.2× | 2.4× | vs 표준 |
| 1K | 기준 | 1.3× | 2.8× | vs 표준 |
| 2K | 기준 | 1.3× | 3.0× | vs 표준 |
| 4K | 기준 | OOM | 3.5× | vs 표준 |
| 8K | OOM | OOM | 가능 | – |
| 16K | OOM | OOM | 가능 | – |
FlashAttention은 시퀀스 길이가 길어질수록 상대적 이점이 커진다. 표준 구현이 OOM(Out of Memory)으로 실행 불가능한 길이에서도 FlashAttention은 정상 동작한다.
기존 근사 어텐션과의 비교
논문이 특히 강조하는 점은, 이론적으로 또는 복잡도를 가진 근사 어텐션 기법들이 실제로는 FlashAttention보다 느리다는 것이다.
| 방법 | 이론적 복잡도 | 실제 속도 (A100) | 정확도 | FLOPS 관점 |
|---|---|---|---|---|
| Standard Attention | 기준 | 정확 | ||
| Linformer | 기준 대비 느림 | 근사 | ||
| Performer | 기준 대비 느림 | 근사 | ||
| Local Attention | 기준 대비 유사 | 근사 | ||
| Longformer | 기준 대비 유사 | 근사 | ||
| FlashAttention | 2~4× 빠름 | 정확 |
이 결과는 FLOPs와 wall-clock 시간이 반드시 비례하지 않는다는 중요한 교훈을 준다. IO-aware 최적화가 이론적 복잡도보다 실질적으로 더 중요할 수 있다.
End-to-End 모델 학습 결과
논문은 FlashAttention을 실제 모델 학습에 적용하여 검증했다.
GPT-2 학습 (OpenWebText)
| 모델 | 컨텍스트 길이 | 학습 속도 | Perplexity |
|---|---|---|---|
| GPT-2 small (표준) | 1K | 기준 | 18.2 |
| GPT-2 small (Flash) | 1K | 1.5× 빠름 | 18.2 |
| GPT-2 small (Flash) | 4K | 가능 | 17.6 |
| GPT-2 medium (표준) | 1K | 기준 | 14.2 |
| GPT-2 medium (Flash) | 1K | 1.3× 빠름 | 14.2 |
| GPT-2 medium (Flash) | 4K | 가능 | 13.7 |
주목할 점:
– 동일 컨텍스트 길이에서 속도 향상과 함께 동일한 모델 품질 유지
– 더 긴 컨텍스트 사용이 가능해져 perplexity 추가 개선
Long-Range Arena (LRA) 벤치마크
LRA는 긴 시퀀스 처리 능력을 평가하는 벤치마크로, 1K~4K 길이의 시퀀스를 다룬다.
| 방법 | 평균 정확도 | ListOps | Text | Retrieval | Image | Pathfinder |
|---|---|---|---|---|---|---|
| Transformer | 61.36 | 36.37 | 64.27 | 57.46 | 42.44 | 71.40 |
| Performer | 51.18 | 18.01 | 65.40 | 53.82 | 42.77 | 77.05 |
| Local Attn | 52.98 | 15.82 | 52.98 | 53.39 | 41.46 | 66.63 |
| FlashAttention | 62.85 | 37.07 | 64.65 | 58.37 | 44.09 | 71.63 |
| Flash + Block-Sparse | 63.10 | 37.50 | 65.21 | 58.50 | 43.94 | 71.87 |
FlashAttention이 근사 기법들보다 일관되게 높은 성능을 달성했다.
BERT 학습 속도
| 설정 | 학습 시간 | MLM 정확도 |
|---|---|---|
| BERT-large (Megatron) | 기준 (100%) | 기준 |
| BERT-large (FlashAttention) | 85% (-15%) | 동일 |
Ablation Study
블록 크기의 영향
블록 크기는 SRAM 활용률과 직접 연관된다.
| 블록 크기 () | 상대 속도 | 비고 |
|---|---|---|
| 32 | 0.75× | SRAM 미활용 |
| 64 | 0.90× | – |
| 128 | 1.0× | 최적 |
| 256 | 0.95× | 레지스터 pressure 증가 |
최적 블록 크기는 GPU 아키텍처(SRAM 크기, 워프 수)에 따라 달라진다.
Recomputation의 영향
| 설정 | Forward 속도 | Backward 속도 | 메모리 | 전체 학습 속도 |
|---|---|---|---|---|
| 표준 (S, P 저장) | 기준 | 기준 | 기준 | |
| Recomputation 사용 | 동일 | 약간 느림 (-5%) | +15% 빠름 |
재계산으로 인한 FLOPs 증가(~30%)에도 불구하고, HBM 접근 감소 효과가 훨씬 커서 전체적으로 더 빠르다.
IO 복잡도 하한 증명
논문은 FlashAttention의 IO 복잡도가 최적에 가까움을 증명한다.
정리 2: 표준 어텐션을 정확히 계산하는 모든 알고리즘은 HBM 접근이 필요하다.
FlashAttention의 상한이 이므로, 점근적으로 최적(asymptotically optimal)이다.
기존 방법론과의 종합 비교
| 특성 | Standard | Sparse (Longformer) | Linear (Performer) | FlashAttention |
|---|---|---|---|---|
| 시간 복잡도 (FLOPs) | ||||
| 메모리 복잡도 | ||||
| HBM 접근 | ||||
| 정확도 | 정확 | 근사 | 근사 | 정확 |
| 실제 wall-clock | 기준 | ~1× | ~0.8× | 2~4× 빠름 |
| 구현 난이도 | 쉬움 | 중간 | 중간 | 어려움 (CUDA 커널) |
| 드롭인 교체 | – | 모델 수정 필요 | 모델 수정 필요 | 기존 모델 호환 |
논문의 강점
1. 실용적 영향력
FlashAttention은 이론과 실용의 간극을 극적으로 좁힌 연구이다. 발표 이후:
– PyTorch 2.0에 torch.nn.functional.scaled_dot_product_attention으로 통합
– Hugging Face Transformers 라이브러리에 기본 옵션으로 채택
– GPT-4, LLaMA, Mistral 등 주요 LLM 학습에 활용
– 후속 FlashAttention-2, FlashAttention-3으로 발전
2. 근본적인 관점 전환
“알고리즘을 바꾸지 말고 구현을 바꿔라”는 메시지가 명확하다. 수학적으로 동일한 결과를 내면서도 하드웨어 특성을 고려한 구현이 얼마나 중요한지를 보여주었다.
3. 엄밀한 이론적 분석
IO 복잡도의 상한과 하한을 모두 증명하여, FlashAttention이 단순히 실험적으로 좋은 것이 아니라 이론적으로도 최적임을 보였다.
4. 범용성
어떤 Transformer 모델에도 드롭인(drop-in)으로 적용 가능하며, 모델 아키텍처 변경이 필요 없다.
논문의 한계점
1. 구현 복잡도
최적화된 CUDA 커널을 직접 작성해야 하므로 구현이 매우 어렵다. 일반 연구자가 수정하거나 확장하기 힘들다.
2. 하드웨어 의존성
최적 블록 크기와 성능이 GPU 아키텍처에 크게 의존한다. A100에 최적화된 커널이 다른 GPU(V100, RTX 3090 등)에서 동일한 성능을 내지 못할 수 있다.
3. Causal Masking 지원의 한계
논문 발표 시점에서 causal masking(자기회귀 모델용)에 대한 최적화가 완전하지 않았다. (이후 FlashAttention-2에서 개선됨)
4. 추론(Inference) 최적화 미흡
이 논문은 주로 학습(training)에 초점을 맞추고 있으며, KV-cache가 필요한 자기회귀 추론에 대한 최적화는 부족하다.
5. Multi-GPU 확장
단일 GPU 내 메모리 계층 최적화에 집중하며, 다중 GPU 간 통신 최적화는 다루지 않는다.
후속 연구 방향
FlashAttention-2 (2023)
- Forward pass에서 parallelism을 Q 블록 축으로 변경하여 GPU 활용률 향상
- Causal masking 최적화
- FlashAttention 대비 약 2× 추가 속도 향상, 이론적 최대 FLOPs의 70% 이상 달성
FlashAttention-3 (2024)
- Hopper 아키텍처(H100) 전용 최적화
- 비동기 연산: WGMMA(행렬곱)와 softmax를 파이프라이닝
- FP8 지원으로 추가 속도 향상
관련 후속 연구
| 연구 | 핵심 아이디어 |
|---|---|
| PagedAttention (vLLM) | KV-cache를 페이지 단위로 관리하여 추론 메모리 효율화 |
| Ring Attention | 다중 GPU 간 KV 블록을 링 형태로 전달하여 초긴 시퀀스 처리 |
| FlashDecoding | 자기회귀 추론에 특화된 FlashAttention 변형 |
| Mamba / State Space Models | 아예 어텐션을 대체하는 아키텍처 |
| Mixture of Experts + FlashAttention | MoE와 결합하여 더 큰 모델을 효율적으로 학습 |
핵심 수식 정리
이 논문에서 기억해야 할 핵심 수식들을 정리한다.
1. 표준 어텐션
2. Online Softmax 갱신
3. IO 복잡도
마무리
FlashAttention은 Transformer 어텐션의 구현 패러다임을 근본적으로 바꾼 논문이다. 핵심을 요약하면:
-
IO-Aware 설계: GPU의 SRAM-HBM 메모리 계층을 인식하여, 데이터 이동을 최소화하는 타일링 기반 어텐션 알고리즘을 설계했다.
-
정확한 계산: 근사 어텐션과 달리 수학적으로 표준 어텐션과 완전히 동일한 결과를 산출한다. 어떤 품질 손실도 없다.
-
2~4배 속도 향상, 메모리: 실제 wall-clock 시간에서 2~4배 빠르며, 메모리 사용량을 시퀀스 길이에 대해 선형으로 줄였다.
-
이론적 최적: HBM 접근의 하한을 증명하여 FlashAttention이 점근적으로 최적임을 보였다.
-
광범위한 채택: PyTorch, Hugging Face, 주요 LLM 학습 파이프라인에 표준으로 자리잡으며, 현대 AI 인프라의 핵심 구성요소가 되었다.
“FLOPs를 줄이는 것보다 메모리 접근을 줄이는 것이 더 중요하다.” 이것이 FlashAttention이 전달하는 가장 중요한 메시지이며, 향후 하드웨어-소프트웨어 공동 최적화(hardware-software co-design) 연구의 방향성을 제시하는 통찰이다.
이 글이 도움이 되셨나요?
Buy me a coffee
답글 남기기