개발 이야기/Machine learning

Stable Diffusion 3 ONNX Export 트러블 슈팅

가마뫼 2024. 11. 24. 17:06

오늘은 업무 중 발생했던 Stable Diffusion 3 모델의 ONNX Export 과정에서 발생했던 이슈와 해결 방법에 대해서 공유해보려고 합니다. Qualcomm의 하드웨어에 AI/ML 모델을 배포하기 위해서 Qualcomm® Neural Processing SDK를 활용하는데요. 해당 SDK에서 제공하는 도구 중 Converter는 여러 가지 프레임워크 (e.g., ONNX, TensorFlow, ...) 를 지원하지만, 대체로 ONNX를 많이 활용하고 있습니다. 아마도 요즘 대부분의 Researcher들이 익숙한 딥러닝 프레임워크가 PyTorch고 PyTorch에서 ONNX Export는 간편하게 지원하는 것이 첫 번째 이유일 것이라고 짐작하며, 두 번째로는 ONNX를 활용함으로써 얻을 수 있는 공짜 점심들이 있기 때문일 것입니다. ONNX 자체적으로 지원하는 최적화 (Constant Folding, Graph Fusion, ...) 외에도 onnx-simplifier 같은 도구를 통해서도 여러 가지 검증된 최적화를 수행할 수 있기 때문입니다. 즉, ONNX를 사용하지 않을 이유는 그렇게 많지 않다고 생각되고, 이런 최적화에 관한 내용은 이번 글의 주제를 조금 벗어나기 때문에 앞으로 따로 글을 써보도록 하겠습니다. 여하튼 이러한 이유로 오픈 소스로 제공되는 모델 자체 또는 그것을 어느 정도 Fine-tuning 한 모델을 ONNX로 먼저 Export 하는 것이 최우선 단계일 텐데요. 해당 과정부터 무언가 문제가 있다는 얘기를 다른 팀 엔지니어로부터 전해 들었고 그 원인을 파악해 보았습니다.

Stable Diffusion 3 모델 구조

우선 Stable Diffusion 3의 모델 구조에 대해서 간략하게 설명하고 넘어가야할 것 같은데요. Stable Diffusion에 익숙하신 분들이라면 대부분 이 모델의 구조가 Text Encoder, UNet, 그리고 VAE의 조합으로 구성된 것을 어느 정도 잘 알고 계실 것 같습니다. Stable Diffusion 3도 큰 틀에서의 구조는 변함은 없지만 세세한 부분에서 변경사항이 좀 있는데요

https://encord.com/blog/stable-diffusion-3-text-to-image-model/

위 그림과 같이 3 종류의 Text Encoder를 사용하고 있으며, UNet 대신 MM-DiT를 활용하는 점이 이전 Stable Diffusion의 모델 구조와의 차이점이라고 볼 수 있을 것 같습니다. HuggingFace의 아래 모델 Definition을 보면 조금 더 이해에 도움이 될 것 같네요!

https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py#L131

Stable Diffusion 3 그 자체에 대해서는 또 글이 너무 길어질 것 같아서 이번 글에서는 모델 구조에 대해서만 간략하게 얘기하고 넘어가겠습니다. 좀 더 자세한 내용이 필요하신 분들은 stability.ai의 research paper와 위 그림의 출처인 https://encord.com/blog/stable-diffusion-3-text-to-image-model 링크를 참고하세요. 문제가 되는 컴포넌트는 MM-DiT였는데요. 어떤 Error가 발생했는지 밑에서 좀 더 자세히 설명하도록 하겠습니다.

Error 내용

전달 받은 Error는 그래프 내부에 특정 Tensor의 이름이 비어있고, 그것이 Graph Compile 과정에서 에러를 발생시킨다는 내용이었는데요. AI/ML 모델 시각화로 많이 사용하는 Netron 을 활용해서 정말로 그런 노드가 존재하는 지 확인을 해보았습니다.

입력 Tensor부터 거슬러가다보니 생각보다 빠르게 수상한 노드를 찾을 수 있었는데요. 다름 아닌 LayerNormalization 노드였습니다. 이제 수상한 노드를 찾았으니 실제 코드가 어떤식으로 구성되어있는지 찾아보면 문제를 해결할 수 있을 것 같네요!

MM-DiT의 모델의 forward pass를 보니 JointTransformerBlock 를 순회하고 있고, JointTransformerBlock은 아래와 같이 self.norm1을 가장 먼저 통과하게 되는데 이번 경우는 AdaLayerNormZero를 사용하고 있습니다. 거의 다 온 것 같은데요!

class JointTransformerBlock(nn.Module):
    r"""
    A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.

    Reference: https://arxiv.org/abs/2403.03206
    이하 중략
    """

    def __init__(
        self,
        dim: int,
        num_attention_heads: int,
        attention_head_dim: int,
        context_pre_only: bool = False,
        qk_norm: Optional[str] = None,
        use_dual_attention: bool = False,
    ):
        super().__init__()

        self.use_dual_attention = use_dual_attention
        self.context_pre_only = context_pre_only
        context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"

        if use_dual_attention:
            self.norm1 = SD35AdaLayerNormZeroX(dim)
        else:
            self.norm1 = AdaLayerNormZero(dim)
        # ...

해당 레이어의 Definition이 눈에 띄네요. 아래와 같이 elementwise_affine=False로 주고 있고 이는 일반적인 LayerNormalization layer의 인자와 다른 값입니다. 기본값은 elementwise_affine=True죠. 이제 원인을 거의 다 찾은 것 같습니다. 아마도 LayerNorm에 대한 PyTorch 문서만 확인해보면 될 것 같네요 :)

class AdaLayerNormZero(nn.Module):
    r"""
    Norm layer adaptive layer norm zero (adaLN-Zero).

    Parameters:
        embedding_dim (`int`): The size of each embedding vector.
        num_embeddings (`int`): The size of the embeddings dictionary.
    """

    def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
        super().__init__()
        if num_embeddings is not None:
            self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
        else:
            self.emb = None

        self.silu = nn.SiLU()
        self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
        if norm_type == "layer_norm":
            self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
        elif norm_type == "fp32_layer_norm":
            self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
        else:
            raise ValueError(
                f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
            )
		# 중략

PyTorch 공식 문서의 LayerNorm Parameter에 대한 설명은 아래와 같습니다

elementwise_affine은 기본값이 True이며 이는 scale과 zero_point를 각각 learnable parameter로 초기화한다고 되어있습니다. 따라서 해당 값을 False로 주게되면 scale과 zero_point에 대한 Tensor의 이름이 제대로 채워지지 않았을 것이라고 추측할 수 있었습니다. 이에 대한 부분은 사실 ONNX Export의 책임이므로 저희가 어떻게 해결하기는 어려운 문제였는데요. 다만 저희는 pre-trained 모델을 사용하고 향후에 QAT와 같은 fine-tuning의 요구사항이 없었기 때문에 아래와 같은 Workaround를 적용해서 문제를 해결했습니다.

Workaround

문제 해결은 심플했는데요. PyTorch 모델을 순회하면서 nn.Module 중 LayerNorm이면서 elementwise_affine이 False로 설정되어있는 경우 해당 Flag를 변경하고, scale와 zero_point를 적절히 채워주었습니다. 코드 스니펫은 다음과 같습니다. 

for name, module in model.named_modules():
    if isinstance(module, nn.LayerNorm) and not module.elementwise_affine:
        module.weight = nn.Parameter(torch.ones(module.normalized_shape))
        module.bias = nn.Parameter(torch.zeros(module.normalized_shape))

수식 상 weight * inputs + bias이므로 weight과 bias를 각각 1과 0으로 명시적 초기화하면 math invariant하기 때문에 문제가 발생하지 않습니다! 위와 같은 임시 해결책을 이슈를 공유한 엔지니어에게 전달하고 무사히 ONNX Export 및 Graph compile을 완료할 수 있었습니다.

Key takeaways

해결 방법이 생각보다 쉬워서 김이 좀 빠지셨나요? 실제로 위와 같은 이슈는 최신 PyTorch 및 ONNX 환경에서는 발생하지 않는 것으로 보입니다. 다만, 우리가 작업하는 환경이 항상 최신 버전을 사용할 수 없으므로 필요하다면 직접 소스 코드를 백포팅하거나 위와 같은 Workaround를 활용해야 할 때가 있음을 인지하고 있으면 좋겠죠? 이번 이슈 트러블슈팅에서의 몇 가지 Key takeaway는 아래와 같았습니다

  • 로그를 잘 확인하자. 로그를 통해서 Tensor 이름이 비어 있는 값임을 알 수 있었다
  • 시각화 도구 (e.g., Netron 등)를 잘 활용하면 가시적으로 문제를 파악하는 데 도움이 된다
  • 상황에 따라 우리가 의존하는 패키지를 직접 수정하기 어려운 때도 있다. 이럴 때는 요구 사항을 잘 확인한 후 적절한 Workaround를 제공하는 유연함이 필요하다
반응형