개발 이야기/Machine learning

PyTorch 모델 프로파일링 및 성능 개선기

가마뫼 2024. 3. 3. 20:45

Motivation

저희 팀은 전사적으로 사용하는 경량화 소프트웨어를 개발하고 있습니다. 여러 도메인의 팀 (자율주행, XR, LLM 등)이 사용하다 보니 여러 가지 문의가 항상 생기는데요. 최근에 Computer Vision 모델을 경량화 소프트웨어를 이용해서 전처리 후 다시 작업을 할 때 원본 모델에서보다 너무 느리다는 문의가 들어왔습니다. 보고된 Forward pass의 수행 시간이 거의 15배 이상 차이가 났었는데요. 그 과정에 대해서 어떻게 해결할 수 있었는지 예제를 통해서 소개하려고 합니다.

Profiling result

처음에는 해당 모델이 GPU로 실행되지 않고, CPU로 실행되고 있다던가 흔히 저지르기 쉬운 실수에 대해서 먼저 체크를 했었는데요. 여러 가지로 검토했을 때 그런 단순한 실수는 원인이 아니었던 것 같았고, 좀 더 정확하게 원인을 파악하고자 PyTorch 공식 페이지에서 제공하는 Profiler 코드를 일부 기본으로 하여 결과를 얻어보았습니다

Profiler 결과

Source location의 stack trace를 읽어보니 뭔가 문제가 되는 nn.Module은 모델 구성 요소 중 GatherNd_0 인 것 같습니다. TensorFlow나 ONNX와 다르게 PyTorch는 GatherNd 연산을 native로 지원하지 않으므로, 저희는 Custom으로 만든 GatherNd를 사용하고 있는데요. 역시나 해당 구현이 조금 비효율적인 것이 아닌가라는 의심을 프로파일링을 통해 시작할 수 있었습니다

PyTorch Performance Tuning Guide

PyTorch 공식 홈페이지에서는 많은 사용자들이 겪었던 성능 저하 사례들에 대한 정리 및 개선 방안을 이미 정리해 두고 있었는데요. 다음 링크 link에서 다양한 사례를 확인할 수 있습니다. 일반적인 최적화인 추론 시 torch.no_grad()를 활용하는 팁이나, 파라미터의 gradient를 초기화 시 None 사용해야 하는 이유 등 우리가 흔히 지나치고 넘어갈 만한 내용들입니다. 그 외에 CPU, GPU 또는 Distributed 환경에서의 특정 최적화도 소개되어 있는데요

아마도 저희가 봐야 할 부분은 GPU specific optimizations에 있을 것 같습니다 :) Keras 이전 TensorFlow에 비해서 PyTorch는 기존 Python 코드에 쉽게 연결해서 사용할 수 있어서 인기가 있었습니다. 다만 이런 이유로, 성능 저하를 만들 수 있는 경우가 꽤 자주 발생하는데요. 그 대표적인 내용이 바로 GPU 최적화 항목의 Avoid unnecessary CPU-GPU synchronization 입니다. PyTorch 튜토리얼에서 제시하는 대표적인 synchronization 케이스는 아래와 같습니다

  • print(cuda_tensor)
  • cuda_tensor.item()
  • memory copies: tensor.cuda(), cuda_tensor.cpu() and equivalent tensor.to(device) calls
  • cuda_tensor.nonzero()
  • python control flow which depends on results of operations performed on CUDA tensors e.g. if (cuda_tensor != 0).all()

예제의 1번과 5번처럼 Python land에서의 연산인 print나 if statement 역시 PyTorch에서는 특별히 문제없이 사용할 수 있기 때문에 실수하기 좋은데요. 앞에서 수행했던 Profiling 결과를 다시 떠올려볼까요? 네 맞습니다, 해당 profiler에서는 aten::item이 병목 후보군이었고, 이는 동기화를 일으키는 대표적인 예시 중 2번째에 해당합니다. 이를 통해, 저희는 이제 불필요한 CPU-GPU 동기화가 병목현상의 원인일 것이라는 어느 정도 추정을 해볼 수 있게 되었습니다. 이것이 정말 병목인지 확인하려면, 실제 코드를 확인하고 병목을 일으키지 않는 코드로 개선 후 다시 Profiling을 해보면 될 것 같습니다

GatherNd

우선 GatherNd가 어떤 연산자인지 간단하게 설명하고 넘어가는 것이 좋을 것 같습니다

GatherNd

쉽게 얘기하면 입력으로 주어지는 Tensor에 대해서 마찬가지로 주어지는 Indices에 해당하는 Tensor만 뽑아오는 연산이라고 생각하셔도 괜찮을 것 같습니다. 다만 Tensor와 Indices의 차원이 단순하지 않을 수 있으므로 경우를 다 고려하면 구현이 쉽지는 않을 것 같습니다. 이번 포스트의 목적은 모델의 프로파일링 및 최적화이므로 Gather나 Scatter와 같은 연산에 대해 좀 더 자세히 알고 싶으신 분들은 아래에 첨부한 References 링크를 참고하세요. 프로파일러에서 병목으로 추정되는 모델의 GatherNd의 구현체는 아래와 같습니다. 코드를 찬찬히 읽어보시고, 어디가 동기화로 인한 병목을 일으키는지 답을 보기 전에 한번 찾아보시는 것도 재미있을 것 같습니다

class GatherNd(torch.nn.Module):
    """ GatherNd op implementation"""

    def __init__(self, batch_dim: int):
        super().__init__()
        self.batch_dims = batch_dim

    def forward(self, data: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
        """
        Forward-pass routine for GatherNd op
        """
        data_rank = len(data.shape)

        assert indices.shape[-1] <= data_rank

        batch_dims_shape = []

        batch_dims_size = 1

        for i in range(self.batch_dims):
            batch_dims_shape.append(indices.shape[i])
            batch_dims_size *= indices.shape[i]

        output_shape = (
            batch_dims_shape + list(indices.shape)[self.batch_dims:-1]
            if (indices.shape[-1] == data_rank - self.batch_dims)
            else batch_dims_shape + list(indices.shape)[self.batch_dims:-1] + list(data.shape)[self.batch_dims + indices.shape[-1]:])

        if torch.jit.is_tracing():
            return torch.zeros(*output_shape, device=data.device)

        output_data_buffer = []

        reshaped_indices = indices.reshape(batch_dims_size, -1, indices.shape[-1])

        reshaped_data = data.reshape((batch_dims_size,) + data.shape[self.batch_dims:])

        for batch_dim in range(reshaped_indices.shape[0]):
            for outer_dim in range(reshaped_indices.shape[1]):
                gather_index = tuple(reshaped_indices[batch_dim][outer_dim])
                output_data_buffer.append(reshaped_data[(batch_dim, *gather_index)])

        if output_data_buffer[0].dim() == 0:
            return torch.tensor(output_data_buffer, device=data.device).reshape(output_shape)
        return torch.cat(output_data_buffer).reshape(output_shape)

어딘지 찾으셨나요? 제 생각엔 아무래도 gather_index = tuple(reshaped_indices[batch_dim][outer_dim]) 여기가 의심스럽습니다. tuple에 넣고 초기화하는 객체들은 모두 PyTorch 객체들이고 아마도 GPU 위에 존재할 텐데요. 이것을 Python land의 tuple에 옮겨오면서 암시적으로 CPU-GPU 동기화가 발생했을 것 같네요

Optimize and profiling again

문제가 되는 모델은 batch_dim이 0이었고, 그 외의 임의의 batch_dim에 대해서 모두 구현하기에는 시간이 여유롭지 않아 우선은 batch_dim이 0인 경우에 대해서만 불필요한 동기화 없이 최적화된 GatherNd 코드를 작성하였습니다. 저희 팀은 코드 작성 시 Incremental development를 지향하는데요, 이것 역시 일종의 점진적 개발이라고 볼 수 있을 것 같습니다. 모든 케이스에 대해서 완벽한 코드를 한 번에 작성할 수 있다면 좋지만, 그렇지 않다면 작은 단위로 먼저 개발을 하고, 일정을 작은 단위로 계획하는 것이 좀 더 유연하니까요!

class GatherNd(torch.nn.Module):
    """ GatherNd op implementation"""

    def __init__(self, batch_dim: int):
        super().__init__()
        self.batch_dims = batch_dim

    def forward(self, data: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
        """
        Forward-pass routine for GatherNd op
        """
        if self.batch_dims != 0:
            raise ValueError("Not implemented")

        r, m = len(data.shape), indices.shape[-1]
        total_samples = indices.shape[:-1].numel()
        output_shape = indices.shape[:-1] + data.shape[m:]
        indices_ = torch.split(
            tensor=indices.reshape(total_samples, m).transpose(0, 1),
            split_size_or_sections=1,
        )

        return data[indices_].reshape(output_shape).contiguous()

다시 코드로 돌아와서, 이전 코드와 비교해 보면 불필요하게 GPU의 PyTorch tensor가 CPU로 동기화되는 부분은 없어 보입니다. 그렇지만 확실히 하기 위해서 다시 프로파일링을 해보고 결과를 보는 것이 좋겠죠?

최적화 후 Profiler 결과

프로파일러 결과를 봐도 이전에 있었던 aten::item은 찾아볼 수가 없네요. 즉, 다시 말해 이제 GatherNd의 과정은 모두 GPU 위에서 실행된다고 봐도 될 것 같습니다. 실제로 해당 패치를 적용 후 15배 정도 느리다고 문제를 제기했던 모델의 forward pass의 수행 시간이 원본 모델과 거의 차이 없이 실행됨을 확인할 수 있었습니다

Conclusion

이번 포스트에서는 아래와 같은 내용들을 간단하게 소개해 봤습니다

  • PyTorch에서 제공하는 Performance Tuning Guide
  • Unnecessary CPU-GPU Synchronization
  • PyTorch Profiler 활용 및 결과 분석

비슷한 상황이 있을 때 여러분들도 프로파일러를 활용하여 병목을 해결해 보실 수 있기를 기대합니다 :)

References

반응형