From a719008d9a5c7515c382cb356a924a4f2504ec8f Mon Sep 17 00:00:00 2001 From: Thomas Jannaud Date: Thu, 3 Apr 2025 10:18:59 -0700 Subject: [PATCH] RMSNorm support - Executorch (#9844) Summary: This follows D72014553 which adds support for RMSNorm (cpu backend) This is a separate diff for Executorch / Github Reviewed By: Vysarat Differential Revision: D72258890 --- backends/cadence/aot/ops_registrations.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 9e604ae42aa..dec6feb1b8d 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -139,6 +139,7 @@ "int in_zero_point, bool channel_last=False) -> (Tensor out)" ) lib.define("linalg_vector_norm(Tensor X) -> (Tensor Y)") +lib.define("rms_norm(Tensor X, float eps, Tensor W) -> (Tensor Y)") lib.define( "transposed_im2row(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, " "int[2] output_padding, Tensor in_zero_point, bool channel_last=False) -> (Tensor out)" @@ -210,6 +211,9 @@ "fully_connected.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define("linalg_vector_norm.out(Tensor X, *, Tensor(a!) out) -> Tensor(a!)") +lib.define( + "rms_norm.out(Tensor X, float eps, Tensor W, *, Tensor(a!) out) -> Tensor(a!)" +) lib.define( "quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, " "Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" @@ -615,6 +619,15 @@ def linalg_vector_norm_meta( return X.new_empty([], dtype=X.dtype) +@register_fake("cadence::rms_norm") +def rms_norm_meta( + X: torch.Tensor, + eps: float, + weight: torch.Tensor, +) -> torch.Tensor: + return X.new_empty(X.shape, dtype=X.dtype) + + @register_fake("cadence::requantize") def requantize_meta( input: torch.Tensor,