Vision Transformer vs CNN Attention Map 비교 분석: 모델 해석 가능성을 높이는 5가지 핵심 기법

들어가며

딥러닝 모델이 왜 그런 예측을 했는지 이해하는 것은 실무에서 매우 중요합니다. 특히 의료 영상 진단이나 자율주행처럼 신뢰성이 필수적인 분야에서는 모델의 해석 가능성(Interpretability)이 성능만큼 중요합니다.

Vision Transformer(ViT)CNN은 각각 다른 방식으로 이미지를 처리하며, Attention Map을 통해 모델이 어디를 보고 있는지 시각화할 수 있습니다. 이 글에서는 두 아키텍처의 Attention 메커니즘을 비교하고, 실무에서 활용할 수 있는 해석 가능성 향상 기법을 소개합니다.

Vision Transformer와 CNN의 Attention 메커니즘 차이

CNN의 Attention: Spatial한 특징 강조

CNN은 합성곱 연산을 통해 지역적 패턴을 학습합니다. Attention 메커니즘을 추가하면 중요한 공간적 영역에 가중치를 부여할 수 있습니다.

특징 설명
메커니즘 Channel Attention, Spatial Attention
대표 기법 CBAM, SE-Net, Grad-CAM
강점 지역적 패턴 강조, 계산 효율적
약점 전역적 관계 파악 제한적

Vision Transformer의 Self-Attention: 전역적 관계 학습

ViT는 이미지를 패치로 나누고 Self-Attention을 통해 모든 패치 간의 관계를 학습합니다.

특징 설명
메커니즘 Multi-head Self-Attention
대표 기법 Attention Rollout, Attention Flow
강점 전역적 문맥 이해, 장거리 의존성
약점 계산 복잡도 높음, 데이터 많이 필요

핵심 차이점: CNN은 지역적 특징에 집중하고, ViT는 전역적 관계를 학습합니다.

모델 해석 가능성 향상을 위한 5가지 핵심 기법

1. Grad-CAM (CNN용)

Gradient-weighted Class Activation Mapping은 CNN에서 가장 널리 사용되는 시각화 기법입니다.

import torch
import torch.nn.functional as F
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image

# ResNet 모델 예시
model = torchvision.models.resnet50(pretrained=True)
target_layers = [model.layer4[-1]]

cam = GradCAM(model=model, target_layers=target_layers)
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)

# 원본 이미지에 오버레이
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)

실무 활용 예시:
– 의료 영상: 병변 영역이 실제로 진단에 사용되었는지 확인
– 제조업: 불량 검출 시 어떤 부분이 판단 근거인지 파악

2. Attention Rollout (ViT용)

ViT의 여러 레이어에 걸친 Attention을 누적하여 최종적으로 어떤 패치가 중요한지 계산합니다.

import numpy as np

def attention_rollout(attentions, discard_ratio=0.9):
    result = torch.eye(attentions[0].size(-1))

    for attention in attentions:
        # 각 레이어의 attention 평균
        attention_heads_fused = attention.mean(axis=1)

        # 낮은 값 제거 (노이즈 감소)
        flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
        _, indices = flat.topk(int(flat.size(-1)*discard_ratio), -1, False)
        flat[0, indices] = 0

        # 누적
        result = torch.matmul(attention_heads_fused, result)

    return result

3. Attention Flow

Attention Rollout의 개선 버전으로, 정보 흐름을 더 정확하게 추적합니다.

  • 각 레이어에서 Attention 가중치를 추적
  • Residual connection 고려
  • 더 정확한 패치 중요도 산출

4. Layer-wise Relevance Propagation (LRP)

CNN과 ViT 모두에 적용 가능한 기법으로, 역전파를 통해 각 픽셀의 기여도를 계산합니다.

장점:
– 수학적으로 엄밀한 기여도 분해
– 양/음의 기여도 구분 가능
– 노이즈에 강건함

5. Transformer Interpretability Beyond Attention

최근 연구에 따르면 Attention만으로는 불충분합니다.

# Value 벡터와 Attention 가중치를 함께 고려
def compute_relevance_with_values(attention_weights, value_vectors):
    # Attention * Value의 노름 계산
    weighted_values = attention_weights.unsqueeze(-1) * value_vectors
    relevance = torch.norm(weighted_values, dim=-1)
    return relevance

CNN vs ViT Attention Map 비교표

비교 항목 CNN (Grad-CAM) ViT (Attention Rollout)
해상도 높음 (원본과 유사) 낮음 (패치 단위)
전역 정보 제한적 우수함
계산 비용 낮음 중간
노이즈 적음 많음 (보정 필요)
의료/안전 분야 검증됨 연구 단계

실무 적용 시 주의사항

1. Attention ≠ Explanation

Attention Map이 높다고 해서 반드시 그 영역이 인과적으로 중요한 것은 아닙니다.

대응 방법:
– 여러 시각화 기법을 함께 사용
– Counterfactual 분석 병행
– 도메인 전문가와 검증

2. 하이퍼파라미터 튜닝

# Attention Rollout의 discard_ratio 조정
for ratio in [0.7, 0.8, 0.9]:
    rollout = attention_rollout(attentions, discard_ratio=ratio)
    # 가장 명확한 시각화 선택

3. 정량적 평가 필요

주관적 시각화만으로는 부족합니다.

  • Deletion/Insertion 메트릭: 중요 영역 제거 시 성능 변화 측정
  • IoU with ground truth: 실제 중요 영역과 비교
  • Faithfulness 점수: 설명과 모델 동작의 일치도

실전 활용 예시: 제품 불량 검출

class InterpretableDefectDetector:
    def __init__(self, model_type='cnn'):
        self.model_type = model_type
        if model_type == 'cnn':
            self.model = ResNet50(num_classes=2)
            self.explainer = GradCAM(model, target_layers)
        else:
            self.model = ViT(num_classes=2)
            self.explainer = AttentionRollout()

    def predict_with_explanation(self, image):
        # 예측
        prediction = self.model(image)

        # 설명 생성
        explanation = self.explainer(image)

        # 임계값 기반 알림
        if explanation.max() < 0.3:
            print("경고: 모델이 불확실한 영역을 보고 있습니다")

        return prediction, explanation

마무리

Vision Transformer와 CNN의 Attention Map 분석은 모델을 신뢰하고 개선하는 데 필수적입니다.

핵심 요약:

  1. CNN: Grad-CAM으로 지역적 특징 시각화, 실무 검증 완료
  2. ViT: Attention Rollout/Flow로 전역적 관계 파악, 노이즈 보정 필요
  3. 해석 가능성 향상: LRP, Value 벡터 활용, 정량적 평가 병행
  4. 실무 적용: 여러 기법 조합, 도메인 전문가 검증, Attention의 한계 인식
  5. 미래 방향: Transformer Interpretability Beyond Attention 연구 주목

모델 해석은 한 번의 시각화가 아닌, 지속적인 검증과 개선 프로세스입니다.

실무에서는 성능과 해석 가능성의 균형을 찾는 것이 중요합니다. CNN의 안정성과 ViT의 표현력을 적절히 활용하여, 신뢰할 수 있는 AI 시스템을 구축하시기 바랍니다.

이 글이 도움이 되셨나요? ☕

Buy me a coffee

코멘트

답글 남기기

이메일 주소는 공개되지 않습니다. 필수 필드는 *로 표시됩니다