Vision Transformer (ViT) vs Swin Transformer vs MaxViT: 이미지 분류 트랜스포머 완벽 비교

Vision Transformer 시대의 시작

CNN이 지배하던 컴퓨터 비전 분야에 2020년 Google Research의 Vision Transformer (ViT)가 등장하면서 패러다임 변화가 시작되었습니다. 이후 Swin Transformer, MaxViT 등 개선된 아키텍처들이 속속 등장하며 각자의 강점을 입증했습니다.

트랜스포머 기반 비전 모델은 이제 ImageNet, COCO 등 주요 벤치마크에서 CNN을 능가하며 실무의 새로운 표준으로 자리잡고 있습니다.

이 글에서는 세 가지 대표 아키텍처를 비교하고, Hybrid 모델을 실전에 적용하는 방법을 안내합니다.

Vision Transformer (ViT) 핵심 개념

ViT는 이미지를 패치(patch) 단위로 나누어 NLP의 Transformer를 그대로 적용한 혁신적 접근입니다.

동작 원리

  1. 패치 분할: 입력 이미지 H×W×CH \times W \times CNN개의 패치 P×PP \times P로 분할
  2. 선형 임베딩: 각 패치를 1D 벡터로 flatten 후 선형 변환
  3. 위치 임베딩 추가: 패치 순서 정보를 위한 learnable position embedding
  4. Transformer Encoder: Multi-Head Self-Attention과 MLP 블록 반복
  5. 분류 토큰: [CLS] 토큰의 최종 출력으로 이미지 분류

수식으로 표현하면:

z0=[xclass;xp1E;xp2E;;xpNE]+Eposz_0 = [x_{class}; x_p^1E; x_p^2E; \cdots; x_p^NE] + E_{pos}

여기서:
xpix_p^i: ii번째 패치의 flatten된 벡터
EE: 패치 임베딩 행렬 (P2C)×D(P^2 \cdot C) \times D
EposE_{pos}: 위치 임베딩 (N+1)×D(N+1) \times D
DD: 임베딩 차원

import torch
import torch.nn as nn
from transformers import ViTForImageClassification, ViTImageProcessor

# ViT 모델 로드 (사전학습 가중치)
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')

# 이미지 전처리 및 추론
from PIL import Image
image = Image.open('example.jpg')
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_class = logits.argmax(-1).item()

ViT의 장점과 한계

장점:
– 단순하고 우아한 구조
– 대규모 데이터셋(JFT-300M 등)에서 뛰어난 성능
– Global receptive field (패치 간 직접 attention)

한계:
– 작은 데이터셋에서 CNN보다 낮은 성능 (inductive bias 부족)
– 고해상도 이미지에서 계산량 폭증 (O(N2)O(N^2) 복잡도)
– 다양한 스케일 특징 추출 어려움

Swin Transformer: 계층적 설계의 혁신

Microsoft Research의 Swin Transformer는 ViT의 한계를 극복하기 위해 CNN의 계층적 구조를 도입했습니다.

핵심 메커니즘

1. Shifted Window Attention

일반 window 단위로 self-attention을 수행하다가, 다음 레이어에서 window를 이동(shift)시켜 윈도우 간 정보 교환:

  • Window Partition: 이미지를 M×MM \times M 크기 윈도우로 분할
  • Local Attention: 각 윈도우 내에서만 attention 계산 → O(N)O(N) 복잡도
  • Shifted Window: 윈도우를 M2\frac{M}{2}만큼 이동하여 교차 영역 생성

2. 계층적 특징 맵

4단계 stage로 구성되며, 각 stage마다 해상도를 1/2로 줄이고 채널을 2배 증가:

Stage 해상도 채널 역할
1 H/4 × W/4 C 초기 특징 추출
2 H/8 × W/8 2C 중간 특징
3 H/16 × W/16 4C 고수준 특징
4 H/32 × W/32 8C 최종 특징 맵
from transformers import SwinForImageClassification, AutoImageProcessor

model = SwinForImageClassification.from_pretrained('microsoft/swin-base-patch4-window7-224')
processor = AutoImageProcessor.from_pretrained('microsoft/swin-base-patch4-window7-224')

image = Image.open('sample.jpg')
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits

Swin의 강점

  • 선형 복잡도: Window attention으로 대용량 이미지 효율적 처리
  • 멀티스케일 특징: 계층 구조로 다양한 스케일 표현 학습
  • 객체 탐지/세그멘테이션: Backbone으로 사용 시 FPN 같은 구조와 자연스러운 결합
  • 작은 데이터셋: ImageNet-1K에서도 ViT보다 우수한 성능

MaxViT: Global + Local의 완벽한 조화

MaxViT (Multi-Axis Vision Transformer)는 2022년 Google Research에서 발표한 하이브리드 아키텍처로, 블록 단위 local attentiongrid 단위 global attention을 번갈아 적용합니다.

Multi-Axis Attention

하나의 블록에서 두 가지 attention을 순차 수행:

Block Attention (Local)

이미지 → 블록 분할 → 각 블록 내 self-attention

Grid Attention (Global)

블록들을 재배열 → 같은 위치의 블록들끼리 attention

이를 통해 선형 복잡도전역 수용 영역을 달성합니다.

MBConv 통합

MaxViT는 각 stage 앞에 MBConv(EfficientNet의 구성 요소)를 배치하여 CNN의 inductive bias를 활용:

Stage: [MBConv]  [MaxViT Block × N]
import timm

# timm 라이브러리로 MaxViT 사용
model = timm.create_model('maxvit_base_tf_224', pretrained=True, num_classes=1000)
model.eval()

# 추론
import torch
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

image = Image.open('test.jpg')
input_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
    output = model(input_tensor)
    pred = output.argmax(dim=1)

세 아키텍처 종합 비교

특성 ViT Swin Transformer MaxViT
Attention 방식 Global (모든 패치) Shifted Window (Local) Block + Grid (Hybrid)
복잡도 O(N2)O(N^2) O(N)O(N) O(N)O(N)
계층 구조 ✗ 단일 해상도 ✓ 4단계 피라미드 ✓ 4단계 + MBConv
Inductive Bias 거의 없음 약간 (locality) 강함 (CNN 통합)
작은 데이터셋 약함 보통 강함
고해상도 처리 느림 빠름 빠름
객체 탐지 Backbone 부적합 적합 매우 적합
ImageNet Top-1 84.5% (Base) 85.2% (Base) 86.5% (Base)
파라미터 효율 보통 좋음 매우 좋음

선택 가이드: 대규모 데이터 + 분류 전용 → ViT, 작은 데이터 + 다운스트림 태스크 → Swin, 성능 최우선 + 효율성 → MaxViT

Hybrid 모델 실전 적용 사례

사례 1: 의료 영상 분류 (X-ray 폐렴 진단)

문제: 작은 데이터셋(5,000장), 고해상도 이미지(1024×1024)

해결책: Swin-Base + Transfer Learning

import torch
import torch.nn as nn
from transformers import SwinModel

class MedicalImageClassifier(nn.Module):
    def __init__(self, num_classes=3):
        super().__init__()
        # ImageNet 사전학습 Swin backbone
        self.swin = SwinModel.from_pretrained('microsoft/swin-base-patch4-window7-224')
        self.classifier = nn.Linear(1024, num_classes)

    def forward(self, pixel_values):
        outputs = self.swin(pixel_values=pixel_values)
        pooled = outputs.pooler_output  # [batch, 1024]
        logits = self.classifier(pooled)
        return logits

# Fine-tuning
model = MedicalImageClassifier(num_classes=3)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# 첫 10 epoch는 backbone freeze
for param in model.swin.parameters():
    param.requires_grad = False

# 훈련 루프...

결과: 기존 ResNet50(82%) 대비 89% 정확도 달성

사례 2: 위성 영상 객체 탐지

문제: 극소 객체(차량, 건물) 탐지, 4096×4096 해상도

해결책: MaxViT Backbone + Cascade R-CNN

import detectron2
from detectron2.config import get_cfg
from detectron2.modeling import build_model

cfg = get_cfg()
cfg.MODEL.BACKBONE.NAME = "MaxViT-Base"
cfg.MODEL.RESNETS.DEPTH = 50
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 10

# MaxViT의 계층적 특징 맵을 FPN에 연결
cfg.MODEL.FPN.IN_FEATURES = ["stage2", "stage3", "stage4"]
cfg.INPUT.MIN_SIZE_TRAIN = (800, 1024, 1280)

model = build_model(cfg)
# 훈련...

성능: DOTA 데이터셋 mAP 76.3% (Swin 대비 +2.1%p)

사례 3: 실시간 동영상 분류

문제: 30fps 실시간 처리, 제한된 GPU 메모리

해결책: ViT-Small + Knowledge Distillation

from transformers import ViTForImageClassification
import torch.nn.functional as F

# Teacher: ViT-Large
teacher = ViTForImageClassification.from_pretrained('google/vit-large-patch16-224')
teacher.eval()

# Student: ViT-Small (경량화)
student = ViTForImageClassification.from_pretrained('google/vit-small-patch16-224')

# Distillation Loss
def distillation_loss(student_logits, teacher_logits, labels, alpha=0.5, temperature=3.0):
    soft_loss = F.kl_div(
        F.log_softmax(student_logits / temperature, dim=1),
        F.softmax(teacher_logits / temperature, dim=1),
        reduction='batchmean'
    ) * (temperature ** 2)

    hard_loss = F.cross_entropy(student_logits, labels)
    return alpha * soft_loss + (1 - alpha) * hard_loss

# 훈련 시 teacher의 soft labels 활용

결과: ViT-Large 정확도의 97% 유지하며 추론 속도 3.2배 향상

구현 시 주의사항

1. 입력 해상도와 패치 크기

  • ViT는 패치 크기(16 or 32)에 따라 계산량이 크게 달라짐
  • Swin/MaxViT는 윈도우 크기(7×7 권장)가 성능에 영향
  • 고해상도 입력 시 메모리 오버플로 주의

2. 위치 임베딩 보간

사전학습 시 224×224로 학습된 모델을 384×384로 fine-tuning할 때:

import torch.nn.functional as F

# 위치 임베딩 interpolation
old_pos_embed = model.embeddings.position_embeddings  # [1, 197, 768]
num_patches = (384 // 16) ** 2  # 576
new_pos_embed = F.interpolate(
    old_pos_embed.reshape(1, 14, 14, 768).permute(0, 3, 1, 2),
    size=(24, 24),
    mode='bicubic'
).permute(0, 2, 3, 1).flatten(1, 2)

3. Mixed Precision Training

트랜스포머는 메모리를 많이 사용하므로 FP16/BF16 학습 필수:

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for images, labels in dataloader:
    optimizer.zero_grad()

    with autocast():
        outputs = model(images)
        loss = criterion(outputs, labels)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

마무리

Vision Transformer 계열은 이제 컴퓨터 비전의 주류 아키텍처로 자리잡았습니다. 각 모델의 특징을 정리하면:

  • ViT: 단순하고 강력하지만 대규모 데이터 필수. 연구용 baseline으로 적합
  • Swin Transformer: 계층적 설계로 다양한 비전 태스크에 범용적. 실무 첫 선택지
  • MaxViT: Local + Global attention 조합으로 최고 성능. 계산 자원 여유 시 최선

핵심 포인트: 데이터 규모, 태스크 유형, 계산 자원을 고려해 아키텍처를 선택하고, Transfer Learning과 Hybrid 접근으로 실전 성능을 극대화하세요.

최신 트랜스포머 기술을 활용해 더 나은 비전 시스템을 구축하시길 바랍니다!

이 글이 도움이 되셨나요?

Buy me a coffee

댓글

답글 남기기

이메일 주소는 공개되지 않습니다. 필수 필드는 *로 표시됩니다

TODAY 46 | TOTAL 245