Vision Transformer 시대의 시작
CNN이 지배하던 컴퓨터 비전 분야에 2020년 Google Research의 Vision Transformer (ViT)가 등장하면서 패러다임 변화가 시작되었습니다. 이후 Swin Transformer, MaxViT 등 개선된 아키텍처들이 속속 등장하며 각자의 강점을 입증했습니다.
트랜스포머 기반 비전 모델은 이제 ImageNet, COCO 등 주요 벤치마크에서 CNN을 능가하며 실무의 새로운 표준으로 자리잡고 있습니다.
이 글에서는 세 가지 대표 아키텍처를 비교하고, Hybrid 모델을 실전에 적용하는 방법을 안내합니다.
Vision Transformer (ViT) 핵심 개념
ViT는 이미지를 패치(patch) 단위로 나누어 NLP의 Transformer를 그대로 적용한 혁신적 접근입니다.
동작 원리
- 패치 분할: 입력 이미지 를 개의 패치 로 분할
- 선형 임베딩: 각 패치를 1D 벡터로 flatten 후 선형 변환
- 위치 임베딩 추가: 패치 순서 정보를 위한 learnable position embedding
- Transformer Encoder: Multi-Head Self-Attention과 MLP 블록 반복
- 분류 토큰: [CLS] 토큰의 최종 출력으로 이미지 분류
수식으로 표현하면:
여기서:
– : 번째 패치의 flatten된 벡터
– : 패치 임베딩 행렬
– : 위치 임베딩
– : 임베딩 차원
import torch
import torch.nn as nn
from transformers import ViTForImageClassification, ViTImageProcessor
# ViT 모델 로드 (사전학습 가중치)
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
# 이미지 전처리 및 추론
from PIL import Image
image = Image.open('example.jpg')
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_class = logits.argmax(-1).item()
ViT의 장점과 한계
장점:
– 단순하고 우아한 구조
– 대규모 데이터셋(JFT-300M 등)에서 뛰어난 성능
– Global receptive field (패치 간 직접 attention)
한계:
– 작은 데이터셋에서 CNN보다 낮은 성능 (inductive bias 부족)
– 고해상도 이미지에서 계산량 폭증 ( 복잡도)
– 다양한 스케일 특징 추출 어려움
Swin Transformer: 계층적 설계의 혁신
Microsoft Research의 Swin Transformer는 ViT의 한계를 극복하기 위해 CNN의 계층적 구조를 도입했습니다.
핵심 메커니즘
1. Shifted Window Attention
일반 window 단위로 self-attention을 수행하다가, 다음 레이어에서 window를 이동(shift)시켜 윈도우 간 정보 교환:
- Window Partition: 이미지를 크기 윈도우로 분할
- Local Attention: 각 윈도우 내에서만 attention 계산 → 복잡도
- Shifted Window: 윈도우를 만큼 이동하여 교차 영역 생성
2. 계층적 특징 맵
4단계 stage로 구성되며, 각 stage마다 해상도를 1/2로 줄이고 채널을 2배 증가:
| Stage | 해상도 | 채널 | 역할 |
|---|---|---|---|
| 1 | H/4 × W/4 | C | 초기 특징 추출 |
| 2 | H/8 × W/8 | 2C | 중간 특징 |
| 3 | H/16 × W/16 | 4C | 고수준 특징 |
| 4 | H/32 × W/32 | 8C | 최종 특징 맵 |
from transformers import SwinForImageClassification, AutoImageProcessor
model = SwinForImageClassification.from_pretrained('microsoft/swin-base-patch4-window7-224')
processor = AutoImageProcessor.from_pretrained('microsoft/swin-base-patch4-window7-224')
image = Image.open('sample.jpg')
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
Swin의 강점
- 선형 복잡도: Window attention으로 대용량 이미지 효율적 처리
- 멀티스케일 특징: 계층 구조로 다양한 스케일 표현 학습
- 객체 탐지/세그멘테이션: Backbone으로 사용 시 FPN 같은 구조와 자연스러운 결합
- 작은 데이터셋: ImageNet-1K에서도 ViT보다 우수한 성능
MaxViT: Global + Local의 완벽한 조화
MaxViT (Multi-Axis Vision Transformer)는 2022년 Google Research에서 발표한 하이브리드 아키텍처로, 블록 단위 local attention과 grid 단위 global attention을 번갈아 적용합니다.
Multi-Axis Attention
하나의 블록에서 두 가지 attention을 순차 수행:
Block Attention (Local)
이미지 → 블록 분할 → 각 블록 내 self-attention
Grid Attention (Global)
블록들을 재배열 → 같은 위치의 블록들끼리 attention
이를 통해 선형 복잡도로 전역 수용 영역을 달성합니다.
MBConv 통합
MaxViT는 각 stage 앞에 MBConv(EfficientNet의 구성 요소)를 배치하여 CNN의 inductive bias를 활용:
Stage: [MBConv] → [MaxViT Block × N]
import timm
# timm 라이브러리로 MaxViT 사용
model = timm.create_model('maxvit_base_tf_224', pretrained=True, num_classes=1000)
model.eval()
# 추론
import torch
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = Image.open('test.jpg')
input_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)
pred = output.argmax(dim=1)
세 아키텍처 종합 비교
| 특성 | ViT | Swin Transformer | MaxViT |
|---|---|---|---|
| Attention 방식 | Global (모든 패치) | Shifted Window (Local) | Block + Grid (Hybrid) |
| 복잡도 | |||
| 계층 구조 | ✗ 단일 해상도 | ✓ 4단계 피라미드 | ✓ 4단계 + MBConv |
| Inductive Bias | 거의 없음 | 약간 (locality) | 강함 (CNN 통합) |
| 작은 데이터셋 | 약함 | 보통 | 강함 |
| 고해상도 처리 | 느림 | 빠름 | 빠름 |
| 객체 탐지 Backbone | 부적합 | 적합 | 매우 적합 |
| ImageNet Top-1 | 84.5% (Base) | 85.2% (Base) | 86.5% (Base) |
| 파라미터 효율 | 보통 | 좋음 | 매우 좋음 |
선택 가이드: 대규모 데이터 + 분류 전용 → ViT, 작은 데이터 + 다운스트림 태스크 → Swin, 성능 최우선 + 효율성 → MaxViT
Hybrid 모델 실전 적용 사례
사례 1: 의료 영상 분류 (X-ray 폐렴 진단)
문제: 작은 데이터셋(5,000장), 고해상도 이미지(1024×1024)
해결책: Swin-Base + Transfer Learning
import torch
import torch.nn as nn
from transformers import SwinModel
class MedicalImageClassifier(nn.Module):
def __init__(self, num_classes=3):
super().__init__()
# ImageNet 사전학습 Swin backbone
self.swin = SwinModel.from_pretrained('microsoft/swin-base-patch4-window7-224')
self.classifier = nn.Linear(1024, num_classes)
def forward(self, pixel_values):
outputs = self.swin(pixel_values=pixel_values)
pooled = outputs.pooler_output # [batch, 1024]
logits = self.classifier(pooled)
return logits
# Fine-tuning
model = MedicalImageClassifier(num_classes=3)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# 첫 10 epoch는 backbone freeze
for param in model.swin.parameters():
param.requires_grad = False
# 훈련 루프...
결과: 기존 ResNet50(82%) 대비 89% 정확도 달성
사례 2: 위성 영상 객체 탐지
문제: 극소 객체(차량, 건물) 탐지, 4096×4096 해상도
해결책: MaxViT Backbone + Cascade R-CNN
import detectron2
from detectron2.config import get_cfg
from detectron2.modeling import build_model
cfg = get_cfg()
cfg.MODEL.BACKBONE.NAME = "MaxViT-Base"
cfg.MODEL.RESNETS.DEPTH = 50
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 10
# MaxViT의 계층적 특징 맵을 FPN에 연결
cfg.MODEL.FPN.IN_FEATURES = ["stage2", "stage3", "stage4"]
cfg.INPUT.MIN_SIZE_TRAIN = (800, 1024, 1280)
model = build_model(cfg)
# 훈련...
성능: DOTA 데이터셋 mAP 76.3% (Swin 대비 +2.1%p)
사례 3: 실시간 동영상 분류
문제: 30fps 실시간 처리, 제한된 GPU 메모리
해결책: ViT-Small + Knowledge Distillation
from transformers import ViTForImageClassification
import torch.nn.functional as F
# Teacher: ViT-Large
teacher = ViTForImageClassification.from_pretrained('google/vit-large-patch16-224')
teacher.eval()
# Student: ViT-Small (경량화)
student = ViTForImageClassification.from_pretrained('google/vit-small-patch16-224')
# Distillation Loss
def distillation_loss(student_logits, teacher_logits, labels, alpha=0.5, temperature=3.0):
soft_loss = F.kl_div(
F.log_softmax(student_logits / temperature, dim=1),
F.softmax(teacher_logits / temperature, dim=1),
reduction='batchmean'
) * (temperature ** 2)
hard_loss = F.cross_entropy(student_logits, labels)
return alpha * soft_loss + (1 - alpha) * hard_loss
# 훈련 시 teacher의 soft labels 활용
결과: ViT-Large 정확도의 97% 유지하며 추론 속도 3.2배 향상
구현 시 주의사항
1. 입력 해상도와 패치 크기
- ViT는 패치 크기(16 or 32)에 따라 계산량이 크게 달라짐
- Swin/MaxViT는 윈도우 크기(7×7 권장)가 성능에 영향
- 고해상도 입력 시 메모리 오버플로 주의
2. 위치 임베딩 보간
사전학습 시 224×224로 학습된 모델을 384×384로 fine-tuning할 때:
import torch.nn.functional as F
# 위치 임베딩 interpolation
old_pos_embed = model.embeddings.position_embeddings # [1, 197, 768]
num_patches = (384 // 16) ** 2 # 576
new_pos_embed = F.interpolate(
old_pos_embed.reshape(1, 14, 14, 768).permute(0, 3, 1, 2),
size=(24, 24),
mode='bicubic'
).permute(0, 2, 3, 1).flatten(1, 2)
3. Mixed Precision Training
트랜스포머는 메모리를 많이 사용하므로 FP16/BF16 학습 필수:
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for images, labels in dataloader:
optimizer.zero_grad()
with autocast():
outputs = model(images)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
마무리
Vision Transformer 계열은 이제 컴퓨터 비전의 주류 아키텍처로 자리잡았습니다. 각 모델의 특징을 정리하면:
- ViT: 단순하고 강력하지만 대규모 데이터 필수. 연구용 baseline으로 적합
- Swin Transformer: 계층적 설계로 다양한 비전 태스크에 범용적. 실무 첫 선택지
- MaxViT: Local + Global attention 조합으로 최고 성능. 계산 자원 여유 시 최선
핵심 포인트: 데이터 규모, 태스크 유형, 계산 자원을 고려해 아키텍처를 선택하고, Transfer Learning과 Hybrid 접근으로 실전 성능을 극대화하세요.
최신 트랜스포머 기술을 활용해 더 나은 비전 시스템을 구축하시길 바랍니다!
이 글이 도움이 되셨나요?
Buy me a coffee
답글 남기기