From 60fce8e32e2c752f74565e750c6e8b0a4441d81d Mon Sep 17 00:00:00 2001 From: Thomas Jannaud Date: Wed, 9 Apr 2025 14:02:27 -0700 Subject: [PATCH] Adding RMSNorm support to arbitrary x and normalized_dim shapes (#9966) Summary: In D72014553, we were adding initial support for RMS norm for an input in 3 or 4 dimensions, and a weight of dimension 1 (same size as x[:-1]) In this diff, we allow for: - input of arbitrary shape - shape broadcasting of w (w must have dim <= 1) Reviewed By: Vysarat Differential Revision: D72484196 --- backends/cadence/aot/ops_registrations.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 368f499aa85..aca4965083d 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -139,7 +139,6 @@ "int in_zero_point, bool channel_last=False) -> (Tensor out)" ) lib.define("linalg_vector_norm(Tensor X) -> (Tensor Y)") -lib.define("rms_norm(Tensor X, float eps, Tensor W) -> (Tensor Y)") lib.define( "transposed_im2row(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, " "int[2] output_padding, Tensor in_zero_point, bool channel_last=False) -> (Tensor out)" @@ -211,9 +210,6 @@ "fully_connected.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define("linalg_vector_norm.out(Tensor X, *, Tensor(a!) out) -> Tensor(a!)") -lib.define( - "rms_norm.out(Tensor X, float eps, Tensor W, *, Tensor(a!) out) -> Tensor(a!)" -) lib.define( "quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, " "Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"