Knowledge Distillation 실전 가이드: 대형 LLM을 소형 모델로 압축하는 증류 기법과 DistilBERT, TinyLlama 사례 분석

Knowledge Distillation이란?

Knowledge Distillation(지식 증류)은 크고 복잡한 모델(Teacher)의 지식을 작고 효율적인 모델(Student)로 전달하는 기법입니다. 2015년 Geoffrey Hinton이 제안한 이후, 모바일·엣지 디바이스에서 대형 AI 모델을 실용화하는 핵심 기술로 자리잡았습니다.

핵심 아이디어: “큰 모델의 출력 확률 분포(soft targets)가 hard labels보다 더 많은 정보를 담고 있다”

일반적인 학습은 정답 레이블(hard label)만 사용하지만, Distillation은 Teacher 모델의 소프트맥스 출력 전체를 학습 신호로 활용합니다. 예를 들어 “고양이” 이미지 분류 시:

  • Hard label: [0, 1, 0, 0] (고양이=1, 나머지=0)
  • Soft target: [0.05, 0.85, 0.08, 0.02] (개 5%, 고양이 85%, 호랑이 8%, …)

두 번째 분포가 “고양이는 개보다 호랑이와 더 유사하다”는 관계 정보를 담고 있습니다.


Distillation Loss 함수

학생 모델은 두 가지 손실의 가중합을 최소화합니다:

<br/>L=αLCE(y,y^<em>student)+(1α)L</em>KD(y^<em>teacher,y^</em>student)<br/><br /> L = \alpha L_{\text{CE}}(y, \hat{y}<em>{\text{student}}) + (1-\alpha) L</em>{\text{KD}}(\hat{y}<em>{\text{teacher}}, \hat{y}</em>{\text{student}})<br />

  • LCEL_{\text{CE}}: 정답 레이블 yy와 학생 예측 간 Cross-Entropy (정확도 보장)
  • LKDL_{\text{KD}}: 교사·학생 소프트맥스 출력 간 KL Divergence (지식 전달)
  • α\alpha: 균형 하이퍼파라미터 (보통 0.1~0.5)
  • Temperature TT: 소프트맥스를 부드럽게 만드는 매개변수

Temperature Scaling

소프트맥스에 온도 TT를 적용하면:

<br/>pi=exp(zi/T)jexp(zj/T)<br/><br /> p_i = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)}<br />

  • T=1T=1: 일반 소프트맥스 (날카로운 분포)
  • T&gt;1 (예: 3~5): 확률이 평탄해져 작은 확률값들도 학습에 기여
  • 추론 시에는 T=1T=1로 복원

주요 Distillation 기법 비교

기법 특징 적용 사례
Response Distillation 최종 출력층만 증류 DistilBERT, TinyBERT
Feature Distillation 중간 히든 레이어도 매칭 BERT-PKD, MiniLM
Attention Distillation Transformer Attention Map 전달 TinyBERT, DynaBERT
Self-Distillation 같은 모델 구조로 반복 증류 Born-Again Networks
Online Distillation 교사·학생 동시 학습 Deep Mutual Learning

사례 1: DistilBERT

DistilBERT(2019, Hugging Face)는 BERT-base를 절반 크기로 압축한 모델입니다.

구조

항목 BERT-base DistilBERT
레이어 12 6
파라미터 110M 66M (40% 감소)
속도 1x 1.6x
정확도 100% 97%

증류 전략

  1. Token-level distillation: 각 토큰의 출력 확률 매칭
  2. Cosine embedding loss: 히든 벡터 방향 정렬
  3. No [NSP] task: Next Sentence Prediction 제거로 단순화
# DistilBERT 사용 예시
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)

inputs = tokenizer("Knowledge distillation is powerful", return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits  # 분류 결과

사례 2: TinyLlama

TinyLlama(2024)는 Llama 2 아키텍처를 1.1B 파라미터로 축소한 소형 LLM입니다.

특징

  • 파라미터: 1.1B (Llama 2 7B 대비 85% 감소)
  • 학습 데이터: 3T 토큰 (RedPajama, SlimPajama)
  • 컨텍스트: 2048 토큰
  • 성능: GPT-3.5 수준의 추론 능력 (특정 태스크)

증류 기법

  1. Logits Distillation: Llama 2 7B의 next-token 확률 분포 모방
  2. Sequence-level KL: 전체 시퀀스 생성 패턴 학습
  3. Data Filtering: 고품질 데이터만 선별하여 효율 극대화
# TinyLlama 추론 예시
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)

prompt = "Explain knowledge distillation in one sentence:"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_length=50)
print(tokenizer.decode(outputs[0]))

활용 사례: 라즈베리파이·스마트폰 온디바이스 AI, 실시간 챗봇, 엣지 추론


사례 3: Gemma 시리즈

Gemma(2024, Google)는 Gemini의 지식을 소형 모델로 증류한 오픈 LLM입니다.

라인업

모델 파라미터 용도
Gemma 2B 2B 모바일·IoT
Gemma 7B 7B 일반 서버 추론
Gemma 27B 27B 고성능 태스크

증류 방식

  • Multi-teacher distillation: Gemini Ultra + Pro 조합 사용
  • Instruction distillation: 명령어-응답 쌍을 증류 데이터로 활용
  • Safety alignment: 교사 모델의 안전성 정책도 함께 전달

실전 구현: PyTorch 예시

간단한 이미지 분류 모델 증류 코드입니다.

import torch
import torch.nn as nn
import torch.nn.functional as F

class DistillationLoss(nn.Module):
    def __init__(self, alpha=0.3, temperature=3.0):
        super().__init__()
        self.alpha = alpha
        self.T = temperature
        self.ce_loss = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, labels):
        # Hard label loss
        hard_loss = self.ce_loss(student_logits, labels)

        # Soft target loss (KL Divergence)
        soft_student = F.log_softmax(student_logits / self.T, dim=1)
        soft_teacher = F.softmax(teacher_logits / self.T, dim=1)
        soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (self.T ** 2)

        return self.alpha * hard_loss + (1 - self.alpha) * soft_loss

# 학습 루프
teacher_model.eval()  # 교사는 평가 모드
student_model.train()

for images, labels in dataloader:
    with torch.no_grad():
        teacher_logits = teacher_model(images)

    student_logits = student_model(images)
    loss = distillation_loss(student_logits, teacher_logits, labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

성능 vs 효율 트레이드오프

실무에서는 압축률과 성능 손실의 균형이 중요합니다.

압축률 예상 정확도 유지 적용 시나리오
30~50% 95~98% 클라우드 비용 절감
50~70% 90~95% 모바일 앱
80% 이상 80~90% IoT·엣지 디바이스

실무 팁: 먼저 Response Distillation으로 빠르게 프로토타입을 만들고, 성능이 부족하면 Feature/Attention Distillation을 추가하세요.


Distillation 체크리스트

실전 프로젝트에 적용하기 전 확인할 사항:

  • [ ] 교사 모델 성능 검증: 충분히 높은 정확도인가?
  • [ ] 데이터셋 준비: 교사 모델 출력(logits)을 미리 저장했는가?
  • [ ] 학생 구조 설계: 파라미터를 어디서 줄일 것인가? (레이어/히든 크기/어텐션 헤드)
  • [ ] 하이퍼파라미터 튜닝: α\alpha, TT 실험 (보통 α=0.3\alpha=0.3, T=3T=3부터 시작)
  • [ ] 추론 속도 측정: 실제 타겟 디바이스에서 레이턴시 확인
  • [ ] A/B 테스트: 실제 서비스에서 사용자 반응 비교

최신 트렌드: LLM 시대의 Distillation

1. Speculative Decoding

작은 모델이 여러 토큰을 빠르게 생성 → 큰 모델이 검증하는 하이브리드 방식 (속도 2~3배)

2. QLoRA + Distillation

Quantization과 Low-Rank Adaptation을 결합하여 메모리 효율 극대화

3. On-Policy Distillation

학생 모델이 생성한 샘플을 교사가 평가하여 재학습 (강화학습 스타일)


마무리

Knowledge Distillation은 대형 모델의 성능을 유지하면서 추론 비용을 획기적으로 줄이는 핵심 기술입니다. 핵심 요약:

  1. 소프트 타겟 활용: 정답 레이블보다 풍부한 정보 전달
  2. Temperature Scaling: 작은 확률값까지 학습 신호로 활용
  3. 사례별 전략:
    DistilBERT: 레이어 절반 + 코사인 임베딩 손실
    TinyLlama: 3T 토큰 학습 + 로짓 증류
    Gemma: 멀티 교사 + 명령어 튜닝 증류
  4. 실무 적용: Response → Feature → Attention 순으로 단계적 적용
  5. 효율 vs 성능: 타겟 디바이스와 요구사항에 맞춰 압축률 조정

GPT-4, Claude 같은 초거대 모델도 결국 더 작고 빠른 형태로 증류되어 일상 디바이스에 탑재될 것입니다. 지금이 Distillation 기술을 마스터할 최적의 시점입니다.

이 글이 도움이 되셨나요?

Buy me a coffee

댓글

답글 남기기

이메일 주소는 공개되지 않습니다. 필수 필드는 *로 표시됩니다

TODAY 87 | TOTAL 286