모델 증류(Distillation) 이해하기

최근 DeepSeek 관련 최신 뉴스를 살펴봤다면, "증류(Distillation)" 라는 용어를 접했을 가능성이 높습니다. 하지만 증류란 정확히 무엇이며, 왜 중요한 걸까요?

이 글에서는 먼저 증류라는 개념과 과정을 설명한 후, Pytorch를 활용한 실습 예제를 통해 이를 실제로 구현해 보려 합니다. 이 글을 끝까지 읽고 이해한다면, 모델 증류의 원리와 중요성을 한층 깊이 이해하지 않을까 합니다.

최근 DeepSeek 관련 최신 뉴스를 살펴봤다면, “증류(Distillation)” 라는 용어를 접했을 가능성이 높습니다. 하지만 증류란 정확히 무엇이며, 왜 중요한 걸까요?

이 글에서는 먼저 증류라는 개념과 과정을 설명한 후, Pytorch를 활용한 실습 예제를 통해 이를 실제로 구현해 보려 합니다. 이 글을 끝까지 읽고 이해한다면, 모델 증류의 원리와 중요성을 한층 깊이 이해하지 않을까 합니다.(참고 자료에서는 텐서플로우로 되어 있는 것을 파이토치로 변환해서 설명합니다.)

모델 증류는 어떻게 동작하는지?

모델 증류는 “더 크고 복잡한 모델(교사, teacher)“의 지식을 “더 작고 간단한 모델(학생, student)“이 학습하도록 하는 기법입니다. 이 과정에서 학생 모델은 정답 레이블(raw labels)이 아니라, 교사 모델이 제공하는 부드러운 확률 분포(softened probability outputs)를 학습합니다.

이를 통해 학생 모델은 교사 모델의 지식을 보다 압축된 형태로 흡수하여, 적은 매개변수로도 유사한 성능을 달성할 수 있습니다.

예를 들어, 이미지 분류에서 단순히 “강아지” 또는 “고양이”라는 정답만 학습하는 것이 아니라, 교사 모델의 신뢰도 점수(confidence scores) 를 활용하여 학습합니다.
예를 들어, 교사 모델이 한 이미지를 분석했을 때 다음과 같은 확률을 부여했다고 가정해 보겠습니다.

  • 강아지(dog): 80%
  • 고양이(cat): 15%
  • 여우(fox): 5%

학생 모델은 이처럼 세밀한 확률 정보를 학습함으로써 더 정교한 분류 능력을 갖출 수 있습니다. 이러한 모델 증류 기법은 모델 크기와 연산 비용을 줄이면서도 높은 정확도를 유지할 수 있다는 장점이 있습니다.

MNIST 데이터셋이란?

MNIST 데이터셋(Modified National Institute of Standards and Technology)은 머신러닝과 컴퓨터 비전 분야에서 널리 사용되는 벤치마크 데이터셋입니다.

이 데이터셋은 손으로 쓴 숫자(0~9)의 흑백 이미지 70,000장으로 구성되어 있으며,
각 이미지는 28×28 픽셀 크기로 되어 있습니다.

MNIST 데이터셋은 다음과 같이 구성됩니다.

  • 훈련 데이터: 60,000장
  • 테스트 데이터: 10,000장

이 데이터셋은 손글씨 숫자를 인식하는 다양한 머신러닝 및 딥러닝 모델을 훈련하고 평가하는 데 널리 활용됩니다.

우선, 교사(Teacher Model) 모델을 보겠습니다.

교사 모델(Teacher Model)과 학생 모델(Student Model)

교사 모델은 MNIST 데이터셋을 사용하여 학습된 CNN(합성곱 신경망) 모델입니다.

또한, 학생 모델(Student Model)도 존재하는데, 이 모델은 교사 모델보다 더 작고 단순한 구조를 갖습니다.

모델 증류의 목표

모델 증류의 핵심 목표는 더 작은 학생 모델(Student Model)이 더 큰 교사 모델(Teacher Model)의 성능을 모방하도록 학습하는 것입니다.
이를 통해 연산 비용과 훈련 시간을 줄이면서도 높은 성능을 유지할 수 있습니다.

모델 증류 과정
  1. 이를 통해 학생 모델이 점진적으로 교사 모델의 지식을 학습하게 됩니다.
  2. 교사 모델과 학생 모델이 동일한 데이터셋을 사용하여 예측을 수행합니다.
    • 두 모델의 출력값 차이를 측정하기 위해 Kullback-Leibler(KL) 발산을 계산합니다.
    • KL 발산은 두 확률 분포 사이의 차이를 정량적으로 평가하는 방법입니다.
  3. KL 발산 값을 기반으로 모델이 업데이트될 방향을 결정합니다.
    • 경사값(gradients)을 계산하여, 학생 모델이 교사 모델의 출력을 더 잘 모방하도록 조정합니다.
    • 이를 통해 학생 모델이 점진적으로 교사 모델의 지식을 학습하게 됩니다.

PyTorch로 구현하는 모델 증류 (Model Distillation) 예시

모델 증류(Model Distillation)는 작고 단순한 학생 모델(Student Model)더 크고 복잡한 교사 모델(Teacher Model)의 성능을 모방하도록 학습하는 기법입니다.
이를 통해 연산 비용을 줄이면서도 높은 정확도를 유지할 수 있습니다.

예제 시나리오:

  • 사 모델(Teacher Model): 깊은 합성곱 신경망(CNN)
  • 학생 모델(Student Model): 더 적은 층을 가진 얕은 CNN
  • 데이터셋: 손글씨 숫자 데이터셋인 MNIST 사용
  • 목표: 학생 모델이 교사 모델의 지식을 학습하여, 더 적은 연산량으로 유사한 성능을 내도록 학습
환경 설정 및 MNIST 데이터 불러오기

먼저 PyTorch를 사용하여 MNIST 데이터를 로드합니다.

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# GPU 사용 가능 여부 확인
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# MNIST 데이터셋 불러오기
transform = transforms.Compose([
    transforms.ToTensor(),  # 이미지를 Tensor로 변환
    transforms.Normalize((0.1307,), (0.3081,))  # 평균과 표준편차로 정규화
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

# DataLoader 설정
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)

# 데이터 샘플 확인
images, labels = next(iter(train_loader))

# 첫 9개 이미지 출력
fig, axes = plt.subplots(1, 9, figsize=(10, 3))
for i, ax in enumerate(axes):
    ax.imshow(images[i].squeeze(), cmap='gray')
    ax.set_title(f"Label: {labels[i].item()}")
    ax.axis('off')
plt.show()
교사 모델(Teacher Model) 정의

합성곱 신경망(CNN) 기반의 교사 모델을 정의합니다.

class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)  # Softmax 없이 로짓 값 출력

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)  # Flatten
        x = F.relu(self.fc1(x))
        x = self.fc2(x)  # Softmax 적용 안 함 (증류 과정에서 사용할 예정)
        return x

# 모델 생성
teacher_model = TeacherModel().to(device)
교사 모델 학습

Adam 옵티마이저와 Categorical CrossEntropy(Softmax 포함)를 사용하여 학습합니다.

def train_teacher(model, train_loader, epochs=5):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    loss_fn = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        total_loss = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        
        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}")

# 교사 모델 학습
train_teacher(teacher_model, train_loader)

epoch 5회 학습하면 아래와 같습니다.

학생 모델(Student Model) 정의

더 가벼운 CNN 모델을 학생 모델로 설정합니다.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)  
        self.pool = nn.MaxPool2d(2, 2)  # MaxPooling2D((2,2))
        self.fc1 = nn.Linear(16 * 14 * 14, 64)  # Flatten 후 Fully Connected Layer
        self.fc2 = nn.Linear(64, 10)  # 최종 출력층 (10개의 클래스, No Softmax)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # Conv → ReLU → MaxPooling
        x = x.view(-1, 16 * 14 * 14)  # Flatten
        x = F.relu(self.fc1(x))  # 첫 번째 Fully Connected Layer
        x = self.fc2(x)  # 최종 로짓 값 출력 (Softmax 없음)
        return x

student_model = StudentModel().to(device)
모델 증류 손실 함수 (Distillation Loss)

Knowledge Distillation은 이미 학습된 teacher 모델의 정보를 student 모델에 전달하여, student 모델이 보다 나은 일반화 능력과 성능을 갖추도록 하는 기법입니다. 이 과정의 핵심은 디스틸레이션 손실 함수를 통해 두 모델 간 예측 분포의 차이를 줄이는 데 있습니다. 아래에서 그 구체적인 과정을 단계별로 살펴보겠습니다.

1. Teacher 모델의 Soft Target 생성
  • Soft Target: teacher 모델은 입력 데이터에 대해 예측을 수행하여, 각 클래스에 대한 확률 분포(soft target)를 생성합니다. 이 확률 분포는 단순한 정답(하드 라벨)이 아니라, 각 클래스에 대해 모델이 가지는 확신 정도를 반영합니다.
2. Student 모델의 Soft Probability 계산
  • Soft Probability: student 모델도 동일한 입력 데이터에 대해 예측을 수행한 후, softmax 함수를 적용하여 각 클래스에 대한 확률 분포(soft probability)를 산출합니다. 이 분포는 teacher 모델의 soft target과 비교되어 학습 과정의 지표로 활용됩니다.
Soft Probabilities란?

Soft probabilities는 한 가지 정답만을 제공하는 hard label과 달리, 여러 가능한 결과에 대해 확률 분포를 제공합니다. 예를 들어, 이메일 분류 문제에서 단순히 “스팸” 혹은 “스팸 아님”으로 분류하는 대신, 이메일이 85% 확률로 스팸, 15% 확률로 스팸 아님이라는 정보를 제공할 수 있습니다. 이 방식은 모델이 데이터의 불확실성을 반영하고, 미세한 차이를 학습하는 데 큰 도움을 줍니다.

Softmax 함수와 Temperature 파라미터
  • Softmax 함수: 모델의 예측값을 확률 분포로 변환하는 역할을 합니다.
  • Temperature 파라미터: 이 파라미터는 softmax 함수가 산출하는 분포의 부드러움을 조절합니다.
    • 높은 Temperature: 확률 분포가 부드러워지며, 각 클래스 간 차이가 줄어들어 모델이 미세한 관계를 학습할 수 있습니다.
    • 낮은 Temperature: 확률 분포가 날카로워져 특정 클래스에 대한 확신이 강해지며, 하드 라벨에 가까워집니다.
Knowledge Distillation에서의 활용

Teacher 모델이 제공하는 soft probabilities는 단순한 정답 예측을 넘어서, 클래스 간의 미묘한 연관성과 차이를 student 모델이 학습하도록 돕습니다. 이를 통해 student 모델은 더 나은 일반화 성능과 응용 분야에서의 효율성을 갖출 수 있습니다.

3. KL Divergence를 통한 분포 차이 측정
  • KL Divergence: teacher와 student 모델이 생성한 두 확률 분포 간의 차이를 측정하기 위해 Kullback-Leibler Divergence를 사용합니다. 이 값은 두 분포가 얼마나 다른지를 수치화하며, 클수록 student 모델이 teacher 모델의 지식을 잘 모방하지 못했음을 의미합니다.
4. 디스틸레이션 손실 반환

최종적으로 계산된 KL Divergence 값을 디스틸레이션 손실로 반환합니다. 이 손실 값을 최소화하는 방향으로 student 모델을 학습시킴으로써, teacher 모델의 미세한 지식과 클래스 간 관계를 효과적으로 전달할 수 있습니다.

# ------------------------------------------------
# Knowledge Distillation을 위한 distillation loss 함수 정의
# ------------------------------------------------
def distillation_loss(y_true, student_logits, x_batch, teacher_model, temperature=5):
    """
    KL Divergence를 사용하여 teacher 모델과 student 모델의 soft probability 분포 차이를 계산합니다.
    
    파라미터:
    - y_true: 실제 레이블 (여기서는 사용되지 않음)
    - student_logits: student 모델의 출력 (logits)
    - x_batch: 입력 배치 (teacher 모델 예측에 사용)
    - teacher_model: 미리 학습된 teacher 모델
    - temperature: softmax의 temperature 값 (높을수록 확률 분포가 부드러워짐)
    """
    teacher_model.eval()  # teacher 모델은 평가 모드로 전환
    with torch.no_grad():
        teacher_logits = teacher_model(x_batch)
    
    # temperature를 적용하여 softmax와 log_softmax 연산 수행
    teacher_probs = F.softmax(teacher_logits / temperature, dim=1)
    student_log_probs = F.log_softmax(student_logits / temperature, dim=1)
    
    # KL Divergence 계산 (배치 단위 평균)
    loss = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean')
    return loss
학생 모델 학습

증류 손실을 사용하여 학생 모델을 학습합니다.

# ------------------------------------------------
# 단일 학습 스텝 함수 정의
# ------------------------------------------------
def train_step(x_batch, y_batch, student_model, teacher_model, optimizer, temperature=5):
    student_model.train()          # student 모델을 학습 모드로 전환
    optimizer.zero_grad()          # 기존 gradient 초기화
    
    # student 모델의 예측값 계산
    student_logits = student_model(x_batch)
    
    # distillation loss 계산
    loss = distillation_loss(y_batch, student_logits, x_batch, teacher_model, temperature)
    
    # 역전파 및 파라미터 업데이트
    loss.backward()
    optimizer.step()
    
    return loss.item()

# ------------------------------------------------
# 옵티마이저 설정 (student_model의 파라미터 업데이트)
# ------------------------------------------------
optimizer = optim.Adam(student_model.parameters())

# ------------------------------------------------
# 전체 학습 루프 (MNIST 데이터셋 사용)
# ------------------------------------------------
epochs = 5
for epoch in range(epochs):
    total_loss = 0.0
    num_batches = 0
    for x_batch, y_batch in train_loader:
        # 데이터를 device로 이동
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        
        # 단일 학습 스텝 수행
        loss = train_step(x_batch, y_batch, student_model, teacher_model, optimizer, temperature=5)
        total_loss += loss
        num_batches += 1
    
    avg_loss = total_loss / num_batches
    print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")

print("Student Model Training Complete!")

위의 train_step 함수는 하나의 학습 단계(training step) 를 수행하는 역할을 합니다.

  1. Student 모델의 예측값 계산
    • 입력 데이터를 기반으로 student 모델이 예측을 수행합니다.
  2. 디스틸레이션 손실(distillation loss) 계산
    • Teacher 모델의 예측값을 활용하여, student 모델과의 차이를 측정하는 디스틸레이션 손실을 계산합니다.
  3. Gradient 계산 및 Student 모델의 가중치 업데이트
    • 손실 함수를 기반으로 student 모델의 가중치를 조정하여 학습을 진행합니다.
Student 모델 학습 과정

Student 모델을 학습하기 위해 훈련 루프(training loop) 를 구성해야 합니다.
이 루프는 데이터셋을 여러 번 반복(iterate)하며, 각 단계에서 student 모델의 가중치를 업데이트합니다.

또한, 각 에포크(epoch)마다 손실(loss)을 출력하여 학습 진행 상황을 모니터링할 수 있습니다.

학습이 완료되면 다음과 같은 출력 결과를 확인할 수 있습니다.

학생 모델 평가
def evaluate(model, test_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
predicted = outputs.argmax(dim=1)
correct += (predicted == labels).sum().item()
total += labels.size(0)

print(f"학생 모델 정확도: {100 * correct / total:.2f}%")

# 학생 모델 평가
evaluate(student_model, test_loader)

예상대로, student 모델이 상당히 좋은 정확도를 달성했습니다.

Teacher 모델과 Student 모델을 사용한 예측 수행

이제 teacher 모델과 student 모델을 사용하여 예측을 수행할 수 있습니다.
MNIST 테스트 데이터셋의 숫자를 정확하게 예측할 수 있는지 확인해 보세요.

import torch
import numpy as np
import matplotlib.pyplot as plt

# MNIST 테스트 데이터 로드 (PyTorch 버전)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True)

# 모델을 평가 모드로 변경
teacher_model.eval()
student_model.eval()

# 테스트 데이터에서 5개 샘플을 확인
for index, (x_test, y_test) in enumerate(test_loader):
if index == 5: # 5개만 출력
break

# 이미지 출력
plt.figure(figsize=(2, 2))
plt.imshow(x_test.squeeze(), cmap="gray", interpolation="none")
plt.title(f"Digit: {y_test.item()}")
plt.xticks([])
plt.yticks([])
plt.show()

# 입력 데이터 변환
x_test = x_test.to(device) # GPU로 전송 (가능하면)

# 교사 모델 예측
with torch.no_grad():
teacher_logits = teacher_model(x_test)
teacher_pred = torch.argmax(teacher_logits, dim=1).cpu().numpy()

print(f"🔵 Predicted value by Teacher Model: {teacher_pred}")

# 학생 모델 예측
with torch.no_grad():
student_logits = student_model(x_test)
student_pred = torch.argmax(student_logits, dim=1).cpu().numpy()

print(f"🟢 Predicted value by Student Model: {student_pred}")


결론

이 글에서는 모델 증류(Model Distillation) 개념을 알아보았습니다. 이 기술을 활용하면 작은 규모의 student 모델이 보다 크고 복잡한 teacher 모델의 성능을 모방할 수 있습니다.

MNIST 데이터셋을 사용하여 teacher 모델을 학습한 후, 증류 기법을 적용하여 student 모델을 학습하는 과정을 단계별로 수행하였습니다.

그 결과, 더 적은 층과 낮은 복잡도를 가진 student 모델이 teacher 모델의 성능을 성공적으로 모방하면서도,훨씬 적은 연산 자원(computational resources)만으로 효율적으로 동작할 수 있음을 확인했습니다.

참고자료

https://ai.gopubby.com/understanding-model-distillation-991ec90019b6

댓글 남기기

이메일 주소는 공개되지 않습니다. 필수 필드는 *로 표시됩니다