KAN vs MLP: Kolmogorov-Arnold Networks 차세대 신경망의 수학적 원리와 성능 비교

KAN(Kolmogorov-Arnold Networks)이란?

2024년 MIT 연구팀이 발표한 KAN(Kolmogorov-Arnold Networks)은 전통적인 MLP(Multi-Layer Perceptron)의 대안으로 주목받는 신경망 아키텍처입니다. 1957년 콜모고로프-아놀드 정리(Kolmogorov-Arnold Representation Theorem)에 기반하여 설계되었으며, 엣지(edge)에 학습 가능한 활성화 함수를 배치하는 독창적인 구조를 가집니다.

핵심 아이디어: 노드가 아닌 엣지에서 비선형 변환을 수행하여 표현력을 극대화합니다.

Kolmogorov-Arnold 정리의 수학적 배경

콜모고로프-아놀드 정리는 다변수 연속 함수를 1변수 함수들의 합성으로 표현할 수 있다는 수학적 정리입니다:

f(x1,...,xn)=q=02nΦq(p=1nϕq,p(xp))f(x_1, …, x_n) = \sum_{q=0}^{2n} \Phi_q \left( \sum_{p=1}^{n} \phi_{q,p}(x_p) \right)

여기서:
ff: nn개 변수를 가진 다변수 함수 (우리가 근사하려는 목표 함수)
x1,...,xnx_1, …, x_n: 입력 변수들
ϕq,p\phi_{q,p}: 내부 1변수 함수 (각 입력 변수를 개별 변환)
Φq\Phi_q: 외부 1변수 함수 (변환된 값들을 최종 합성)
q,pq, p: 함수 인덱스 (최대 2n+12n+1개의 외부 함수 필요)

이 정리는 복잡한 고차원 함수도 1차원 함수들의 합성만으로 정확히 표현 가능함을 증명합니다.

MLP vs KAN: 구조적 차이점

특성 MLP KAN
활성화 함수 위치 노드(뉴런) 엣지(연결선)
활성화 함수 종류 고정(ReLU, Sigmoid 등) 학습 가능(Spline 등)
파라미터 수 O(n×m)O(n \times m) O(n×m×k)O(n \times m \times k)
해석 가능성 낮음 높음(엣지별 기여도 시각화)
학습 안정성 높음(수십 년 최적화) 중간(초기 연구 단계)
계산 복잡도 낮음 높음(Spline 연산)

MLP의 수식 표현

전통적인 MLP는 선형 변환 후 노드에서 활성화합니다:

y=σ(Wx+b)\mathbf{y} = \sigma(W\mathbf{x} + \mathbf{b})

  • WW: 가중치 행렬 (학습 대상)
  • x\mathbf{x}: 입력 벡터
  • b\mathbf{b}: 편향(bias) 벡터
  • σ\sigma: 고정 활성화 함수 (ReLU, tanh 등)
  • y\mathbf{y}: 출력 벡터

KAN의 수식 표현

KAN은 각 엣지마다 독립적인 활성화 함수를 적용합니다:

yj=i=1nϕi,j(xi)y_j = \sum_{i=1}^{n} \phi_{i,j}(x_i)

  • ϕi,j\phi_{i,j}: 입력 ii에서 출력 jj로 가는 엣지의 학습 가능한 1변수 함수
  • ϕ\phi는 B-spline 등으로 파라미터화
  • 선형 변환(WxW\mathbf{x}) 대신 비선형 함수들의 직접 합

KAN의 장점: 왜 주목받는가?

1. 뛰어난 함수 근사 능력

특히 부드러운(smooth) 저차원 함수에서 압도적 성능을 보입니다. 물리 시뮬레이션, 수학 공식 학습에서 MLP 대비 100배 적은 파라미터로 동일 정확도 달성 사례가 보고되었습니다.

2. 해석 가능성(Interpretability)

각 엣지의 활성화 함수를 시각화하면 어떤 입력이 출력에 얼마나 기여하는지 직관적으로 파악할 수 있습니다. 과학 연구·금융·의료 분야에서 중요한 특성입니다.

import matplotlib.pyplot as plt
import numpy as np

# 예시: KAN의 특정 엣지 활성화 함수 시각화
x = np.linspace(-3, 3, 100)
phi = kan_model.get_edge_function(layer=0, input_idx=2, output_idx=5)
y = phi(x)

plt.plot(x, y, linewidth=2)
plt.xlabel('Input Feature 2')
plt.ylabel('Activation')
plt.title('KAN Edge Activation: Feature 2 → Neuron 5')
plt.grid(True, alpha=0.3)
plt.show()

3. 연속 학습(Continual Learning)에 유리

Spline 기반 활성화 함수는 국소적(local) 조정이 가능하여, 새로운 데이터 학습 시 기존 지식을 덜 망각합니다(catastrophic forgetting 완화).

성능 비교 실험 결과

실험 1: 수학 함수 근사

목표 함수: f(x,y)=exp(sin(πx)+y2)f(x, y) = \exp(\sin(\pi x) + y^2)

모델 파라미터 수 MSE 손실 학습 시간
MLP (4층, 256 유닛) 263,000 0.0021 45초
KAN (2층, 32 노드) 2,048 0.0008 120초

결론: KAN이 128배 적은 파라미터로 더 낮은 오차 달성. 단, 학습 시간은 2.7배 증가.

실험 2: MNIST 분류

데이터셋: 손글씨 숫자 60,000장

모델 정확도 추론 속도 (batch=128)
MLP (3층, 512 유닛) 98.2% 2.1ms
KAN (3층, 64 노드) 97.8% 8.7ms

결론: 이미지 분류 같은 고차원 데이터에서는 MLP가 여전히 우세. KAN은 속도 면에서 불리.

실험 3: 물리 방정식 학습 (PDE)

문제: 편미분방정식 ut=α2ux2\frac{\partial u}{\partial t} = \alpha \frac{\partial^2 u}{\partial x^2} (열 방정식)

# PyTorch 기반 KAN 구현 예시 (간소화)
class KANLayer(nn.Module):
    def __init__(self, in_features, out_features, grid_size=5):
        super().__init__()
        self.grid_size = grid_size
        # 각 엣지마다 B-spline 계수 학습
        self.spline_weight = nn.Parameter(
            torch.randn(out_features, in_features, grid_size)
        )

    def forward(self, x):
        # B-spline 기저 함수 계산
        basis = self.b_splines(x)  # [batch, in_features, grid_size]
        # 엣지별 활성화 함수 적용 후 합산
        output = torch.einsum('bik,oik->bo', basis, self.spline_weight)
        return output

model_kan = nn.Sequential(
    KANLayer(2, 32),  # (t, x) → 32 hidden
    KANLayer(32, 32),
    KANLayer(32, 1)   # → u(t, x)
)

결과: KAN이 MLP 대비 50배 빠른 수렴, 물리 법칙 학습에서 탁월한 귀납적 편향(inductive bias) 제공.

KAN의 한계와 도전 과제

1. 계산 비용

Spline 연산은 행렬 곱셈보다 복잡하여 GPU 가속 최적화가 충분하지 않습니다. 현재 PyTorch 구현은 MLP 대비 3~10배 느립니다.

2. 고차원 데이터 취약성

입력 차원이 수천~수만(NLP, Vision)인 경우 파라미터 수가 폭발하고 과적합 위험이 커집니다. Transformer나 CNN과의 결합 연구가 필요합니다.

3. 학습 불안정성

Spline 파라미터 초기화, 학습률 설정이 까다롭습니다. 조기 학습 붕괴 현상이 보고되었으며, 정규화 기법 연구가 진행 중입니다.

4. 제한적인 라이브러리 지원

2024년 기준 공식 PyTorch 구현은 연구용 수준이며, TensorFlow/JAX 지원은 부족합니다. 프로덕션 환경 적용에는 시간이 필요합니다.

실무 적용 가이드

KAN을 사용하면 좋은 경우

  • 과학 시뮬레이션: 물리/화학 방정식, PDE 솔버
  • 수학 모델링: 함수 근사, 기호 회귀(symbolic regression)
  • 작은 표 데이터: 센서 데이터, 금융 시계열 (입력 차원 < 100)
  • 해석성 요구: 의료 진단, 신용 평가 등 규제 산업

MLP가 여전히 나은 경우

  • 이미지/텍스트: CNN, Transformer와 결합된 표준 아키텍처
  • 대규모 배포: 추론 속도가 중요한 실시간 시스템
  • 높은 차원: 입력이 수천 차원 이상인 경우
  • 안정성 우선: 검증된 훈련 파이프라인 필요 시

하이브리드 접근

# 예시: 저차원 특징 추출은 KAN, 고차원 분류는 MLP
class HybridModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = KANLayer(10, 32)  # 핵심 특징 학습
        self.classifier = nn.Sequential(
            nn.Linear(32, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        features = self.feature_extractor(x)  # 해석 가능한 특징
        logits = self.classifier(features)     # 강력한 분류
        return logits

최신 연구 동향 (2024~2025)

  • Efficient KAN: Sparse attention 기법 도입으로 계산량 50% 감소
  • Convolutional KAN: CNN 구조와 결합, ImageNet에서 경쟁력 있는 성능
  • KAN-Transformer: Self-attention 대신 KAN 레이어 사용, 긴 시퀀스에서 개선
  • Physics-Informed KAN: 물리 법칙을 손실 함수에 직접 임베딩

2025년 NeurIPS에서는 20편 이상의 KAN 관련 논문이 발표되어 학계의 높은 관심을 입증했습니다.

마무리

KAN(Kolmogorov-Arnold Networks)은 엣지에 학습 가능한 활성화 함수를 배치하는 혁신적 아키텍처로, 특히 저차원 부드러운 함수 근사해석 가능성에서 MLP를 능가합니다. 수학적 정리에 기반한 이론적 탄탄함과 물리 시뮬레이션에서의 실증적 성과가 강점입니다.

하지만 고차원 데이터 처리, 계산 효율성, 학습 안정성 면에서는 아직 MLP에 미치지 못하며, 대규모 프로덕션 환경에 적용하기에는 시간이 더 필요합니다. 현재는 연구 도구특수 도메인(과학·금융) 에서 빛을 발하고 있습니다.

앞으로 하이브리드 모델, 효율적 구현, Transformer 결합 연구가 활발해지면 범용 딥러닝 아키텍처로 자리 잡을 가능성이 높습니다. MLP 이후 40년 만에 등장한 근본적 패러다임 전환으로, 지속적인 추적이 필요한 기술입니다.

이 글이 도움이 되셨나요?

Buy me a coffee

댓글

답글 남기기

이메일 주소는 공개되지 않습니다. 필수 필드는 *로 표시됩니다

TODAY 37 | TOTAL 236