Knowledge Distillation이란?
Knowledge Distillation(지식 증류)은 크고 복잡한 모델(Teacher)의 지식을 작고 효율적인 모델(Student)로 전달하는 기법입니다. 2015년 Geoffrey Hinton이 제안한 이후, 모바일·엣지 디바이스에서 대형 AI 모델을 실용화하는 핵심 기술로 자리잡았습니다.
핵심 아이디어: “큰 모델의 출력 확률 분포(soft targets)가 hard labels보다 더 많은 정보를 담고 있다”
일반적인 학습은 정답 레이블(hard label)만 사용하지만, Distillation은 Teacher 모델의 소프트맥스 출력 전체를 학습 신호로 활용합니다. 예를 들어 “고양이” 이미지 분류 시:
- Hard label:
[0, 1, 0, 0](고양이=1, 나머지=0) - Soft target:
[0.05, 0.85, 0.08, 0.02](개 5%, 고양이 85%, 호랑이 8%, …)
두 번째 분포가 “고양이는 개보다 호랑이와 더 유사하다”는 관계 정보를 담고 있습니다.
Distillation Loss 함수
학생 모델은 두 가지 손실의 가중합을 최소화합니다:
- : 정답 레이블 와 학생 예측 간 Cross-Entropy (정확도 보장)
- : 교사·학생 소프트맥스 출력 간 KL Divergence (지식 전달)
- : 균형 하이퍼파라미터 (보통 0.1~0.5)
- Temperature : 소프트맥스를 부드럽게 만드는 매개변수
Temperature Scaling
소프트맥스에 온도 를 적용하면:
- : 일반 소프트맥스 (날카로운 분포)
- T>1 (예: 3~5): 확률이 평탄해져 작은 확률값들도 학습에 기여
- 추론 시에는 로 복원
주요 Distillation 기법 비교
| 기법 | 특징 | 적용 사례 |
|---|---|---|
| Response Distillation | 최종 출력층만 증류 | DistilBERT, TinyBERT |
| Feature Distillation | 중간 히든 레이어도 매칭 | BERT-PKD, MiniLM |
| Attention Distillation | Transformer Attention Map 전달 | TinyBERT, DynaBERT |
| Self-Distillation | 같은 모델 구조로 반복 증류 | Born-Again Networks |
| Online Distillation | 교사·학생 동시 학습 | Deep Mutual Learning |
사례 1: DistilBERT
DistilBERT(2019, Hugging Face)는 BERT-base를 절반 크기로 압축한 모델입니다.
구조
| 항목 | BERT-base | DistilBERT |
|---|---|---|
| 레이어 | 12 | 6 |
| 파라미터 | 110M | 66M (40% 감소) |
| 속도 | 1x | 1.6x |
| 정확도 | 100% | 97% |
증류 전략
- Token-level distillation: 각 토큰의 출력 확률 매칭
- Cosine embedding loss: 히든 벡터 방향 정렬
- No [NSP] task: Next Sentence Prediction 제거로 단순화
# DistilBERT 사용 예시
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)
inputs = tokenizer("Knowledge distillation is powerful", return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits # 분류 결과
사례 2: TinyLlama
TinyLlama(2024)는 Llama 2 아키텍처를 1.1B 파라미터로 축소한 소형 LLM입니다.
특징
- 파라미터: 1.1B (Llama 2 7B 대비 85% 감소)
- 학습 데이터: 3T 토큰 (RedPajama, SlimPajama)
- 컨텍스트: 2048 토큰
- 성능: GPT-3.5 수준의 추론 능력 (특정 태스크)
증류 기법
- Logits Distillation: Llama 2 7B의 next-token 확률 분포 모방
- Sequence-level KL: 전체 시퀀스 생성 패턴 학습
- Data Filtering: 고품질 데이터만 선별하여 효율 극대화
# TinyLlama 추론 예시
from transformers import AutoTokenizer, AutoModelForCausalLM
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
prompt = "Explain knowledge distillation in one sentence:"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_length=50)
print(tokenizer.decode(outputs[0]))
활용 사례: 라즈베리파이·스마트폰 온디바이스 AI, 실시간 챗봇, 엣지 추론
사례 3: Gemma 시리즈
Gemma(2024, Google)는 Gemini의 지식을 소형 모델로 증류한 오픈 LLM입니다.
라인업
| 모델 | 파라미터 | 용도 |
|---|---|---|
| Gemma 2B | 2B | 모바일·IoT |
| Gemma 7B | 7B | 일반 서버 추론 |
| Gemma 27B | 27B | 고성능 태스크 |
증류 방식
- Multi-teacher distillation: Gemini Ultra + Pro 조합 사용
- Instruction distillation: 명령어-응답 쌍을 증류 데이터로 활용
- Safety alignment: 교사 모델의 안전성 정책도 함께 전달
실전 구현: PyTorch 예시
간단한 이미지 분류 모델 증류 코드입니다.
import torch
import torch.nn as nn
import torch.nn.functional as F
class DistillationLoss(nn.Module):
def __init__(self, alpha=0.3, temperature=3.0):
super().__init__()
self.alpha = alpha
self.T = temperature
self.ce_loss = nn.CrossEntropyLoss()
def forward(self, student_logits, teacher_logits, labels):
# Hard label loss
hard_loss = self.ce_loss(student_logits, labels)
# Soft target loss (KL Divergence)
soft_student = F.log_softmax(student_logits / self.T, dim=1)
soft_teacher = F.softmax(teacher_logits / self.T, dim=1)
soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (self.T ** 2)
return self.alpha * hard_loss + (1 - self.alpha) * soft_loss
# 학습 루프
teacher_model.eval() # 교사는 평가 모드
student_model.train()
for images, labels in dataloader:
with torch.no_grad():
teacher_logits = teacher_model(images)
student_logits = student_model(images)
loss = distillation_loss(student_logits, teacher_logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
성능 vs 효율 트레이드오프
실무에서는 압축률과 성능 손실의 균형이 중요합니다.
| 압축률 | 예상 정확도 유지 | 적용 시나리오 |
|---|---|---|
| 30~50% | 95~98% | 클라우드 비용 절감 |
| 50~70% | 90~95% | 모바일 앱 |
| 80% 이상 | 80~90% | IoT·엣지 디바이스 |
실무 팁: 먼저 Response Distillation으로 빠르게 프로토타입을 만들고, 성능이 부족하면 Feature/Attention Distillation을 추가하세요.
Distillation 체크리스트
실전 프로젝트에 적용하기 전 확인할 사항:
- [ ] 교사 모델 성능 검증: 충분히 높은 정확도인가?
- [ ] 데이터셋 준비: 교사 모델 출력(logits)을 미리 저장했는가?
- [ ] 학생 구조 설계: 파라미터를 어디서 줄일 것인가? (레이어/히든 크기/어텐션 헤드)
- [ ] 하이퍼파라미터 튜닝: , 실험 (보통 , 부터 시작)
- [ ] 추론 속도 측정: 실제 타겟 디바이스에서 레이턴시 확인
- [ ] A/B 테스트: 실제 서비스에서 사용자 반응 비교
최신 트렌드: LLM 시대의 Distillation
1. Speculative Decoding
작은 모델이 여러 토큰을 빠르게 생성 → 큰 모델이 검증하는 하이브리드 방식 (속도 2~3배)
2. QLoRA + Distillation
Quantization과 Low-Rank Adaptation을 결합하여 메모리 효율 극대화
3. On-Policy Distillation
학생 모델이 생성한 샘플을 교사가 평가하여 재학습 (강화학습 스타일)
마무리
Knowledge Distillation은 대형 모델의 성능을 유지하면서 추론 비용을 획기적으로 줄이는 핵심 기술입니다. 핵심 요약:
- 소프트 타겟 활용: 정답 레이블보다 풍부한 정보 전달
- Temperature Scaling: 작은 확률값까지 학습 신호로 활용
- 사례별 전략:
– DistilBERT: 레이어 절반 + 코사인 임베딩 손실
– TinyLlama: 3T 토큰 학습 + 로짓 증류
– Gemma: 멀티 교사 + 명령어 튜닝 증류 - 실무 적용: Response → Feature → Attention 순으로 단계적 적용
- 효율 vs 성능: 타겟 디바이스와 요구사항에 맞춰 압축률 조정
GPT-4, Claude 같은 초거대 모델도 결국 더 작고 빠른 형태로 증류되어 일상 디바이스에 탑재될 것입니다. 지금이 Distillation 기술을 마스터할 최적의 시점입니다.
이 글이 도움이 되셨나요?
Buy me a coffee
답글 남기기