The PyTorch team has officially released PyTorch 2.0, which was first previewed back in December 2022 at the PyTorch Conference.
PyTorch is a Linux Foundation machine learning framework that was originally developed by Meta.
This release includes a high-performance implementation of the Transformer API. It supports more use cases now, such as models using Cross-Attention, Transformer Decoders, and for training models. The goal of releasing this new API is to make training and deployment of Transformer models more cost effective and affordable, the team explained.
PyTorch 2.0 also introduces torch.compile as the main API for wrapping models and returning a compiled model. This is a completely additive feature, helping to maintain backwards compatibility.
Torch.compile is built on four other new technologies:
- TorchDynamo, which uses Python Frame Evaluation Hooks to safely capture PyTorch programs
- AOTAutogram, which can be used to generate ahead-of-time backward traces
- PrimTorch, which condenses over 2,000 PyTorch operators down into a set of 250 that can be targeted to build a complete PyTorch backend, significantly reducing the barrier to entry
- TorchInductor, which is a deep learning compiler that makes use of OpenAI Triton.
“We have achieved major speedups for training transformer models and in particular large language models with Accelerated PyTorch 2 Transformers using a combination of custom kernels and torch.compile(),” the PyTorch team wrote in a blog post.
This release also adds support for 60 new operators to the Metal Performance Shaders (MPS) backend, which provides GPU accelerated training on macOS platforms. This brings the total coverage to 300 operators to-date.
AWS customers will see improved performance on AWS Graviton compared to previous releases. These improvements focus on GEMM kernels, bfloat16 support, primitive caching, and the memory allocator.
This release also includes several beta updates to PyTorch domain libraries and other libraries like TorchAudio, TorchVision, and TorchText.
There are also several features in the prototype stage across many features, including TensorParallel, DTensor, 2D parallel, TorchDynamo, AOTAutograd, PrimTorch and TorchInductor.