PyTorch Custom Op to ONNX
앞선 포스트에서 몇 번 언급했던 것처럼 최근 머신러닝 애플리케이션은 연구자 (조직마다 다르지만 보통 Researcher, Data Scientist, Applied Scientist 등으로 표현하는 포지션)가 PyTorch 환경에서 모델 학습 및 실험으로 성능을 검증한 후 해당 모델을 배포할 환경에 따라 TFLite, ONNX 등으로 모델 변환을 하는 것이 일반적인데요. 특히 ONNX는 여러 최적화 기법을 공짜 점심으로 변환 시 제공하고 있어서 많이 활용합니다. 다만 우리가 만든 PyTorch 모델의 연산들이 항상 ONNX가 지원하는 Operator라는 보장이 없을 수 있는데요. 우선 ONNX에서 지원하는 Operator 목록은 링크를 통해 확인하실 수 있습니다
목록을 한 번 살펴보시면 눈치채셨겠지만, 최근 LLM 계열에서 많이 활용하는 RMSNorm이나 RotaryEmbedding 같은 연산들은 아직 ONNX에서 지원하지 않는 것을 확인하실 수 있는데요. 사실 이런 사용자 정의 연산자를 ONNX에서 매번 제공하는 것도 좋은 방향은 아닐 것 같습니다. 해당 연산이 정말로 많이 사용되고 이후에도 계속 사용된다면 괜찮겠지만, 특정 시점에 잠깐 유행하고 대체제가 나온다면 또 해당 연산을 지원해 줘야 하니까요. 우선은 이렇게 지원되지 않는 PyTorch 연산자를 ONNX로 export 하면 어떻게 결과물이 나오는지부터 확인해 보는 것이 좋겠습니다.
HuggingFace 저장소에 공개되어 있는 GemmaRMSNorm을 export 해보겠습니다. PyTorch에서의 구현체는 다음과 같습니다
class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float())
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output = output * (1.0 + self.weight.float())
return output.type_as(x)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
rsqrt나 pow, mean과 같은 연산의 조합으로 norm을 처리하는 것을 확인할 수 있습니다. 저희의 기대대로라면 RMSNorm이 ONNX 그래프에서도 같이 RMSNorm으로 표현이 된다면 좋겠지만, 아마도 Operator Schemas의 지원 목록에 없었기 때문에 그렇지 않을 것이란 것은 이미 어느 정도 짐작할 수 있는데요. 실제로 ONNX로 export 후 그래프 형태를 Netron으로 시각화해 보면 아래와 같이 나타납니다.
지원하지 않는 RMSNorm 대신 ONNX에서 현재 지원하는 Unit 연산들, 다시 말해서 Add, ReduceMean, Sqrt, Div와 같은 연산의 조합으로 그래프가 표현되고 있는 것을 확인할 수 있습니다. 이렇게 원본 PyTorch 그래프와 export 된 ONNX 그래프의 개형 차이가 발생하는 경우가 사용자 정의 연산으로 인해 꽤 빈번할 텐데요. 일반적인 Inference 환경에서라면 대부분 큰 문제는 없을 것입니다. Export 과정에서 Graph tracing이 제대로 되었다면 수학적으로 만들어지는 결과는 PyTorch와 ONNX 양쪽 모두에서 같거나 약간의 오차만 있어야 하기 때문입니다. 다만, 우리가 이런 그래프를 Quantization, 특히 Activation Quantization을 진행한다면 생각해야 하는 부분이 꽤 생기는데요. PyTorch에서는 한 개의 연산으로 표현되는 RMSNorm이므로 한 개의 Activation Quantization만을 고려하면 되지만, ONNX에서는 여러 Unit 연산으로 쪼개지므로 예외 처리를 하지 않는 이상 여러 개의 Activation Quantization이 발생하게 되며 이는 아마도 우리가 기대하는 결과와는 꽤 많은 차이를 발생시킬 것 같습니다
이런 유스케이스 때문에 PyTorch와 ONNX에서는 사용자 정의 연산을 원하는 형태로 Export 할 수 있는 기능을 제공하는데요. 요약하면 사용자 정의 연산을 torch.autograd.Function을 상속 후 symbolic 함수를 구현해주면 됩니다. PyTorch 문서에서 제공하는 예제 링크는 간단한 연산인 ReLU를 예제로 사용하고 있는데요. 이번 포스트에서도 비슷하게 ReLU로 진행해보겠습니다. 우선 다음과 같이 사용자 정의 ReLU를 정의하고 시작하겠습니다.
class ReluFunc(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
inputs = args[0]
ctx.save_for_backward(inputs)
return inputs.clamp(min=0)
@staticmethod
def symbolic(g: GraphContext, *args: Any) -> torch.Value:
inputs = args[0]
return g.op('my_domain::CustomReLU', inputs).setType(inputs.type())
class UserDefinedReLU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return ReluFunc.apply(x)
torch.autograd.Function을 상속한 ReluFunc를 정의하고 forward 함수와 symbolic을 구현한 것을 볼 수 있습니다. 이 때 symbolic 함수의 시그니처는 다음 문서를 보시면 도움이 될 것 같은데요.
symbolic의 첫 번째 인자는 GraphContext이며 그 뒤는 positional arguments를 받습니다
그 후 GraphContext의 op를 정의해서 반환을 해야하는데요. 순서대로 연산자의 이름 (이 때 domain::type_name 형태로 정의할 수 있음)과 positional arguments와 keyword arguments등을 g.op(...)에 채워넣을 수 있습니다. 위 예제에서 저는 우선 도메인은 my_domain 그리고 연산자의 이름은 CustomReLU로 정의했고 ReLU는 간단한 연산이므로 argument는 입력값만 넘겨주면 되겠네요! 뒤의 setType 설정은 쉽게 빠뜨리기 쉬운 부분인데요. PyTorch 문서를 확인하시면 Shape inference를 위해 setType을 설정하기를 권장하고 있는 내용이 나옵니다.
When the user registers symbolic for custom/contrib ops, it is highly recommended to add shape inference for that operator via setType API, otherwise the exported graph may have incorrect shape inference in some extreme cases. An example of setType is test_aten_embedding_2 in test_operators.py.
자 이제 필요한 작업은 다 한 것 같으니 실제로 export된 CustomReLU의 그래프 형태를 한 번 살펴보겠습니다.
우리가 PyTorch 코드에서 정의한 domain과 type이 잘 설정된 것을 확인할 수 있네요 :) 만약에 symboilc 함수를 정의하지 않고 export하면 어떻게 나오는 지도 한 번 확인해봐야겠죠? 결과는 아래와 같습니다.
따로 symbolic 함수를 정의하지 않은 CustomReLU에 대해서는 ONNX가 기존에 지원하는 Unit 연산 중 하나인 Clip으로 그래프가 표현되고 있는 것을 확인할 수 있습니다. 포스트에서는 간단한 ReLU에 대해서만 진행해봤지만, 필요하다면 더 복잡한 연산에 대해서도 당연히 symbolic을 정의 후 ONNX export 할 수 있습니다. 본문의 RMSNorm과 같은 연산을 하나의 연산으로 export 하는 것은 과제로 남겨두겠습니다 :)
이번 포스트에서는 다음과 같은 내용을 다뤘습니다.
- ONNX에서 지원하는 Operator 목록
- PyTorch의 사용자 정의 연산을 export할 때 ONNX 그래프의 개형
- PyTorch 사용자 정의 연산을 원하는 형태로 export 하기 위한 symbolic 정의
댓글
이 글 공유하기
다른 글
-
Stable Diffusion 3 ONNX Export 트러블 슈팅
Stable Diffusion 3 ONNX Export 트러블 슈팅
2024.11.24 -
Full Stack Optimization of Transformer Inference: a Survey (1)
Full Stack Optimization of Transformer Inference: a Survey (1)
2024.04.14 -
PyTorch 모델 프로파일링 및 성능 개선기
PyTorch 모델 프로파일링 및 성능 개선기
2024.03.03 -
외부 API가 의도한 대로 동작을 안 해요 (2) - contextlib 활용
외부 API가 의도한 대로 동작을 안 해요 (2) - contextlib 활용
2023.12.24