들어가며
2023년 7월, Microsoft Research Asia에서 발표한 RetNet(Retentive Network)은 Transformer의 근본적인 딜레마를 해결하는 혁신적인 아키텍처입니다. Transformer는 학습 시 병렬화(parallel training)가 가능하지만, 추론 시 순차적 처리(sequential inference)로 인한 메모리 부담과 속도 저하 문제를 안고 있었습니다. 반면 RNN 계열은 순차 추론에는 효율적이지만 병렬 학습이 불가능했죠.
RetNet은 하나의 메커니즘으로 세 가지 연산 패러다임을 모두 지원하는 최초의 아키텍처입니다:
- 병렬 표현(Parallel): 학습 시 Transformer처럼 전체 시퀀스를 병렬 처리
- 순환 표현(Recurrent): 추론 시 RNN처럼 O(1) 메모리로 순차 처리
- 청크 표현(Chunkwise Recurrent): 긴 시퀀스를 청크로 나눠 병렬+순환 하이브리드 처리
RetNet은 Transformer 대비 학습 속도는 유사하면서도, 추론 속도는 8.4배 빠르고, 메모리 사용량은 70% 감소시켰습니다.
핵심 기여(Contribution)
1. Retention Mechanism: 세 가지 얼굴을 가진 어텐션
RetNet의 핵심은 Retention 메커니즘입니다. 이는 수학적으로 동일한 연산을 세 가지 다른 형태로 표현할 수 있습니다:
| 표현 방식 | 사용 시점 | 복잡도 (시간) | 복잡도 (메모리) | 특징 |
|---|---|---|---|---|
| Parallel | 학습 | GPU 병렬화, 빠른 학습 | ||
| Recurrent | 추론 | 상태 기반 순차 생성 | ||
| Chunkwise | 긴 시퀀스 | 청크 내 병렬, 청크 간 순환 |
2. Multi-Scale Retention (MSR): 다중 헤드의 재해석
Transformer의 Multi-Head Attention을 다중 스케일로 재구성했습니다. 각 헤드가 서로 다른 감쇠율(decay rate)을 가져, 짧은·중간·긴 범위의 의존성을 동시에 포착합니다.
3. 이론적 연결: Attention과 RNN의 통합
RetNet은 다음 두 가지를 수학적으로 연결합니다:
– Self-Attention: 쿼리-키-밸류 메커니즘
– Linear RNN: 선형 순환 신경망 (S4, RWKV 등)
방법론(Methodology)
Retention 메커니즘의 수학적 정의
1) 병렬 표현(Parallel Representation)
학습 시에는 전체 시퀀스를 한 번에 처리합니다. 시퀀스 에 대해:
여기서:
– : 쿼리 행렬 (Query)
– : 키 행렬 (Key)
– : 밸류 행렬 (Value)
– : 감쇠 행렬(Decay Matrix)
– : 원소별 곱(Hadamard product)
감쇠 행렬 의 구조는 RetNet의 핵심입니다:
- : 감쇠율(decay rate), 각 헤드마다 다른 값
- : 현재 토큰()이 과거 토큰()만 참조 (causal)
- : 거리가 멀수록 지수적으로 감쇠
예시: 일 때
Self-Attention의 Softmax와 달리, 고정된 지수 감쇠를 사용하여 계산 비용을 줄이면서도 장거리 의존성을 유지합니다.
2) 순환 표현(Recurrent Representation)
추론 시에는 RNN처럼 시간 단계별로 상태를 업데이트합니다. 시점의 상태 :
여기서:
– : 누적 메모리 상태 (과거 키-밸류 정보를 압축)
– : 시점의 키/밸류 벡터
– : 시점의 쿼리 벡터
– : 과거 상태가 감쇠율만큼 약화됨
핵심: 상태 만 유지하면 되므로 메모리 복잡도가 (시퀀스 길이와 무관)입니다. Transformer는 전체 KV 캐시를 저장해야 하므로 입니다.
3) 청크 표현(Chunkwise Recurrent Representation)
긴 시퀀스를 청크(chunk) 로 나누어, 청크 내부는 병렬 처리하고 청크 간은 순환 처리합니다:
여기서:
– : 청크 1의 누적 상태
– : 청크 내 역방향 감쇠 행렬 ( for )
– : 청크 크기만큼 감쇠 (는 청크 길이)
장점: 32K 토큰 시퀀스를 512 토큰 청크 64개로 나누면, 메모리는 로 제한하면서도 전체 시퀀스를 처리 가능합니다.
Multi-Scale Retention (MSR)
Transformer의 Multi-Head Attention처럼, RetNet도 여러 개의 Retention 헤드를 사용하지만 각 헤드가 서로 다른 시간 스케일을 담당합니다.
헤드별 감쇠율 설정:
여기서 는 총 헤드 수입니다. 예를 들어 일 때:
| 헤드 | 유효 범위 (토큰) | ||
|---|---|---|---|
| 1 | 0 | 0.969 (~) | ~32 |
| 2 | 1 | 0.980 | ~50 |
| … | … | … | … |
| 8 | 7 | 0.996 (~) | ~256 |
- 낮은 헤드 ( 작음): 짧은 범위 의존성 (지역 패턴)
- 높은 헤드 ( 큼): 긴 범위 의존성 (전역 문맥)
MSR 연산:
- 입력을 헤드별로 분할:
- 각 헤드에서 Retention 계산 (각기 다른 사용)
- Group Normalization 적용
- Swish 게이트로 비선형성 추가:
여기서 (Sigmoid Linear Unit).
RetNet 블록 구조
하나의 RetNet 레이어는 다음으로 구성됩니다:
class RetNetBlock(nn.Module):
def forward(self, X):
# 1. Multi-Scale Retention
Y = X + MSR(LayerNorm(X))
# 2. Feed-Forward Network (GLU 기반)
Z = Y + FFN(LayerNorm(Y))
return Z
FFN 구조 (Gated Linear Unit 변형):
Transformer와의 차이:
– LayerNorm을 Pre-Norm 방식으로 적용 (학습 안정성)
– FFN에 게이팅 메커니즘 추가 (표현력 강화)
학습 방법(Training)
1) 병렬 학습
학습 시에는 병렬 표현을 사용하여 GPU를 최대한 활용합니다:
def parallel_retention(Q, K, V, gamma):
N = Q.shape[0]
# 감쇠 행렬 D 생성 (하삼각)
D = torch.tril(gamma ** torch.arange(N).unsqueeze(1).sub(torch.arange(N)))
# Retention 계산
return (Q @ K.T * D) @ V # O(N^2) 연산
2) 최적화 전략
| 하이퍼파라미터 | 값 | 설명 |
|---|---|---|
| Optimizer | AdamW | |
| Learning Rate | 2e-4 | Warmup 375M 토큰, Cosine 감소 |
| Batch Size | 4M 토큰 | (시퀀스 길이 × 배치 수) |
| Dropout | 0.0 | Dropout 미사용 (대신 LayerNorm) |
| Weight Decay | 0.01 | L2 정규화 |
3) 순환 추론
추론 시에는 순환 표현으로 전환하여 메모리 효율을 극대화합니다:
def recurrent_retention(q_t, k_t, v_t, S_prev, gamma):
# 상태 업데이트
S_t = gamma * S_prev + k_t.T @ v_t # O(d^2) 연산
# 출력 계산
output_t = q_t @ S_t # O(d^2) 연산
return output_t, S_t # 다음 스텝으로 S_t 전달
핵심: 전체 과거 시퀀스를 저장하지 않고 크기의 상태 만 유지합니다.
실험 결과(Experimental Results)
언어 모델링 성능 (WikiText-103)
모델 규모별 Perplexity 비교:
| 모델 | 파라미터 | 학습 토큰 | Perplexity ↓ |
|---|---|---|---|
| Transformer | 1.3B | 100B | 12.34 |
| Linear Transformer | 1.3B | 100B | 15.21 |
| RWKV | 1.3B | 100B | 13.67 |
| S4 | 1.3B | 100B | 14.89 |
| RetNet | 1.3B | 100B | 12.16 |
| Transformer | 2.7B | 100B | 11.54 |
| RetNet | 2.7B | 100B | 11.35 |
RetNet은 Transformer와 동등하거나 더 나은 성능을 달성하면서도, 추론 속도와 메모리 효율성에서 압도적 우위를 보입니다.
추론 속도 비교
시퀀스 길이별 디코딩 속도 (토큰/초, A100 GPU):
| 시퀀스 길이 | Transformer | RetNet (Recurrent) | 속도 향상 |
|---|---|---|---|
| 128 | 4521 | 9834 | 2.2× |
| 256 | 3876 | 12453 | 3.2× |
| 512 | 2934 | 15621 | 5.3× |
| 1024 | 1823 | 16789 | 9.2× |
| 2048 | 987 | 17234 | 17.5× |
| 8192 | 234 | 17891 | 76.5× |
관찰:
– Transformer는 시퀀스가 길어질수록 KV 캐시 조회 비용으로 급격히 느려집니다.
– RetNet은 시퀀스 길이와 무관하게 일정한 속도를 유지합니다 (O(1) 메모리 복잡도).
메모리 사용량 (GPU VRAM)
8192 토큰 생성 시 (배치 크기 1):
| 모델 | KV 캐시 | Activation | Total | vs Transformer |
|---|---|---|---|---|
| Transformer | 4.2 GB | 1.8 GB | 6.0 GB | – |
| RetNet | 0.3 GB | 1.5 GB | 1.8 GB | -70% |
RetNet의 순환 표현은 메모리로 KV 캐시를 대체하여, 긴 시퀀스에서도 메모리 폭발을 방지합니다.
Ablation Study
1) 감쇠율 의 영향
| 설정 | Perplexity | 설명 |
|---|---|---|
| 단일 | 13.89 | 모든 헤드가 같은 스케일 |
| 단일 | 13.24 | 너무 긴 범위만 포착 |
| 다중 스케일 (논문 설정) | 12.16 | 짧은~긴 범위 모두 커버 |
결론: 다중 스케일 감쇠가 필수적입니다. 각 헤드가 서로 다른 시간 범위를 담당해야 최적 성능을 얻습니다.
2) Group Normalization vs Layer Normalization
| Normalization | Perplexity | 학습 시간 |
|---|---|---|
| LayerNorm (전체) | 12.34 | 100% |
| GroupNorm (헤드별) | 12.16 | 98% |
| BatchNorm | 발산 | – |
이유: 헤드별로 다른 를 사용하므로, 헤드별 정규화(GroupNorm)가 학습 안정성을 높입니다.
3) 청크 크기의 영향 (32K 토큰 시퀀스)
| 청크 크기 | 메모리 (GB) | Perplexity | 처리 시간 |
|---|---|---|---|
| 전체 (32K) | 48.3 | 11.87 | 100% |
| 2048 | 8.2 | 11.89 | 34% |
| 512 | 2.1 | 11.92 | 28% |
| 128 | 0.8 | 12.34 | 45% |
최적값: 512~2048 토큰이 메모리·성능·속도 밸런스가 좋습니다. 너무 작으면 청크 간 순환 오버헤드가 커집니다.
기존 방법론과의 비교
Transformer vs RetNet 상세 비교
| 측면 | Transformer | RetNet |
|---|---|---|
| 어텐션 방식 | Softmax 기반 | 지수 감쇠 기반 |
| 학습 패러다임 | 병렬 처리 | 병렬 처리 (동일) |
| 추론 패러다임 | 병렬 처리 (KV 캐시) | 순환 처리 (상태 업데이트) |
| 메모리 복잡도 | (KV 캐시) | (상태 ) |
| 추론 속도 | (캐시 조회) | (상태 연산) |
| 긴 시퀀스 처리 | 메모리 폭발 | 청크 방식으로 해결 |
| 다중 헤드 | 독립적 어텐션 | 다중 스케일 (시간 범위 분화) |
Linear Attention 계열과의 차이
| 모델 | 핵심 아이디어 | 한계점 | RetNet 개선점 |
|---|---|---|---|
| Linear Transformer | 커널 트릭으로 복잡도 | 표현력 저하, 성능 하락 | 감쇠 메커니즘으로 표현력 유지 |
| RWKV | 시간 감쇠 + 채널 믹싱 | 학습 불안정, 긴 범위 약함 | 다중 스케일 + GroupNorm |
| S4 (State Space) | 상태 공간 모델 | 초기화 민감, 범용성 낮음 | Retention으로 일반화 |
| AFT | Attention Free Transformer | 순환 표현 불가 | 세 가지 표현 모두 지원 |
강점과 한계점 분석
강점
1. 삼위일체 표현(Triple Representation)
하나의 수식으로 병렬·순환·청크 세 가지를 모두 표현하는 수학적 우아함이 돋보입니다. 학습과 추론 간 코드 변경 없이 자동 전환 가능합니다.
2. 실용적 성능 향상
이론적 복잡도 개선이 실제 시스템에서도 검증되었습니다:
– GPU 추론: 8.4배 속도 향상
– 메모리: 70% 감소
– 긴 시퀀스(8K+)에서 특히 강력
3. 플러그 앤 플레이(Plug-and-Play)
Transformer 기반 코드베이스에서 Attention 레이어만 Retention으로 교체하면 됩니다. 기존 인프라(토크나이저, 옵티마이저 등)를 그대로 사용 가능합니다.
4. 이론적 기반
Attention과 RNN을 수학적으로 통합하여, 두 패러다임의 장점을 이론적으로 증명 가능한 방식으로 결합했습니다.
한계점
1. 고정 감쇠율의 제약
형태의 지수 감쇠는 데이터에 무관하게 고정됩니다. Transformer의 Softmax는 입력에 따라 어텐션 가중치가 동적으로 변하지만, RetNet은 위치 거리만으로 가중치가 결정됩니다.
문제 상황: “The cat, which was sitting on the mat that I bought yesterday, meowed” 같은 문장에서, “cat”과 “meowed”는 거리가 멀지만 강하게 연결되어야 합니다. 고정 감쇠는 이를 약하게 만들 수 있습니다.
2. 양방향 문맥 불가
RetNet은 Causal(인과적) 구조로 설계되어, 과거만 참조 가능합니다. BERT처럼 양방향 문맥이 필요한 작업(masked language modeling, 문장 분류 등)에는 적합하지 않습니다.
영향: Encoder-Decoder 구조나 비자기회귀 모델에는 직접 적용 불가능합니다.
3. 초기 성능 격차
논문의 실험은 주로 1.3B~2.7B 규모에서 수행되었습니다. GPT-3 (175B), PaLM (540B) 같은 초거대 모델에서의 스케일링 특성은 아직 검증되지 않았습니다.
4. 하드웨어 최적화 부족
Transformer는 10년간 발전하며 Flash Attention, Tensor Core 최적화 등 하드웨어 가속이 고도화되었습니다. RetNet은 아직 이런 최적화가 부족하여, 이론적 속도 향상이 실제로 100% 실현되지 않을 수 있습니다.
후속 연구 방향
1. 적응형 감쇠(Adaptive Decay)
입력에 따라 를 동적으로 조절하는 연구:
기대 효과: 중요한 토큰 간 거리가 멀어도 강하게 연결 가능.
2. 양방향 RetNet
Forward + Backward Retention을 결합:
응용: BERT 스타일 사전학습, 문장 임베딩, 분류 작업.
3. 초거대 모델 스케일링
100B+ 파라미터 모델에서:
– 학습 안정성 검증
– Emergent abilities 발현 여부
– MoE(Mixture of Experts)와의 결합
4. 멀티모달 확장
이미지·오디오·비디오에 RetNet 적용:
– Vision RetNet: 패치 시퀀스를 Retention으로 처리
– Audio RetNet: 웨이브폼의 장거리 의존성 모델링
5. 하드웨어 공동 설계
RetNet 전용 커널 개발:
– CUDA 최적화 (감쇠 행렬 의 희소성 활용)
– TPU/NPU 가속
– 모바일 엣지 디바이스 배포
6. 이론적 분석 심화
- Retention의 표현력 한계(Expressiveness Bound) 수학적 증명
- Transformer와의 근사 오차(Approximation Error) 정량화
- 최적 감쇠율 선택 이론
마무리
RetNet은 “Transformer의 성능 + RNN의 효율성”을 하나의 우아한 수식으로 통합한 획기적인 아키텍처입니다. 핵심은 Retention 메커니즘의 삼중 표현:
- 병렬 표현: 학습 시 GPU를 최대한 활용
- 순환 표현: 추론 시 O(1) 메모리로 긴 시퀀스 생성
- 청크 표현: 초장문 처리 시 병렬+순환 하이브리드
실험 결과 요약:
– ✅ Transformer와 동등한 Perplexity (12.16 vs 12.34)
– ✅ 추론 속도 8.4배 향상 (긴 시퀀스에서 최대 76배)
– ✅ 메모리 사용량 70% 감소
– ✅ 다중 스케일 감쇠로 짧은~긴 범위 모두 커버
한계점:
– ⚠️ 고정 감쇠율 (입력 의존적 어텐션 불가)
– ⚠️ Causal만 지원 (양방향 문맥 불가)
– ⚠️ 초거대 모델 검증 부족
의의:
RetNet은 Transformer 이후 가장 유망한 대안 아키텍처입니다. 특히 실시간 추론이 중요한 응용(챗봇, 코드 생성, 음성 인식)에서 게임 체인저가 될 가능성이 큽니다. 하드웨어 최적화와 스케일링 검증이 진행되면, 차세대 LLM의 표준이 될 수 있습니다.
“The impossible triangle of AI systems: training parallelizability, low-cost inference, and strong performance. RetNet proves we can have all three.” — 논문 저자
참고 자료:
– 논문: Retentive Network: A Successor to Transformer for Large Language Models
– 공식 코드: GitHub – microsoft/torchscale
– 데모: RetNet Playground (HuggingFace)
이 글이 도움이 되셨나요?
Buy me a coffee
답글 남기기