4️⃣Mamba

맘바는 유명한 트랜스포머의 잠재적 라이벌로 불리며 AI 커뮤니티에서 큰 반향을 불러일으키고 있습니다. 맘바의 명성은 긴 시퀀스에 대한 인상적인 확장 능력에 있습니다. 하지만 복잡한 시퀀스 모델링 환경에서 맘바가 실제로 차별화되는 점은 무엇일까요?

맘바의 위치를 이해하기 위해 기존 모델을 간략히 살펴봅시다.

  1. Transformers: 시퀀스의 어떤 부분이 다른 부분과 동적으로 상호작용할 수 있는 주의 메커니즘으로 잘 알려진 트랜스포머는 특히 인과적 주의 기능을 통해 시퀀스의 개별 요소를 처리하는 데 능숙합니다. 그러나 시퀀스 길이(L²)의 제곱에 따라 확장되는 상당한 계산 및 메모리 비용이 발생합니다.

  2. 순환 신경망(RNN): RNN은 현재 입력과 마지막 숨겨진 상태만 고려하여 숨겨진 상태를 순차적으로 업데이트합니다. 이 접근 방식은 잠재적으로 일정한 메모리 요구 사항으로 무한한 시퀀스 길이를 처리할 수 있습니다. 하지만 RNN의 단순성은 단점이 될 수 있으며, 장기적인 종속성을 기억하는 능력이 제한될 수 있습니다. 또한 RNN의 시간 경과에 따른 역전파(BPTT)는 메모리 집약적일 수 있으며, LSTM과 같은 혁신에도 불구하고 그라데이션이 사라지거나 폭발하는 문제가 발생할 수 있습니다.

  3. 상태 공간 모델(S4): 이 모델은 유망한 특성을 보여주었습니다. 트랜스포머보다 메모리 효율이 높으면서도 RNN보다 장거리 종속성을 더 효과적으로 포착하는 균형을 제공합니다.

맘바의 접근 방식

이제 맘바가 어떤 기능을 제공하는지 자세히 살펴보겠습니다:

  • 선택적 상태 공간: 맘바는 상태 공간 모델의 개념을 기반으로 하지만 새로운 방식을 도입했습니다. 선택적 상태 공간을 활용하여 긴 시퀀스에서 관련 정보를 보다 효율적이고 효과적으로 캡처할 수 있습니다.

  • 선형 시간 복잡성: 트랜스포머와 달리 맘바는 시퀀스 길이와 관련하여 선형 시간으로 작동합니다. 이 특성은 기존 모델로는 어려움을 겪을 수 있는 매우 긴 시퀀스와 관련된 작업에 특히 적합합니다.

선택적 상태 공간

맘바는 선택적 상태 공간이라는 개념을 통해 기존 상태 공간 모델에 흥미로운 변형을 도입했습니다. 이 접근 방식은 표준 상태 공간 모델의 경직된 상태 전환을 약간 완화하여 적응성과 유연성을 높인 것으로, LSTM과 다소 유사합니다. 그러나 맘바는 상태 공간 모델의 효율적인 계산 특성을 유지하여 전체 시퀀스의 순방향 패스를 한 번에 수행할 수 있으며, 이는 트랜스포머를 연상시키는 기능입니다.

맘바를 이용한 훈련 및 추론

훈련 중에 맘바는 트랜스포머와 유사하게 작동하여 전체 시퀀스를 한 번에 처리합니다. 이 접근 방식은 모든 입력이 알려져 있어도 순방향 패스를 단계별로 계산해야 하는 LSTM과는 대조적입니다. 추론에서 Mamba의 동작은 기존의 순환 모델에 더 부합하여 시퀀스를 효율적으로 처리합니다.

이전 모델의 한계

기존 상태 공간 모델(SSM)의 주요 한계는 경직된 입력 불변형 구조입니다. 일반적으로 이러한 모델은 전체 시퀀스에 대해 일련의 고정된 매개변수(이를 A와 B라고 합니다)를 사용합니다. 이 구조는 신호의 변환이 이전의 숨겨진 상태와 입력에 따라 달라질 수 있는 LSTM과 같은 모델보다 훨씬 더 제한적입니다.

맘바 접근법(입력 의존적 전환)

맘바는 다음 숨겨진 상태로의 전환을 계산하는 방식에 패러다임의 변화를 도입했습니다. 맘바의 아키텍처에서는 현재 입력에 따라 전환이 달라질 수 있습니다. 이 접근 방식은 기존 SSM의 고정된 계산 백본과 순환 신경망의 입력 의존적 역동성 사이에서 균형을 이룹니다.

주요 구성 요소:

  1. 고정 백본: 한 숨겨진 상태에서 다음 상태로 전환할 때 고정된 계산(A 행렬로 정의됨)이 유지되므로 시퀀스 전체에서 사전 계산이 가능합니다.

  2. 입력 종속 변환: 입력이 다음 숨겨진 상태(B 행렬로 정의됨)에 영향을 미치는 방식은 이전 숨겨진 상태가 아니라 현재 입력에 따라 달라집니다. 이러한 입력 종속성은 기존 SSM에 비해 더 많은 유연성을 제공합니다.

컴퓨팅 문제 극복하기

이 접근 방식의 계산 수요를 해결하기 위해 Mamba는 하드웨어 인식 알고리즘을 사용합니다. 이 알고리즘은 컨볼루션 대신 스캔 연산을 사용하여 반복적으로 계산을 수행하므로 GPU에서 매우 효율적입니다. 이러한 효율성은 입력 의존적 전환으로 인한 알고리즘 복잡성에도 불구하고 높은 성능을 유지하는 데 매우 중요합니다.

맘바 대 선택적 상태 공간

Mamba와 선택적 상태 공간 모델은 동의어가 아니라는 점을 명확히 하는 것이 중요합니다. Mamba는 선택적 상태 공간의 개념을 사용하는 구현입니다. 이러한 구분은 계산 효율성을 유지하면서 보다 유연하고 입력에 빠르게 반응하도록 SSM 프레임워크를 조정하는 Mamba의 고유한 기여를 강조하기 때문에 매우 중요합니다.

맘바의 약속

맘바의 접근 방식은 시퀀스 모델링의 흥미로운 발전을 보여줍니다. 입력 의존적 전환을 허용하는 동시에 계산 효율성이 높은 백본을 유지함으로써, Mamba는 매우 유연하지만 계산 집약적인 Transformer와 효율적이지만 경직된 기존 SSM 사이의 간극을 메웁니다. 이러한 균형은 자연어 처리에서 게놈 시퀀싱에 이르기까지 다양한 영역에 걸쳐 긴 시퀀스를 처리하는 새로운 기능을 잠재적으로 열어줄 수 있습니다.

GPU 메모리: SRAM 및 HBM

GPU에는 두 가지 주요 메모리 유형이 있습니다: 고대역폭 메모리(HBM)와 정적 랜덤 액세스 메모리(SRAM)입니다. HBM은 대역폭은 높지만 훨씬 빠르지만 크기가 작은 SRAM에 비해 상대적으로 액세스 시간이 느립니다. 이러한 점을 이해한 Mamba는 계산의 핵심을 이루는 행렬 곱셈 시 빠른 액세스를 위해 전략적으로 SRAM을 사용합니다.

데이터 이동 병목현상 극복

계산의 주요 병목 현상은 계산 자체가 아니라 메모리 유형 간의 데이터 이동인 경우가 많습니다. 맘바는 대용량 데이터 전송의 필요성을 크게 줄임으로써 이 문제를 해결합니다. 이산 및 반복 계산과 같은 알고리즘의 중요한 부분을 SRAM에서 직접 실행하여 지연 시간을 줄입니다.

융합된 선택적 스캔 레이어

맘바는 퓨즈드 셀렉티브 스캔 레이어를 도입하여 메모리 요구 사항을 플래시 주의력을 사용하여 최적화된 트랜스포머 구현과 동등한 수준으로 끌어올렸습니다. 이 계층은 특히 모델에서 입력 종속 요소를 처리할 때 효율성을 유지하는 데 매우 중요합니다.

접두사 합산/병렬 스캔을 통한 효율적인 계산

맘바는 효율적인 계산을 위해 접두사 합 또는 병렬 스캔을 활용합니다. 일정한 커널이 필요한 컨볼루션과 달리 접두사 합은 Mamba의 입력 종속성으로 인해 발생하는 다양한 요소를 처리할 수 있습니다. 이 방법은 서로 다른 시간 간격으로 행렬의 누적 곱셈을 계산하는 데 필수적입니다.

실험 결과 및 확장

맘바는 긴 염기서열이 널리 사용되는 언어 모델링과 DNA 시퀀싱 분야에서 유망한 결과를 보여주었습니다. 긴 서열 길이에서도 확장성과 효율성이 뛰어나 일반 서열 모델 백본으로 강력한 후보가 될 수 있습니다.


딥러닝 영역에서 시퀀스 모델링은 여전히 어려운 과제로 남아 있으며, 종종 LSTM이나 트랜스포머와 같은 모델이 이를 해결합니다. 하지만 이러한 모델은 계산 집약적일 수 있습니다. 효율성과 효과를 위해 설계된 선형 시간 시퀀스 모델링 프레임워크인 Mamba를 소개합니다. 이 블로그 게시물에서는 이 혁신적인 접근 방식의 기술적 측면과 코드에 대해 설명하면서 PyTorch를 사용한 Mamba의 구현에 대해 자세히 살펴봅니다.

Importing Libraries and Setting Flags

구현은 필수 라이브러리를 가져오는 것으로 시작됩니다:

import torch  
import torch.nn as nn  
import torch.optim as optim  
from torch.utils.data import DataLoader, Dataset  
from torch.nn import functional as F  
from einops import rearrange  
from tqdm import tqdm  
  
import math  
import os  
import urllib.request  
from zipfile import ZipFile  
  
from transformers import AutoTokenizer  
  
torch.autograd.set_detect_anomaly(True)

Flags and Hyperparameters

# Configuration flags and hyperparameters  
USE_MAMBA = 1  
DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM = 0  
  
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Defining Hyperparameters and Initialization

여기에는 모델 차원(d_model), 상태 크기, 시퀀스 길이, 배치 크기와 같은 하이퍼파라미터가 정의됩니다.

# User-defined hyperparameters  
d_model = 8  
state_size = 128 # Example state size  
seq_len = 100 # Example sequence length  
batch_size = 256 # Example batch size  
last_batch_size = 81 # only for the very last batch of the dataset  
current_batch_size = batch_size  
different_batch_size = False  
h_new = None  
temp_buffer = None

Defining the S6 Module

S6 클래스는 일련의 선형 변환과 이산화 프로세스를 통해 입력 시퀀스를 처리하는 Mamba 아키텍처 내의 정교한 구성 요소를 나타냅니다. 언어 모델링과 같은 시퀀스 모델링 작업의 핵심 요소인 시퀀스의 시간적 역학을 포착하는 데 중요한 역할을 합니다. 이 강의에서는 시퀀스 데이터의 복잡한 요구 사항을 처리하기 위한 텐서 연산 및 사용자 지정 이산화 방법과 같은 고급 기술을 소개합니다.

nn.Module에서 상속된 S6 클래스는 이산화 과정과 순방향 전파를 처리하는 Mamba 모델의 핵심 구성 요소입니다.

class S6(nn.Module):  
	def __init__(self, seq_len, d_model, state_size, device):  
	super(S6, self).__init__()  
	  
	self.fc1 = nn.Linear(d_model, d_model, device=device)  
	self.fc2 = nn.Linear(d_model, state_size, device=device)  
	self.fc3 = nn.Linear(d_model, state_size, device=device)  
	  
	self.seq_len = seq_len  
	self.d_model = d_model  
	self.state_size = state_size  
	  
	#self.A = nn.Parameter(torch.ones(d_model, state_size, device=device))  
	self.A = nn.Parameter(F.normalize(torch.ones(d_model, state_size, device=device), p=2, dim=-1))  
	nn.init.xavier_uniform_(self.A)  
	  
	self.B = torch.zeros(batch_size, self.seq_len, self.state_size, device=device)  
	self.C = torch.zeros(batch_size, self.seq_len, self.state_size, device=device)  
	  
	self.delta = torch.zeros(batch_size, self.seq_len, self.d_model, device=device)  
	self.dA = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)  
	self.dB = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)  
	  
	# h should have dimensions [batch_size, seq_len, d_model, state_size]  
	self.h = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)  
	self.y = torch.zeros(batch_size, self.seq_len, self.d_model, device=device)  
  
  
def discretization(self):  
	# discretization function is defined based on the MAMBA paper's description using ZOH on page 28  
	# in Section C : Mechanics on Selective SSMs  
	# See also "Zero-order hold discretization" maths proof inside https://studywolf.wordpress.com/tag/zero-order-hold/  
	"""  
	Here is an explanation of the mathematical rationale for the formulation of Δt used in Mamba:  
	The key idea is that Δt controls the discretization rate of the continuous SSM dynamics. By making Δt input-dependent, it introduces selectivity into the discrete transition matrices.  
	Specifically, in Mamba they parameterize Δt as:  
	Δt = τΔ(Parameter + sΔ(xt))  
	Where:  
	- Parameter is a learned scalar parameter that controls the baseline discretization rate  
	- sΔ(xt) is a projection that makes Δt input-dependent by computing a value based on xt  
	- τΔ(x) = softplus(x) transforms the result to be positive through the softplus nonlinearity  
	The rationale for this formulation is:  
	- Parameter provides a reasonable default discretization rate  
	- sΔ(xt) injects input-dependence through the projection  
	- softplus ensures Δt is positive as required to be a valid timestep  
	- The projection sΔ allows the model to learn to modulate Δt based on the input xt  
	- This modulation creates selectivity in how rapidly or slowly the states update  
	So in summary, the learned input-dependent projection allows Δt, and thus the discrete dynamics, to become selective. The softplus and scalar parameter provide useful inductive biases on top of this flexibility.  
	The end result is discrete transition matrices that are selective on the input, enabling powerful sequence modeling capabilities.  
	Credit: Claude2 AI chatbot  
	"""  
	  
	# inverse() only supports square matrix  
	#dB = torch.matmul(torch.inverse(A * delta), torch.matmul(dA - torch.eye(A.shape[0]), B))  
	self.dB = torch.einsum("bld,bln->bldn", self.delta, self.B)  
	  
	# https://github.com/state-spaces/mamba/blob/0131c1e94a46fc9f70bcfc9d57962963bb2f0b9e/mamba_ssm/modules/mamba_simple.py#L240  
	#dA = torch.matrix_exp(A * delta) # matrix_exp() only supports square matrix  
	self.dA = torch.exp(torch.einsum("bld,dn->bldn", self.delta, self.A))  
	#print(f"self.dA.shape = {self.dA.shape}")  
	#print(f"self.dA.requires_grad = {self.dA.requires_grad}")  
	  
	return self.dA, self.dB  
  
def forward(self, x):  
	# Refer to Algorithm 2 in the MAMBA paper  
	self.B = self.fc2(x)  
	self.C = self.fc3(x)  
	self.delta = F.softplus(self.fc1(x))  
	  
	# Uses ZOH as in MAMBA, Hungry Hippo still uses bilinear transform for discretization  
	self.discretization()  
	  
	if DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM: # this will trigger in-place runtime error if without using `h_new`  
	  
		global current_batch_size  
		current_batch_size = x.shape[0]  
		  
		if self.h.shape[0] != current_batch_size:  
			#print("Adjusting h_new for the different batch size of input data `x`")  
			different_batch_size = True  
			  
			# Resize self.h to match the current batch size  
			h_new = torch.einsum('bldn,bldn->bldn', self.dA, self.h[:current_batch_size, ...]) + rearrange(x, "b l d -> b l d 1") * self.dB  
			  
		else:  
			different_batch_size = False  
			h_new = torch.einsum('bldn,bldn->bldn', self.dA, self.h) + rearrange(x, "b l d -> b l d 1") * self.dB  
	  
	# y needs to have a shape of [batch_size, seq_len, d_model]  
	self.y = torch.einsum('bln,bldn->bld', self.C, h_new)  
	  
	# Update self.h with the detached state of h_new  
	# Only do this if retaining gradients for self.h is not necessary for backprop  
	# Otherwise, store h_new in a temporary list and update self.h after the loop  
	global temp_buffer  
	temp_buffer = h_new.detach().clone() if not self.h.requires_grad else h_new.clone()  
	  
	return self.y  
	  
	else: # this will not trigger in-place runtime error  
	# h should have dimensions [batch_size, seq_len, d_model, state_size]  
	h = torch.zeros(x.size(0), self.seq_len, self.d_model, self.state_size, device=x.device)  
	y = torch.zeros_like(x)  
	  
	h = torch.einsum('bldn,bldn->bldn', self.dA, h) + rearrange(x, "b l d -> b l d 1") * self.dB  
	  
	# y needs to have a shape of [batch_size, seq_len, d_model]  
	y = torch.einsum('bln,bldn->bld', self.C, h)  
	  
	return y

Defining the Mamba Block

MambaBlcok 클래스는 맘바 모델의 핵심 빌딩 블록으로 설계된 맞춤형 신경망 모듈입니다. 입력 데이터를 처리하기 위한 여러 레이어와 연산을 캡슐화합니다.

MambaBlcok 클래스는 선형 투영, 컨볼루션, 활성화 함수, 사용자 정의 S6 모듈 및 잔여 연결로 구성된 복잡한 신경망 블록을 나타냅니다. 이 블록은 데이터의 관련 패턴과 특징을 포착하기 위해 일련의 변환을 통해 입력 시퀀스를 처리하는 Mamba 모델의 기본 구성 요소입니다. 이러한 다양한 레이어와 작업의 조합을 통해 맘바블록은 복잡한 시퀀스 모델링 작업을 효과적으로 처리할 수 있습니다.

class MambaBlock(nn.Module):  
	def __init__(self, seq_len, d_model, state_size, device):  
		super(MambaBlock, self).__init__()  
		  
		self.inp_proj = nn.Linear(d_model, 2*d_model, device=device)  
		self.out_proj = nn.Linear(2*d_model, d_model, device=device)  
		  
		# For residual skip connection  
		self.D = nn.Linear(d_model, 2*d_model, device=device)  
		  
		# Set _no_weight_decay attribute on bias  
		self.out_proj.bias._no_weight_decay = True  
		  
		# Initialize bias to a small constant value  
		nn.init.constant_(self.out_proj.bias, 1.0)  
		  
		self.S6 = S6(seq_len, 2*d_model, state_size, device)  
		  
		# Add 1D convolution with kernel size 3  
		self.conv = nn.Conv1d(seq_len, seq_len, kernel_size=3, padding=1, device=device)  
		  
		# Add linear layer for conv output  
		self.conv_linear = nn.Linear(2*d_model, 2*d_model, device=device)  
		  
		# rmsnorm  
		self.norm = RMSNorm(d_model, device=device)  
	  
	def forward(self, x):  
		"""  
		x_proj.shape = torch.Size([batch_size, seq_len, 2*d_model])  
		x_conv.shape = torch.Size([batch_size, seq_len, 2*d_model])  
		x_conv_act.shape = torch.Size([batch_size, seq_len, 2*d_model])  
		"""  
		# Refer to Figure 3 in the MAMBA paper  
		  
		x = self.norm(x)  
		  
		x_proj = self.inp_proj(x)  
		#print(f"x_proj.shape = {x_proj.shape}")  
		  
		# Add 1D convolution with kernel size 3  
		x_conv = self.conv(x_proj)  
		#print(f"x_conv.shape = {x_conv.shape}")  
		  
		x_conv_act = F.silu(x_conv)  
		#print(f"x_conv_act.shape = {x_conv_act.shape}")  
		  
		# Add linear layer for conv output  
		x_conv_out = self.conv_linear(x_conv_act)  
		#print(f"x_conv_out.shape = {x_conv_out.shape}")  
		  
		x_ssm = self.S6(x_conv_out)  
		x_act = F.silu(x_ssm) # Swish activation can be implemented as x * sigmoid(x)  
		#print(f"x_act.shape = {x_act.shape}")  
		  
		# residual skip connection with nonlinearity introduced by multiplication  
		x_residual = F.silu(self.D(x))  
		#print(f"x_residual.shape = {x_residual.shape}")  
		x_combined = x_act * x_residual  
		#print(f"x_combined.shape = {x_combined.shape}")  
		  
		x_out = self.out_proj(x_combined)  
		#print(f"x_out.shape = {x_out.shape}")  
		  
		return x_out

Defining the Mamba Model

Mamba 클래스는 일련의 맘바블록 모듈로 구성된 맘바 모델의 전체 아키텍처를 나타냅니다. 각 블록은 입력 데이터를 처리하는 데 기여하며, 한 블록의 출력은 다음 블록의 입력으로 사용됩니다. 이러한 순차적 처리를 통해 모델은 입력 데이터의 복잡한 패턴과 관계를 포착할 수 있으므로 시퀀스 모델링과 관련된 작업에 효과적입니다. 여러 블록을 쌓는 것은 모델이 데이터의 계층적 표현을 학습할 수 있게 해주므로 딥러닝 아키텍처의 일반적인 설계입니다.

이 클래스는 전체 맘바 모델을 정의하며, 모델의 아키텍처를 위해 여러 맘바블록 인스턴스를 체인으로 연결합니다.

class Mamba(nn.Module):  
	def __init__(self, seq_len, d_model, state_size, device):  
		super(Mamba, self).__init__()  
		self.mamba_block1 = MambaBlock(seq_len, d_model, state_size, device)  
		self.mamba_block2 = MambaBlock(seq_len, d_model, state_size, device)  
		self.mamba_block3 = MambaBlock(seq_len, d_model, state_size, device)  
	  
	def forward(self, x):  
		x = self.mamba_block1(x)  
		x = self.mamba_block2(x)  
		x = self.mamba_block3(x)  
		return x

Defining RMSNorm

RMSNorm 클래스는 사용자 정의 정규화 레이어로, PyTorch의 nn.Module을 확장합니다. 이 레이어는 신경망의 활성화를 정규화하는 데 사용되며, 이를 통해 학습을 안정화하고 속도를 높일 수 있습니다.

RMSNorm은 신경망 아키텍처의 일반적인 기술인 평균제곱근 레이어 정규화를 위한 레이어입니다.

class RMSNorm(nn.Module):  
	def __init__(self,  
		d_model: int,  
		eps: float = 1e-5,  
		device: str ='cuda'):  
		super().__init__()  
		self.eps = eps  
		self.weight = nn.Parameter(torch.ones(d_model, device=device))  
	  
	  
	def forward(self, x):  
		output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight  
		  
		return output

Example Usage

x = torch.rand(batch_size, seq_len, d_model, device=device)  
# Create the Mamba model  
mamba = Mamba(seq_len, d_model, state_size, device)  
  
# rmsnorm  
norm = RMSNorm(d_model)  
x = norm(x)  
  
# Forward pass  
test_output = mamba(x)  
print(f"test_output.shape = {test_output.shape}") # Should be [batch_size, seq_len, d_model]

Data Preparation and Training Functions

Enwiki8Dataset 클래스는 언어 모델링과 같은 시퀀스 모델링 작업을 위해 구조화된 데이터셋으로 작업하도록 특별히 설계된 PyTorch의 데이터셋 클래스를 확장하는 사용자 지정 데이터셋 핸들러입니다.

class Enwiki8Dataset(Dataset):  
	def __init__(self, data):  
		self.data = data  
	  
	def __len__(self):  
		return len(self.data['input_ids'])  
	  
	def __getitem__(self, idx):  
		item = {key: val[idx].clone().detach() for key, val in self.data.items()}  
		return item

pad_sequences_3d 함수는 시퀀스 배치를 일정한 길이로 패딩하여 배치의 각 시퀀스가 동일한 수의 요소(또는 시간 단계)를 갖도록 설계되었습니다. 이는 입력 데이터의 모양이 일정해야 하는 많은 머신 러닝 작업에서 특히 중요합니다.

# Define a function for padding  
def pad_sequences_3d(sequences, max_len=None, pad_value=0):  
	# Assuming sequences is a tensor of shape (batch_size, seq_len, feature_size)  
	batch_size, seq_len, feature_size = sequences.shape  
	  
	if max_len is None:  
		max_len = seq_len + 1  
	  
	  
	# Initialize padded_sequences with the pad_value  
	padded_sequences = torch.full((batch_size, max_len, feature_size), fill_value=pad_value, dtype=sequences.dtype, device=sequences.device)  
	# Pad each sequence to the max_len  
	padded_sequences[:, :seq_len, :] = sequences  
	  
	return padded_sequences

train 함수는 맘바 모델을 훈련하기 위해 설계되었습니다. 그 구성 요소를 살펴 보겠습니다:

def train(model, tokenizer, data_loader, optimizer, criterion, device, max_grad_norm=1.0, DEBUGGING_IS_ON=False):

이 함수에서 사용하는 매개변수:

  • model: 학습하고자 하는 모델명(Mamba in this case).

  • tokenizer: 입력 데이터를 처리할 토큰화 도구

  • data_loader: 학습을 위해 데이터 배치를 제공하는 이터러블

  • optimizer: 모델의 가중치를 업데이트하는 데 사용되는 최적화 알고리즘

  • criterion: 모델의 성능을 평가하는 데 사용되는 손실 함수

  • device: 모델이 실행될 장치(CPU 또는 GPU)

  • max_grad_norm: 그라데이션이 폭발하는 것을 방지하기 위한 그라데이션 클리핑 값

  • DEBUGGING_IS_ON: 디버깅 정보를 활성화하기 위한 플래그

def train(model, tokenizer, data_loader, optimizer, criterion, device, max_grad_norm=1.0, DEBUGGING_IS_ON=False):  
	model.train()  
	total_loss = 0  
		for batch in data_loader:  
		optimizer.zero_grad()  
		  
		input_data = batch['input_ids'].clone().to(device)  
		attention_mask = batch['attention_mask'].clone().to(device)  
		  
		# In most sequence modeling tasks, like language modeling, the target should be the next token  
		# in the sequence rather than the input token itself.  
		# This is because the model's goal is to predict the next word given the previous words.  
		# Shift the input data by one position to get the target, so that each target token  
		# is the next token following the input token.  
		target = input_data[:, 1:]  
		input_data = input_data[:, :-1]  
		  
		# Pad all the sequences in the batch:  
		input_data = pad_sequences_3d(input_data, pad_value=tokenizer.pad_token_id)  
		target = pad_sequences_3d(target, max_len=input_data.size(1), pad_value=tokenizer.pad_token_id)  
		  
		if USE_MAMBA:  
			output = model(input_data)  
			loss = criterion(output, target)  
		  
		loss.backward(retain_graph=True)  
		  
		# Clip gradients: gradients are modified in place  
		#torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)  
		for name, param in model.named_parameters():  
			if 'out_proj.bias' not in name:  
				# clip weights but not bias for out_proj  
				torch.nn.utils.clip_grad_norm_(param, max_norm=max_grad_norm)  
		  
		if DEBUGGING_IS_ON:  
			for name, parameter in model.named_parameters():  
				if parameter.grad is not None:  
					print(f"{name} gradient: {parameter.grad.data.norm(2)}")  
				else:  
					print(f"{name} has no gradient")  
		  
		if USE_MAMBA and DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM:  
			model.S6.h[:current_batch_size, ...].copy_(temp_buffer)  
		  
		optimizer.step()  
		  
		total_loss += loss.item()  
	return total_loss / len(data_loader)

Evaluate Definition

evaluate 함수는 데이터 세트에서 맘바 모델의 성능을 평가하도록 설계되었습니다. 이 함수를 자세히 살펴보겠습니다:

def evaluate(model, data_loader, criterion, device):

This function takes four parameters:

  • model: 평가하고자 하는 모델

  • data_loader: 평가를 위해 데이터 배치를 제공하는 이터러블

  • criterion: 모델의 성능을 평가하는 데 사용되는 손실 함수

  • device: 평가를 수행할 장치(CPU 또는 GPU)

def evaluate(model, data_loader, criterion, device):  
	model.eval()  
	total_loss = 0  
	with torch.no_grad():  
		for batch in data_loader:  
			input_data = batch['input_ids'].clone().detach().to(device)  
			attention_mask = batch['attention_mask'].clone().detach().to(device)  
			  
			# In most sequence modeling tasks, like language modeling, the target should be the next token  
			# in the sequence rather than the input token itself.  
			# This is because the model's goal is to predict the next word given the previous words.  
			# Shift the input data by one position to get the target, so that each target token  
			# is the next token following the input token.  
			target = input_data[:, 1:]  
			input_data = input_data[:, :-1]  
			  
			# Pad all the sequences in the batch:  
			input_data = pad_sequences_3d(input_data, pad_value=tokenizer.pad_token_id)  
			target = pad_sequences_3d(target, max_len=input_data.size(1), pad_value=tokenizer.pad_token_id)  
			  
			if USE_MAMBA:  
				output = model(input_data)  
				loss = criterion(output, target)  
			total_loss += loss.item()  
		return total_loss / len(data_loader)

calculate_perplexity 함수는 Mamba와 같은 언어 모델의 성능을 평가하는 데 간단하지만 중요한 함수입니다.

def calculate_perplexity(loss):  
	return math.exp(loss)

load_enwiki8_dataset 함수는 언어 모델 벤치마킹에 일반적으로 사용되는 enwiki8데이터셋을 다운로드하고 추출하도록 설계되었습니다.

def load_enwiki8_dataset():  
	print(f"Download and extract enwiki8 data")  
	url = "http://mattmahoney.net/dc/enwik8.zip"  
	urllib.request.urlretrieve(url, "enwik8.zip")  
	  
	with ZipFile("enwik8.zip") as f:  
		data = f.read("enwik8").decode("utf-8")  
	  
	return data

encode_dataset 함수는 데이터 집합을 토큰화 및 인코딩하여 Mamba와 같은 신경망 모델에서 처리할 수 있도록 준비하기 위해 설계되었습니다.

# Tokenize and encode the dataset  
def encode_dataset(tokenizer, text_data):  
	def batch_encode(tokenizer, text_data, batch_size=1000):  
		# Tokenize in batches  
		batched_input_ids = []  
		for i in range(0, len(text_data), batch_size):  
			batch = text_data[i:i+batch_size]  
			inputs = tokenizer(batch, add_special_tokens=True, truncation=True,  
								padding='max_length', max_length=seq_len,  
								return_tensors='pt')  
			batched_input_ids.append(inputs['input_ids'])  
		return torch.cat(batched_input_ids)  
	  
	# Assuming enwiki8_data is a list of sentences  
	input_ids = batch_encode(tokenizer, enwiki8_data)  
	  
	# vocab_size is the number of unique tokens in the tokenizer's vocabulary  
	global vocab_size  
	vocab_size = len(tokenizer.vocab) # Note that for some tokenizers, we might access the vocab directly  
	print(f"vocab_size = {vocab_size}")  
	  
	# Create an embedding layer  
	# embedding_dim is the size of the embedding vectors (MAMBA model's D)  
	embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)  
	  
	# Pass `input_ids` through the embedding layer  
	# This will change `input_ids` from shape [B, L] to [B, L, D]  
	#encoded_input = embedding_layer(input_ids) ## this eats memory, so use batched_embedding_calls instead  
	def batch_embedding_calls(input_ids, embedding_layer, batch_size=256):  
		# Check if input_ids is already a tensor, if not convert it  
		if not isinstance(input_ids, torch.Tensor):  
			input_ids = torch.tensor(input_ids, dtype=torch.long)  
		  
		# Calculate the number of batches needed  
		num_batches = math.ceil(input_ids.size(0) / batch_size)  
		  
		# List to hold the output embeddings  
		output_embeddings = []  
		  
		# Process each batch  
		for i in range(num_batches):  
			# Calculate start and end indices for the current batch  
			start_idx = i * batch_size  
			end_idx = start_idx + batch_size  
			  
			# Get the batch  
			input_id_batch = input_ids[start_idx:end_idx]  
			  
			# Call the embedding layer  
			with torch.no_grad(): # No need gradients for this operation  
				batch_embeddings = embedding_layer(input_id_batch)  
			  
			# Append the result to the list  
			output_embeddings.append(batch_embeddings)  
		  
		# Concatenate the embeddings from each batch into a single tensor  
		all_embeddings = torch.cat(output_embeddings, dim=0)  
		  
		return all_embeddings  
	  
	# `input_ids` is a list or tensor of the input IDs and `embedding_layer` is model's embedding layer  
	if USE_MAMBA:  
		# Set `batch_size` to a value that works for memory constraints  
		encoded_inputs = batch_embedding_calls(input_ids, 
							embedding_layer, 
							batch_size=1).float()  
	  
	attention_mask = (input_ids != tokenizer.pad_token_id).type(input_ids.dtype)  
	return encoded_inputs, attention_mask

Tokenization

# Load a pretrained tokenizer  
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

Model Training

# Assuming encoded_inputs is a preprocessed tensor of shape [num_samples, seq_len, d_model]  
encoded_inputs_file = 'encoded_inputs_mamba.pt'  
  
  
if os.path.exists(encoded_inputs_file):  
	print("Loading pre-tokenized data...")  
	encoded_inputs = torch.load(encoded_inputs_file)  
else:  
	print("Tokenizing raw data...")  
	enwiki8_data = load_enwiki8_dataset()  
	encoded_inputs, attention_mask = encode_dataset(tokenizer, enwiki8_data)  
	torch.save(encoded_inputs, encoded_inputs_file)  
	print(f"finished tokenizing data")  
  
  
# Combine into a single dictionary  
data = {  
	'input_ids': encoded_inputs,  
	'attention_mask': attention_mask  
}  
  
# Split the data into train and validation sets  
total_size = len(data['input_ids'])  
train_size = int(total_size * 0.8)  
  
train_data = {key: val[:train_size] for key, val in data.items()}  
val_data = {key: val[train_size:] for key, val in data.items()}  
  
train_dataset = Enwiki8Dataset(train_data)  
val_dataset = Enwiki8Dataset(val_data)  
  
  
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)  
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)  
  
  
# Initialize the model  
  
model = Mamba(seq_len, d_model, state_size, device).to(device)  
  
# Define the loss function and optimizer  
criterion = nn.CrossEntropyLoss()  
optimizer = optim.AdamW(model.parameters(), lr=5e-6)  
  
# Training loop  
num_epochs = 25 # Number of epochs to train for  
  
for epoch in tqdm(range(num_epochs)): # loop over the dataset multiple times  
	train_loss = train(model, tokenizer, train_loader, optimizer, criterion, device, max_grad_norm=10.0, DEBUGGING_IS_ON=False)  
	val_loss = evaluate(model, val_loader, criterion, device)  
	val_perplexity = calculate_perplexity(val_loss)  
	print(f'Epoch: {epoch+1}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Validation Perplexity: {val_perplexity:.4f}')

Conclusion

지금까지 Mamba를 처음부터 구축하는 포괄적인 코드 워크스루를 마무리하면서, 이론을 실제에 적용하는 복잡한 구현 과정을 살펴봤습니다. 이번 탐험을 통해 맘바의 내부 작동 원리에 대한 이해가 더욱 깊어졌을 뿐만 아니라 강력한 AI 도구를 구현하는 데 필요한 실질적인 단계도 확인할 수 있었습니다. 이러한 실습을 통해 시퀀스 모델링의 미묘한 차이와 Mamba가 이 영역에 도입한 효율성에 대해 알아봤습니다. 이제 이러한 지식으로 무장하여 프로젝트에서 Mamba를 실험하거나 혁신적인 AI 솔루션을 더 깊이 있게 개발할 수 있게 되었습니다. 모든 코드 라인은 AI 기술을 마스터하기 위한 단계이며, Mamba는 이 흥미로운 분야의 무한한 가능성을 보여주는 증거라는 점을 기억하세요.


Reference

Building Mamba from Scratch: A Comprehensive Code Walkthrough https://medium.com/ai-insights-cobet/building-mamba-from-scratch-a-comprehensive-code-walkthrough-5db040c28049

Last updated