배경

최근 Quantization Simulation 내용이 포함된 Machine Learning Pipeline 스크립트를 작성할 일이 있었다. 비교적 최근에 AI 조직으로 합류하게 된 엔지니어의 교육용이자 Hands-on을 위한 스크립트였는데, 각 블록이 어떤 역할을 하는지 코드 레벨로 확인할 수 있게 로그를 심는 부분이 있었다. 회사 내의 업무라 정확한 내용을 얘기하기는 좀 어렵지만, 다루는 모델의 크기가 크기 때문에 모델에 대한 특정 정보를 로깅 해서 단순히 남길 때 너무 많은 로그가 생성되고 그중 많은 부분이 중복되는 정보여서 여간 신경 쓰이는 것이 아니었다. 사실 교육용 스크립트고, 읽는 사람이 약간의 귀찮음을 감수하면 큰 문제는 없었지만, 코드 리뷰 때 비슷한 질의를 다른 엔지니어들에게도 몇 번 받았고 이런 부류의 문제는 어떤 식으로 해결하는지 개인적으로 궁금하기도 하여 비슷한 상황을 모사할 수 있는 스크립트와 해결 방식을 남겨둔다.

초기 구현

import logging
import torch
from torch.nn.modules.module import register_module_forward_pre_hook
from transformers import AutoModelForCausalLM
logger = logging.getLogger(__name__)
logger.setLevel('INFO')
stream_handler = logging.StreamHandler()
logger.addHandler(stream_handler)
checkpoint = 'HuggingFaceTB/SmolLM2-135M'
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16)
def module_type_hook(module, _) -> None:
logger.info(f'Model contains {type(module)} module')
register_module_forward_pre_hook(module_type_hook)
dummy_input = torch.zeros(1, 64, dtype=torch.long)
with torch.inference_mode():
model(dummy_input)

유사한 상황을 모사할 수 있는 초기 스크립트를 우선 작성해 봤다. 우선 Logger를 하나 생성 후 몇 가지 설정 (Level, Handler 등)을 하고 HuggingFace에서 제공하는 작은 크기의 모델 중 하나인 SmolLM2-135M를 가져온다. 해당 스크립트의 목적은 모델을 구성하고 있는 PyTorch module의 종류를 파악하기 위해서 hook 함수를 하나 만들고, 최종적으로는 dummy_input을 forward 함수에 태워서 로그를 출력하는 것이 목적이다. 아마 대부분의 ML Engineer에 해당 스크립트는 아주 어렵지 않을 것이지만, 혹시 Logger, HF transformers 패키지, 또는 PyTorch Hook에 대해서 익숙하지 않다면 다음과 같은 페이지들을 참고로 남겨둔다. 해당 페이지들을 읽은 후 위 코드를 직접 실행해 보면 의도를 더 쉽게 이해할 수 있을 것이다.

위 스크립트를 실행해 보면 다음과 같은 결과를 얻을 수 있다.

Model contains <class 'transformers.models.llama.modeling_llama.LlamaForCausalLM'> module
Model contains <class 'transformers.models.llama.modeling_llama.LlamaModel'> module
Model contains <class 'torch.nn.modules.sparse.Embedding'> module
Model contains <class 'transformers.models.llama.modeling_llama.LlamaRotaryEmbedding'> module
Model contains <class 'transformers.models.llama.modeling_llama.LlamaDecoderLayer'> module
Model contains <class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'> module
Model contains <class 'transformers.models.llama.modeling_llama.LlamaSdpaAttention'> module
Model contains <class 'torch.nn.modules.linear.Linear'> module
Model contains <class 'torch.nn.modules.linear.Linear'> module
Model contains <class 'torch.nn.modules.linear.Linear'> module
Model contains <class 'torch.nn.modules.linear.Linear'> module
Model contains <class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'> module
Model contains <class 'transformers.models.llama.modeling_llama.LlamaMLP'> module
Model contains <class 'torch.nn.modules.linear.Linear'> module
Model contains <class 'torch.nn.modules.activation.SiLU'> module
Model contains <class 'torch.nn.modules.linear.Linear'> module
Model contains <class 'torch.nn.modules.linear.Linear'> module
Model contains <class 'transformers.models.llama.modeling_llama.LlamaDecoderLayer'> module
Model contains <class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'> module
Model contains <class 'transformers.models.llama.modeling_llama.LlamaSdpaAttention'> module
Model contains <class 'torch.nn.modules.linear.Linear'> module
Model contains <class 'torch.nn.modules.linear.Linear'> module
Model contains <class 'torch.nn.modules.linear.Linear'> module
Model contains <class 'torch.nn.modules.linear.Linear'> module
# 이하 중략

작은 크기의 모델이라고는 해도 Decoder Layer가 30개나 되다 보니 전체 로그 수가 모듈 개수만큼 기록되고, 한눈에 봐도 중복되는 로그가 많이 보인다. 만약에 더 큰 크기의 모델을 같은 스크립트에 대해서 실행하면 더 많은 의미 없는 로그들이 많이 쌓이게 될 것은 쉽게 짐작할 수 있겠다. 즉, 우리는 의미 있는 정보 그러니까 우리가 관심 있는 정보인 모델을 구성하는 모듈 타입의 종류에 대해서만 적절하게 로그를 남길 수 있으면 좋겠다는 결론에 이르게 된다.

Cache를 활용한 트릭

이런 중복 로그를 남기지 않는 방법으로는 여러 구현 방식이 있겠지만, Python에서 제공하는 @functools.cache 또는 @functools.lru_cache(None)을 활용해 볼 수 있습니다. 기본적으로 캐시는 비싼 연산을 매번 계산하지 않기 위해서, 사전에 연산에 대해서 계산한 값을 저장 후 이후에 같은 입력에 대해서는 계산 없이 캐싱해 둔 값을 반환하는 개념인데요. 중복 로깅을 방지하는 트릭에서도  캐싱 기능을 활용해 볼 수 있습니다. 캐시 자체에 대한 개념은 해당 포스트의 범위를 조금 벗어나니 필요하다면 https://docs.python.org/3/library/functools.html#functools.cache 문서를 참고하세요. 거두절미하고 트릭을 적용한 스크립트를 보겠습니다

import functools
import logging
import torch
from torch.nn.modules.module import register_module_forward_pre_hook
from transformers import AutoModelForCausalLM
logger = logging.getLogger(__name__)
logger.setLevel('INFO')
stream_handler = logging.StreamHandler()
logger.addHandler(stream_handler)
@functools.lru_cache(None)
def info_once(msg):
logger.info(msg)
checkpoint = 'HuggingFaceTB/SmolLM2-135M'
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16)
def module_type_hook(module, _) -> None:
info_once(f'Model contains {type(module)} module')
register_module_forward_pre_hook(module_type_hook)
dummy_input = torch.zeros(1, 64, dtype=torch.long)
with torch.inference_mode():
model(dummy_input)

info_once라는 lru_cache로 데코레이트 된 메서드를 하나 선언하고, 기존에 hook 함수에서 호출하던 logger.info를 info_once로 바꿔주는 단순한 변경입니다. 다시 말하자면 처음 맞이하는 로그 메시지에 대해서는 logger.info(msg)를 호출하게 되고, 그 이후에 동일 msg에 대해서는 이미 캐싱이 되어있으므로 logger.info(msg)를 호출하는 대신 캐싱 된 결과를 반환하는데요. 이때 logger.info(msg)의 결과는 None이기 때문에 자연스럽게 중복된 로그에 대해서는 더 이상 로그가 찍히지 않게 됩니다. 실제 실행 결과는 아래와 같습니다

Model contains <class 'transformers.models.llama.modeling_llama.LlamaForCausalLM'> module
Model contains <class 'transformers.models.llama.modeling_llama.LlamaModel'> module
Model contains <class 'torch.nn.modules.sparse.Embedding'> module
Model contains <class 'transformers.models.llama.modeling_llama.LlamaRotaryEmbedding'> module
Model contains <class 'transformers.models.llama.modeling_llama.LlamaDecoderLayer'> module
Model contains <class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'> module
Model contains <class 'transformers.models.llama.modeling_llama.LlamaSdpaAttention'> module
Model contains <class 'torch.nn.modules.linear.Linear'> module
Model contains <class 'transformers.models.llama.modeling_llama.LlamaMLP'> module
Model contains <class 'torch.nn.modules.activation.SiLU'> module

무시무시한 (?) LLM의 명성과 다르게 실제로 구성하고 있는 모듈의 타입은 꽤 단순하죠? Leaf 레벨의 모듈만 고려하면 6가지 모듈로만 구성되어 있습니다. 본 포스트에서는 단순한 시나리오에 대해서 중복 로그를 방지하는 예제 코드를 다뤘는데요. 조금 변형하면 각자의 상황에 맞춰서 더 개선할 수 있을 것입니다.

결론

캐싱의 성질을 활용하여 중복으로 남는 로그를 똑똑하게 방지할 수 있다!

References

반응형

댓글

댓글을 사용할 수 없습니다.