문제 정의 및 현재 상황

이전 포스트에서 얘기한 것처럼 외부 API의 호출 전후로 그래프의 개형이 바뀌는 문제의 원인은 파악했는데요. 이제 그러면 우리가 원하는 동작을 하도록 코드를 수정하거나 개선을 할 차례입니다. 하고 싶은 일은 외부 API가 의존하는 onnxsim.simplify에 skipped_optimizers를 전달하는 것인데요. 자세한 내용 서술에 앞서서 전체적인 호출자 로직과 3rd Party API의 예제 구현체는 다음과 같습니다

호출자 로직

import onnx
from transformers import AutoTokenizer, GPTNeoForCausalLM

from third_party_api import prepare_model

if __name__ == "__main__":
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125m")
    model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-125m")
    inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")

    simplified_model = prepare_model(model, inputs["input_ids"], "./data/model.onnx")
    onnx.save_model(simplified_model, "./data/simplified_model_monkey_patch.onnx")

비교적 작은 모델인 GPT-NEO 125M 모델을 로드하고 3rd-party API인 prepare_model을 호출하는 간단한 흐름입니다

3rd Party API 로직

# third_party_api.py
import onnx
import onnxsim
import torch.onnx
from torch import nn


class LinearModulePrinter:
    def __init__(self, name):
        self.name = name

    def print_module_info(self):
        print(f"This is a Linear module {self.name}")


class ConvModulePrinter:
    def __init__(self, name):
        self.name = name

    def print_module_info(self):
        print(f"This is a Conv module {self.name}")


PRINT_BY_MODULE_TYPE = {
    nn.Linear: LinearModulePrinter,
    nn.Conv2d: ConvModulePrinter,
    # Can be extended to other modules
}


def prepare_model(
    model: nn.Module, 
    dummy_input: torch.Tensor, 
    path: str,
):    
    # Print original model info
    for name, module in model.named_modules():
        if printer := PRINT_BY_MODULE_TYPE.get(type(module)):
            printer(name=name).print_module_info()

    # Export PyTorch model as ONNX format
    torch.onnx.export(model, dummy_input, path, opset_version=12)

    # Load ONNX model and apply simplification
    onnx_model = onnx.load(path)
    simplified_model, check = onnxsim.simplify(onnx_model)

    return simplified_model

3rd-party API의 prepare_model은 호출자로부터 model과 dummy_input 그리고 저장할 경로를 받고

  1. 원본 모델의 Module 정보를 출력
  2. ONNX 모델로 저장 및 다시 로드
  3. ONNX 모델을 Simplification

크게 3가지 과정으로 구성되어 있습니다

처음 생각했던 합리적인 방법은?

모든 개발 업무는 우리가 직접 해결해야만 하는 것은 아니라고 생각합니다. 대화로 풀거나 정책적으로도 해결할 수 있겠죠. 특히 이번 사례는 해당 API를 관리하는 팀에게 요청해서 필요한 인자를 받을 수 있게 수정을 요청하는 것이 깔끔한 방법일 테고요. 이를테면 아래와 같이 prepare_model API에서 skipped_optimizers를 받을 수 있게 하고, 그대로 simplify에 전달만 하면 될 텐데요

def prepare_model(
    model: nn.Module, 
    dummy_input: torch.Tensor, 
    path: str,
    skipped_optimizers: List[str],
):
    # 앞 부분 로직 생략
    simplified_model, check = onnxsim.simplify(onnx_model, skipped_optimizers=skipped_optimizers)

    return simplified_model

그러나 세상 모든 일이 그렇듯, 우리가 생각하는 합리적인 방안이 상대방에게는 그렇지 않을 수도 있습니다. 지금 당장 해당 API를 수정할 개발자가 부족해서, 다른 우선순위 업무가 높아서 등 여러 가지 이유가 있을 수 있기 때문입니다. 그렇다면 우리는 담당 팀의 지원이 아닌 다른 방법으로 문제를 해결해야겠네요 :(

Monkey Patch

Monkey Patch 의 정의는 다음과 같습니다 ( https://en.wikipedia.org/wiki/Monkey_patch )

Monkey patching is a technique used to dynamically update the behavior of a piece of code at run-time. A monkey patch (also spelled monkey-patch, MonkeyPatch) is a way to extend or modify the runtime code of dynamic languages (e.g. Smalltalk, JavaScript, Objective-C, Ruby, Perl, Python, Groovy, etc.) without altering the original source code.

즉, 런타임에 원본 소스 코드 수정 없이 우리가 원하는 목적에 맞게 동작하도록 하는 일종의 기교라고 볼 수 있는데요. 사실 말이 거창해서 그렇지, Python에서의 module method의 Monkey Patch는 아주 간단하게 가능합니다. prepare_model을 호출하기 전에 아래와 같이 patch를 해버리면 되는데요. 즉, 이제 런타임에서 onnxsim.simplify는 skipped_optimizers 인자가 설정된 상태로 실행이 됩니다

# Monkey Patch
patched_method = functools.partial(onnxsim.simplify, skipped_optimizers=["fuse_qkv"])
setattr(onnxsim, "simplify", patched_method)

simplified_model = prepare_model(model, inputs["input_ids"], "./data/model.onnx")

즉, prepare_model 내부가 아닌 호출자 쪽에서만 패치를 해서 우리가 원하는 목적은 얻을 수 있었습니다. 다만, 이렇게 강제로 덮어쓰기를 하는 것이 찝찝하기도 하고, 실제로 디버깅이나 테스팅 시에 문제가 있을 수 있는데요. 따라서 저희는 이런 Patch를 좀 더 아름답고 Pythonic 한 방식으로 개선해 보겠습니다

Context Manager를 활용하자!

Python을 많이 사용하시는 분들은 with open(...) as f 와 같은 구문을 많이 보셨을텐데요. 이를 Context manager라고 하며, 보통은 해당 블럭을 탈출하면서 파일 닫기, 커넥션 종료와 같은 용도로 많이 사용 합니다. 분량상 context manager의 자세한 내용은 표준 문서 ( https://docs.python.org/ko/3/library/contextlib.html ) 를 참고하세요. 우리는 탈출하면서 정리한다는 부분을 잘 활용하여 위에서 했던 Monkey Patch를 좀 더 아름답게 해보려고 하는데요. 바로 코드로 들어가면 다음과 같습니다

@contextlib.contextmanager
def patch_onnx_simplifier(
    skipped_optimizers: Optional[List[str]],
) -> contextlib.AbstractContextManager:
    original_method = getattr(onnxsim, "simplify")
    patched_method = functools.partial(
        onnxsim.simplify, skipped_optimizers=skipped_optimizers
    )

    try:
        setattr(onnxsim, "simplify", patched_method)
        yield
    finally:
        setattr(onnxsim, "simplify", original_method)

# Monkey Patch
with patch_onnx_simplifier(skipped_optimizers=["fuse_qkv"]):
    simplified_model = prepare_model(model, inputs["input_ids"], "./data/model.onnx")
# Will be recovered after exiting context manager

contextlib 데코레이터를 활용하여 context manager를 직접 구현하는데요. 큰 흐름은 다음과 같습니다

  1. 나중에 복구할 원본 onnxsim.simplify 메서드를 original_method 변수에 저장
  2. 우리의 목적을 위한 Patch된 메서드를 patched_method 변수에 저장
  3. try 블럭 -> context manager 내부에서는 onnx.simplify는 patched_method로 동작
  4. finally 블럭 -> context manager에서 탈출하면서 다시 onnx.simplify를 original_method로 복구

드디어 새로 만든 context manager를 통해 필요한 부분에서만 적절히 Monkey patch를 하고 탈출하면서 원본 메서드를 그대로 복구할 수 있게 되었습니다!

특정 Class method 구현을 바꿔치기 하고 싶다면?

앞의 onnxsim.simplify 예제로 Python module의 method patch는 알아봤는데요. 만약, 외부 API 일부 로직이 특정 Class에 의존하고 있고 해당 메서드를 변경하고 싶을 때는 어떻게 해야 할까요? 때는 앞의 patch와 비슷한 방식으로 해결이 가능한 예도 있지만, 다른 방식도 소개해 드리고 싶은데요. 마찬가지로 예제를 통해서 같이 알아보겠습니다

class LinearModulePrinter:
    def __init__(self, name):
        self.name = name

    def print_module_info(self):
        print(f"This is a Linear module {self.name}")


class ConvModulePrinter:
    def __init__(self, name):
        self.name = name

    def print_module_info(self):
        print(f"This is a Conv module {self.name}")


PRINT_BY_MODULE_TYPE = {
    nn.Linear: LinearModulePrinter,
    nn.Conv2d: ConvModulePrinter,
    # Can be extended to other modules
}


def prepare_model(model: nn.Module, dummy_input: torch.Tensor, path: str):
    # Print original model info
    for name, module in model.named_modules():
        if printer := PRINT_BY_MODULE_TYPE.get(type(module)):
            printer(name=name).print_module_info()

외부 API의 prepare_model의 앞 로직을 보면 PRINT_BY_MODULE_TYPE 딕셔너리 객체에 포함된 Printer 클래스들에 의존하고 있는데요. 저는 외부 API가 제공하는 Printer가 아니라 조금 다른 메시지를 출력하고 싶습니다. 역시 context manager와 결합하여 변경되는 부분을 최소화하고 싶은데요. 아래와 같이 해볼 수 있겠습니다 :)

class CustomLinearModulePrinter(LinearModulePrinter):
    def print_module_info(self):
        print(f"Module name: {self.name} Custom message of Linear module")


@contextlib.contextmanager
def patch_custom_module_printer() -> contextlib.AbstractContextManager:
    original_linear_printer = PRINT_BY_MODULE_TYPE[nn.Linear]

    try:
        PRINT_BY_MODULE_TYPE[nn.Linear] = CustomLinearModulePrinter
        yield
    finally:
        PRINT_BY_MODULE_TYPE[nn.Linear] = original_linear_printer


if __name__ == "__main__":
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125m")
    model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-125m")
    inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")

    # Monkey Patch
    with patch_custom_module_printer():
        simplified_model = prepare_model(
            model, inputs["input_ids"], "./data/model.onnx"
        )

구현은 간단합니다. 상속을 통해서 재사용을 하는 예제라고도 볼 수 있겠는데요. LinearModulePrinter를 상속받아서 print_module_info 메서드를 재정의하고 이전처럼 Context Manager 내에서 적절히 바꿔치기 후 복원하면 됩니다. 실제 실행 결과는 다음과 같습니다

As-is

This is a Linear module transformer.h.0.attn.attention.k_proj
This is a Linear module transformer.h.0.attn.attention.v_proj
This is a Linear module transformer.h.0.attn.attention.q_proj
...

To-be (Context manager 적용 후)

Module name: transformer.h.0.attn.attention.k_proj Custom message of Linear module
Module name: transformer.h.0.attn.attention.v_proj Custom message of Linear module
Module name: transformer.h.0.attn.attention.q_proj Custom message of Linear module
...

여러 개의 Context Manager를 관리하고 싶어요

앞에서 우리는 원본 코드 수정 없이 Patch를 위해서 2개의 Context Manager를 구현하였는데요. 2가지 Context manager를 모두 patch 하고 싶다면 다음과 같이 작성하면 됩니다

    # Monkey Patch
    with patch_onnx_simplifier(
        skipped_optimizers=["fuse_qkv"]
    ), patch_custom_module_printer():
        simplified_model = prepare_model(
            model, inputs["input_ids"], "./data/model.onnx"
        )
    # Will be recovered after exiting context manager

그런데 이런 Context manager가 2개가 아니라 계속해서 늘어나거나, Programmable하게 변경하고 싶다면 좀 더 나은 방법이 있을까요? 그럴 때를 위해서 저희는 Contextlib 에서 제공하는 ExitStack()을 사용할 수 있을 것 같습니다

    # Monkey Patch
    with contextlib.ExitStack() as stack:
        stack.enter_context(patch_onnx_simplifier(skipped_optimizers=["fuse_qkv"]))
        stack.enter_context(patch_custom_module_printer())

        simplified_model = prepare_model(
            model, inputs["input_ids"], "./data/model.onnx"
        )
    # Will be recovered after exiting context manager

이제 우리는 동적으로 stack에 ContextManager들을 넣을 수도 있고, List로 관리하고 있다면 간단히 iteration하면서 stack에 넣기만 하면 됩니다!

마무리

이번 포스트를 통해서 저희는 원본 코드 수정 없이 원하는 목적을 이루기 위해서 아래와 같은 내용들을 다뤄봤습니다

  • Monkey Patch / Class Inheritance
  • Context Manager
  • ExitStack

대부분의 Python 사용 사례는 이번 포스트에서 다룬 내용으로 대부분 커버가 될 것으로 생각합니다만
혹시 제가 생각하지 못한 다른 재미난 사례가 있다면 편하게 댓글로 공유해주세요!