diff --git a/_posts/2025-01-06-hi-po-low-bit-operators.md b/_posts/2025-01-06-hi-po-low-bit-operators.md new file mode 100644 index 000000000000..c5243cff1bf6 --- /dev/null +++ b/_posts/2025-01-06-hi-po-low-bit-operators.md @@ -0,0 +1,133 @@ +--- +layout: blog_detail +title: "High-Performance Low-Bit Operators for PyTorch" +author: Scott Roy, Digant Desai, Kimish Patel +--- + +We are excited to announce the addition of embedding operators with low-bit weights (1-8 bit) and linear operators with 8-bit dynamically quantized activations and low-bit weights (1-8 bit) for Arm CPUs in TorchAO, PyTorch’s native low-precision library. These operators work seamlessly across all PyTorch surfaces, including eager, torch.compile, AOTI, and ExecuTorch, and are [available to use in torchchat](https://github.com/pytorch/torchchat/blob/main/docs/quantization.md#experimental-torchao-lowbit-kernels). + +In developing these linear operators, our focus was on **code sharing between PyTorch and ExecuTorch**, and establishing a clear boundary between the higher-level operator and the lower-level kernel. This design **allows third-party vendors to easily swap in their own kernels**. We also set out to **create a place and infrastructure to experiment** with new CPU quantization ideas and test those across the PyTorch ecosystem. + + +## Universal low-bit kernels + +There is no hardware support for low-bit arithmetic. In what we call universal kernels, we explicitly separated the logic that unpacks low-bit values to int8 values, and the int8 GEMV kernel logic in a modular fashion. We started with an 8-bit kernel, for example, this [1x8 8-bit GEMV kernel](https://github.com/pytorch/ao/blob/299aacd0ab0e0cce376f56e18e5bb585d517b2e1/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot-impl.h#L64) that uses the Arm neondot instruction. Within the 8-bit kernel, we invoke an [inlined unpacking routine](https://github.com/pytorch/ao/blob/299aacd0ab0e0cce376f56e18e5bb585d517b2e1/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot-impl.h#L169) to convert low-bit values into int8 values. This unpacking routine is force-inlined and templated on some low-bit value. Our experiments showed no performance difference between using a separate force-inlined unpacking routine and directly embedding the unpacking code inline. + +The advantage of this modular design is improved development speed and code maintainability. After writing an 8-bit kernel, we quickly achieved full low-bit coverage by writing [simple bitpacking routines](https://github.com/pytorch/ao/tree/299aacd0ab0e0cce376f56e18e5bb585d517b2e1/torchao/experimental/kernels/cpu/aarch64/bitpacking). In fact, developers who worked on the bit packing routines did not need to be experts on GEMV/GEMM kernel writing. We also reused the same bitpacking routines from the linear kernels [within the embedding kernels](https://github.com/pytorch/ao/blob/299aacd0ab0e0cce376f56e18e5bb585d517b2e1/torchao/experimental/kernels/cpu/aarch64/embedding/embedding.h#L161). In future we could reuse the same bitpacking routines for universal GEMM kernels or kernels based on fma or i8mm instructions. + + +## Shared code between PyTorch and ExecuTorch + +To achieve shared code between PyTorch and ExecuTorch, we wrote kernels [using raw pointers instead of PyTorch tensors](https://github.com/pytorch/ao/blob/299aacd0ab0e0cce376f56e18e5bb585d517b2e1/torchao/experimental/kernels/cpu/aarch64/linear/linear.h). Moreover, we implemented the [linear operator in a header ](https://github.com/pytorch/ao/blob/299aacd0ab0e0cce376f56e18e5bb585d517b2e1/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h#L259)that is included in separate [PyTorch](https://github.com/pytorch/ao/blob/299aacd0ab0e0cce376f56e18e5bb585d517b2e1/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp) and [ExecuTorch](https://github.com/pytorch/ao/blob/299aacd0ab0e0cce376f56e18e5bb585d517b2e1/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w4s.cpp) operator registration code. By using only features common to both ATen and ExecuTorch tensors, we ensured compatibility between the two frameworks. For multi-threaded compute, we introduced [torchao::parallel_1d](https://github.com/pytorch/ao/blob/299aacd0ab0e0cce376f56e18e5bb585d517b2e1/torchao/experimental/ops/parallel.h#L13), which compiles to either [at::parallel_for](https://github.com/pytorch/ao/blob/299aacd0ab0e0cce376f56e18e5bb585d517b2e1/torchao/experimental/ops/parallel-aten-impl.h) or [ExecuTorch’s threadpool](https://github.com/pytorch/ao/blob/299aacd0ab0e0cce376f56e18e5bb585d517b2e1/torchao/experimental/ops/parallel-executorch-impl.h) based on compile-time flags. + + +## Swappable kernels + +Our design for the higher-level multi-threaded linear operator is agnostic to the lower-level single-threaded kernels, allowing third-party vendors to swap in their own implementations. The interface between the operator and kernel is defined by a [ukernel config](https://github.com/pytorch/ao/blob/299aacd0ab0e0cce376f56e18e5bb585d517b2e1/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h#L14), which specifies kernel function pointers for preparing activation data, preparing weight data, and running the kernel. The operator, responsible for tiling and scheduling, interacts with kernels solely through this config. + + +## Performance + +In the table below, we show Llama3.1 8B token generation performance using 6 CPU threads on an M1 Macbook Pro with 32GB of RAM. + + +
Bitwidth x + | +torch.compile (Decode tokens/sec) + | +ExecuTorch (Decode tokens/sec) + | +ExecuTorch PTE size (GiB) + | +
1 + | +24.18 + | +17.86 + | +1.46 + | +
2 + | +27.02 + | +19.65 + | +2.46 + | +
3 + | +21.01 + | +22.25 + | +3.46 + | +
4 + | +19.51 + | +19.47 + | +4.47 + | +
5 + | +14.78 + | +16.34 + | +5.47 + | +
6 + | +12.80 + | +13.61 + | +6.47 + | +
7 + | +8.16 + | +11.73 + | +7.48 + | +