diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute.glsl b/backends/vulkan/runtime/graph/ops/glsl/permute.glsl deleted file mode 100644 index 716c42e8ede..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/permute.glsl +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define VEC4_T ${texel_type(DTYPE)} - -layout(std430) buffer; - -#include "indexing_utils.h" - -${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} - -layout(push_constant) uniform PRECISION restrict Block { - ivec4 out_limits; - ivec4 in_sizes; - // output dims - ivec4 out_ndims; - // x = output channels aligned to 4, y = input channels aligned to 4 - ivec2 channel_info; -}; - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -layout(constant_id = 3) const int packed_dim = C_DIM; - -#extension GL_EXT_control_flow_attributes : require - -void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, out_limits.xyz))) { - return; - } - - VEC4_T outval = VEC4_T(0.0); - - // scale up output position's packed dim - pos[packed_dim] <<= 2; - - // index of packed dim in bchw format - const int in_packed_dim_bchw_index = 3 - packed_dim; - - // determine input position based on output position and permute map - // out_ndims is in BCHW format - ivec4 in_bchw_pos = ivec4(0); // holds b,c,h,w - in_bchw_pos[out_ndims[0]] = (pos.z / channel_info.x); - in_bchw_pos[out_ndims[1]] = (pos.z % channel_info.x); - in_bchw_pos[out_ndims[2]] = pos.y; - in_bchw_pos[out_ndims[3]] = pos.x; - - const int in_packed_dim_size = in_sizes[3 - out_ndims[in_packed_dim_bchw_index]]; - - [[unroll]] for (int j = 0, bchw_index = in_bchw_pos[out_ndims[in_packed_dim_bchw_index]]; j < 4; ++j, ++bchw_index) { - // terminate the loop if trying to access input texture out of bounds - if (bchw_index >= in_packed_dim_size) { - break; - } - // go to position in the input, that is mapped to the packed dim in the output - in_bchw_pos[out_ndims[in_packed_dim_bchw_index]] = bchw_index; - - ivec3 fetch_pos; - - fetch_pos.xy = in_bchw_pos.wz; - // calculate input position in z axis using batch and channel index which is in_bchw_pos.x and in_bchw_pos.y respectively - fetch_pos.z = in_bchw_pos.y + in_bchw_pos.x * channel_info.y; - - // input tensor's packed dim lane corresponding to output tensor's pos - const int in_packed_dim_lane_index = fetch_pos[packed_dim] & 0x3; - - // scale down input tensor's packed dim pos to perform fetch - fetch_pos[packed_dim] >>= 2; - - // fetch input texel - VEC4_T inval = VEC4_T(load_texel(t_in, fetch_pos)); - outval[j] = inval[in_packed_dim_lane_index]; - } - - pos[packed_dim] = int(gl_GlobalInvocationID[packed_dim]); - - imageStore(t_out, pos, outval); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/permute_buffer.glsl new file mode 100644 index 00000000000..55b9e3dc9ea --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/permute_buffer.glsl @@ -0,0 +1,72 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} +#define T ${buffer_scalar_type(DTYPE)} + +${define_active_storage_type("buffer")} +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_in", DTYPE, "buffer")} + +${layout_declare_ubo(B, "ivec4", "in_sizes")} +${layout_declare_ubo(B, "ivec4", "out_strides")} +${layout_declare_ubo(B, "int", "out_numel")} + +layout(push_constant) uniform restrict Block { + ivec4 in_strides; + ivec4 permute_dims; // Permutation mapping: permute_dims[i] = j means output dim i comes from input dim j +}; + +${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} +${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} + +const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +// Convert output tensor index to input tensor index based on permutation +ivec4 out_tidx_to_in_tidx(const ivec4 out_tidx) { + ivec4 in_tidx; + + // Apply the permutation mapping: in_tidx[permute_dims[i]] = out_tidx[i] + in_tidx[permute_dims.x] = out_tidx.x; + in_tidx[permute_dims.y] = out_tidx.y; + in_tidx[permute_dims.z] = out_tidx.z; + in_tidx[permute_dims.w] = out_tidx.w; + + return in_tidx; +} + +void main() { + const int out_bufi = ivec3(gl_GlobalInvocationID).x; + if (out_bufi >= out_numel) { + return; + } + + // Convert buffer index to tensor index for output + const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_dim_order); + + // Convert output tensor index to input tensor index using permutation + const ivec4 in_tidx = out_tidx_to_in_tidx(out_tidx); + + // Convert input tensor index back to buffer index + const int in_bufi = tidx_to_bufi(in_tidx, in_strides); + + // Copy data from input to output + t_out[out_bufi] = t_in[in_bufi]; +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute.yaml b/backends/vulkan/runtime/graph/ops/glsl/permute_buffer.yaml similarity index 73% rename from backends/vulkan/runtime/graph/ops/glsl/permute.yaml rename to backends/vulkan/runtime/graph/ops/glsl/permute_buffer.yaml index a90ddcb41ce..81675ae8917 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/permute.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/permute_buffer.yaml @@ -1,12 +1,10 @@ -permute: +permute_buffer: parameter_names_with_default_values: DTYPE: float - NDIM: 3 - STORAGE: texture3d generate_variant_forall: DTYPE: - VALUE: half - VALUE: float - VALUE: int32 shader_variants: - - NAME: permute + - NAME: permute_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/permute_texture.glsl new file mode 100644 index 00000000000..274077f4181 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/permute_texture.glsl @@ -0,0 +1,103 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} +#define T ${buffer_scalar_type(DTYPE)} + +${define_active_storage_type("texture3d")} +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_in", DTYPE, "texture3d")} + +layout(push_constant) uniform restrict Block { + ivec4 out_sizes; + ivec4 in_sizes; + ivec4 permute_dims; // Permutation mapping: permute_dims[i] = j means output dim i comes from input dim j +}; + +${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} +const lowp ivec4 out_axis_map = unhash_axis_map(out_layout); +const lowp int out_packed_dim = unhash_packed_dim(out_layout); + +${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} +const lowp ivec4 in_axis_map = unhash_axis_map(in_layout); +const lowp int in_packed_dim = unhash_packed_dim(in_layout); + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +// Convert output tensor index to input tensor index based on permutation +ivec4 out_tidx_to_in_tidx(const ivec4 out_tidx) { + ivec4 in_tidx; + + // Apply the permutation mapping: in_tidx[permute_dims[i]] = out_tidx[i] + in_tidx[permute_dims.x] = out_tidx.x; + in_tidx[permute_dims.y] = out_tidx.y; + in_tidx[permute_dims.z] = out_tidx.z; + in_tidx[permute_dims.w] = out_tidx.w; + + return in_tidx; +} + +// Check if we can use the fast path where texels from the input tensor can be +// copied directly into the output tensor. This occurs when the packed dimension +// is preserved in the permutation, i.e. reading a texel from the output tensor +// produces 4 texels along the same dimension as reading a texel from the input +// tensor. +bool can_use_fast_path() { + // Fast path is possible when the packed dimension is preserved in the permutation + // This means permute_dims[out_packed_dim] == in_packed_dim + return permute_dims[out_packed_dim] == in_packed_dim; +} + +void main() { + const ivec3 lpos = ivec3(gl_GlobalInvocationID); + ivec4 out_tidx = lpos_to_tidx(lpos, out_sizes, out_axis_map.w, out_packed_dim); + + if (any(greaterThanEqual(out_tidx, out_sizes))) { + return; + } + + if (can_use_fast_path()) { + // Fast path: packed dimension is preserved, so we can copy texels directly + ivec4 in_tidx = out_tidx_to_in_tidx(out_tidx); + ivec3 in_pos = tidx_to_pos(in_tidx, in_sizes, in_axis_map, in_packed_dim); + VEC4_T in_texel = VEC4_T(load_texel(t_in, in_pos)); + + write_texel_lpos(t_out, lpos, in_texel, out_axis_map); + } + else { + // Slow path: packed dimension is not preserved, so each element of the + // output texel may be "sourced" from a different texel in the input tensor. + // Therefore each output texel element is processed individually. + VEC4_T out_texel = VEC4_T(0); + + for (int texel_i = 0; texel_i < 4; ++texel_i) { + ivec4 in_tidx = out_tidx_to_in_tidx(out_tidx); + ivec3 in_pos = tidx_to_pos(in_tidx, in_sizes, in_axis_map, in_packed_dim); + int element_idx = in_tidx[in_packed_dim] % 4; + + VEC4_T in_texel = VEC4_T(load_texel(t_in, in_pos)); + T selected_value = T(in_texel[element_idx]); + + out_texel[texel_i] = selected_value; + + out_tidx[out_packed_dim]++; + } + + write_texel_lpos(t_out, lpos, out_texel, out_axis_map); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/permute_texture.yaml new file mode 100644 index 00000000000..f68b8dcdd3d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/permute_texture.yaml @@ -0,0 +1,10 @@ +permute_texture: + parameter_names_with_default_values: + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + - VALUE: int32 + shader_variants: + - NAME: permute_texture3d diff --git a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp index fba3f03467b..6e6a6fa3bf2 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp @@ -10,6 +10,7 @@ #include +#include #include #include #include @@ -100,54 +101,76 @@ void add_permute_node( const ValueRef out) { check_args(graph, in, permute_dims, out); - ivec4 out_dims{0, 1, 2, 3}; - - // Special cases of squeeze/unsqueeze. Because the input dim size can be - // different with output dim size. So pick graph.dim_of(in) if squeeze, and - // graph.dim_of(out) if unsqueeze to create parameter for permute. - const int64_t out_ndim = std::max(graph.dim_of(in), graph.dim_of(out)); - std::vector seen(out_ndim); + // Convert the permute dims to WHCN dimension order, which is the standard in + // our compute shaders. The following transformations are applied. + // 1. Change dimension index values from NCHW order valueto WHCN order value + // 2. Reverse the order of the permute array from NCHW order to WHCN order + ivec4 whcn_permute_dims{0, 1, 2, 3}; { IntListPtr permute_dims_ptr = graph.get_int_list(permute_dims); - for (int i = 0; i < out_ndim; i++) { - int64_t permute_dim = permute_dims_ptr->at(i); - VK_CHECK_COND( - !seen[permute_dim], "Argument dim ", permute_dim, " is repeated"); - seen[permute_dim] = true; + const int32_t permute_ndim = + utils::safe_downcast(permute_dims_ptr->size()); + + for (int32_t nchw_i = permute_ndim - 1, whcn_i = 0; nchw_i >= 0; + nchw_i--, whcn_i++) { + const int32_t permute_dim_nchw = permute_dims_ptr->at(nchw_i); + const int32_t permute_dim_whcn = permute_ndim - 1 - permute_dim_nchw; - out_dims[(4u - out_ndim) + i] = - utils::safe_downcast(permute_dim + (4 - out_ndim)); + whcn_permute_dims[whcn_i] = permute_dim_whcn; } } std::string kernel_name = "permute"; kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); add_dtype_suffix(kernel_name, graph.dtype_of(out)); - const int32_t out_channels = dim_at(graph.sizes_of(out)); - const int32_t in_channels = dim_at(graph.sizes_of(in)); + vkapi::ParamsBindList param_buffers; + std::vector push_constants; + vkapi::SpecVarList spec_vars; - const int32_t packed_dim = graph.packed_dim_of(in); - ivec2 channel_info = {out_channels, in_channels}; - if (packed_dim == WHCN::kChannelsDim) { - channel_info[0] = utils::align_up_4(channel_info[0]); - channel_info[1] = utils::align_up_4(channel_info[1]); - } + if (graph.is_buffer_storage(out)) { + param_buffers.append(graph.sizes_ubo(in)); + param_buffers.append(graph.strides_ubo(out)); + param_buffers.append(graph.numel_ubo(out)); + + // Buffer storage - use permute_buffer shader + push_constants = { + graph.strides_pc_of(in), + PushConstantDataInfo(&whcn_permute_dims, sizeof(whcn_permute_dims)), + }; + + spec_vars = {graph.hashed_layout_of(out), graph.hashed_layout_of(in)}; + } else { + // Texture storage - use permute_texture shader + const int32_t out_channels = dim_at(graph.sizes_of(out)); + const int32_t in_channels = dim_at(graph.sizes_of(in)); + + const int32_t packed_dim = graph.packed_dim_of(in); + ivec2 channel_info = {out_channels, in_channels}; + if (packed_dim == WHCN::kChannelsDim) { + channel_info[0] = utils::align_up_4(channel_info[0]); + channel_info[1] = utils::align_up_4(channel_info[1]); + } + + push_constants = { + graph.sizes_pc_of(out), + graph.sizes_pc_of(in), + PushConstantDataInfo(&whcn_permute_dims, sizeof(whcn_permute_dims))}; - const vkapi::SpecVarList spec_vars = {packed_dim}; + spec_vars = {graph.hashed_layout_of(out), graph.hashed_layout_of(in)}; + } - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, {{out, vkapi::kWrite}, {in, vkapi::kRead}}, - {}, + // Parameter buffers + param_buffers, // Push Constants - {{graph.logical_limits_pc_of(out), - graph.sizes_pc_of(in), - PushConstantDataInfo(&out_dims, sizeof(out_dims)), - PushConstantDataInfo(&channel_info, sizeof(channel_info))}}, + push_constants, // Specialization Constants spec_vars, // Resize Args diff --git a/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp b/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp index 306a79fb8b8..c4de5d88f30 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp @@ -26,6 +26,9 @@ void add_unsqueeze_node( in_dim < 4, "Cannot unsqueeze a tensor with more than 3 dimensions"); int64_t dim = graph.extract_scalar(dim_ref); + if (dim < 0) { + dim += out_dim; + } std::vector permute_dims(out_dim); for (int i = 1; i <= dim; i++) { diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 813807445f0..92f73268ebf 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -752,6 +752,13 @@ def get_permute_inputs(): "utils::kHeightPacked", "utils::kChannelsPacked", ] + test_suite.storage_types = [ + "utils::kBuffer", + "utils::kTexture3D", + ] + test_suite.dtypes = [ + "at::kFloat", + ] return test_suite @@ -990,9 +997,11 @@ def get_unsqueeze_inputs(): ((9, 9), 2), ((9,), 0), ((9,), 1), + ((1, 10), -1), ] ) test_suite.layouts = [ + "utils::kWidthPacked", "utils::kChannelsPacked", ] test_suite.data_gen = "make_seq_tensor" diff --git a/backends/vulkan/test/op_tests/utils/gen_computegraph.py b/backends/vulkan/test/op_tests/utils/gen_computegraph.py index b24879f660a..38a3ee93627 100644 --- a/backends/vulkan/test/op_tests/utils/gen_computegraph.py +++ b/backends/vulkan/test/op_tests/utils/gen_computegraph.py @@ -58,6 +58,8 @@ class ValueRef: src_cpp_type: str is_in: bool = False is_out: bool = False + fixed_storage_type: Optional[str] = None + fixed_memory_layout: Optional[str] = None requires_prepack: bool = False supports_prepack: bool = False # When is_dynamic_size is true, the underlying object size is not known @@ -137,20 +139,43 @@ def __init__( if arg.name in self.suite_def.prepacked_args: supports_prepack = True + fixed_storage_type = None + if arg.name in self.suite_def.arg_storage_types: + fixed_storage_type = self.suite_def.arg_storage_types[arg.name] + + fixed_memory_layout = None + if arg.name in self.suite_def.arg_memory_layouts: + fixed_memory_layout = self.suite_def.arg_memory_layouts[arg.name] + self.refs[arg.name] = ValueRef( name=f"{arg.name}_ref", src_cpp_name=arg.name, src_cpp_type=cpp_type, is_in=(cpp_type in InableCppType), + fixed_storage_type=fixed_storage_type, + fixed_memory_layout=fixed_memory_layout, requires_prepack=requires_prepack, supports_prepack=supports_prepack, ) ret_type = cpp.returns_type(self.f.func.returns, symint=False).cpp_type() self.out = ATenArg(name="out", cpp_type=ret_type, default=None) + + fixed_storage_type = None + if "out" in self.suite_def.arg_storage_types: + fixed_storage_type = self.suite_def.arg_storage_types["out"] + fixed_memory_layout = None + if "out" in self.suite_def.arg_memory_layouts: + fixed_memory_layout = self.suite_def.arg_memory_layouts["out"] + if ret_type == AT_TENSOR: self.refs["out"] = ValueRef( - name="out_ref", src_cpp_name="out", src_cpp_type=ret_type, is_out=True + name="out_ref", + src_cpp_name="out", + src_cpp_type=ret_type, + is_out=True, + fixed_storage_type=fixed_storage_type, + fixed_memory_layout=fixed_memory_layout, ) elif ret_type == TWO_TENSOR_TUPLE: self.refs["out"] = [ @@ -159,12 +184,24 @@ def __init__( src_cpp_name="std::get<0>(out)", src_cpp_type="at::Tensor", is_out=True, + fixed_storage_type=( + fixed_storage_type[0] if fixed_storage_type else None + ), + fixed_memory_layout=( + fixed_memory_layout[0] if fixed_memory_layout else None + ), ), ValueRef( name="out_ref_second", src_cpp_name="std::get<1>(out)", src_cpp_type="at::Tensor", is_out=True, + fixed_storage_type=( + fixed_storage_type[1] if fixed_storage_type else None + ), + fixed_memory_layout=( + fixed_memory_layout[1] if fixed_memory_layout else None + ), ), ValueRef( name="out_ref", @@ -180,18 +217,36 @@ def __init__( src_cpp_name="std::get<0>(out)", src_cpp_type="at::Tensor", is_out=True, + fixed_storage_type=( + fixed_storage_type[0] if fixed_storage_type else None + ), + fixed_memory_layout=( + fixed_memory_layout[0] if fixed_memory_layout else None + ), ), ValueRef( name="out_ref_second", src_cpp_name="std::get<1>(out)", src_cpp_type="at::Tensor", is_out=True, + fixed_storage_type=( + fixed_storage_type[1] if fixed_storage_type else None + ), + fixed_memory_layout=( + fixed_memory_layout[1] if fixed_memory_layout else None + ), ), ValueRef( name="out_ref_third", src_cpp_name="std::get<2>(out)", src_cpp_type="at::Tensor", is_out=True, + fixed_storage_type=( + fixed_storage_type[2] if fixed_storage_type else None + ), + fixed_memory_layout=( + fixed_memory_layout[2] if fixed_memory_layout else None + ), ), ValueRef( name="out_ref", @@ -302,7 +357,12 @@ def create_value_for( # noqa: C901 ret_str += f"{self.graph}{self.dot}" ret_str += "add_input_tensor(" if ref.is_in else "add_tensor(" ret_str += f"{ref.src_cpp_name}->sizes().vec(), " - ret_str += f"from_at_scalartype({ref.src_cpp_name}->scalar_type())); \n" + ret_str += f"from_at_scalartype({ref.src_cpp_name}->scalar_type()" + if ref.fixed_storage_type: + ret_str += f", {ref.fixed_storage_type}" + if ref.fixed_memory_layout: + ret_str += f", {ref.fixed_memory_layout}" + ret_str += "));\n" elif prepack: ret_str += f"{self.graph}{self.dot}" ret_str += f"add_tensorref({ref.src_cpp_name}->sizes().vec(), " @@ -385,7 +445,12 @@ def create_value_for( # noqa: C901 elif ref.src_cpp_type == AT_TENSOR and not prepack: ret_str += "add_input_tensor(" if ref.is_in else "add_tensor(" ret_str += f"{ref.src_cpp_name}.sizes().vec(), " - ret_str += f"from_at_scalartype({ref.src_cpp_name}.scalar_type())); \n" + ret_str += f"from_at_scalartype({ref.src_cpp_name}.scalar_type())" + if ref.fixed_storage_type: + ret_str += f", {ref.fixed_storage_type}" + if ref.fixed_memory_layout: + ret_str += f", {ref.fixed_memory_layout}" + ret_str += ");\n" elif ref.src_cpp_type == AT_TENSOR and prepack: ret_str += f"add_tensorref({ref.src_cpp_name}.sizes().vec(), " ret_str += f"from_at_scalartype({ref.src_cpp_name}.scalar_type()), " diff --git a/backends/vulkan/test/op_tests/utils/gen_correctness_base.py b/backends/vulkan/test/op_tests/utils/gen_correctness_base.py index 5be4ddba6bf..250edf333bc 100644 --- a/backends/vulkan/test/op_tests/utils/gen_correctness_base.py +++ b/backends/vulkan/test/op_tests/utils/gen_correctness_base.py @@ -140,7 +140,13 @@ def call_data_gen_fn(self, arg: Argument, data: Any, terminate: bool = True) -> else self.suite_def.arg_data_range[arg.name] ) - ret_str = f"{self.suite_def.data_gen}({init_list_str(data)}, {tensor_dtype}, {data_range[0]}, {data_range[1]})" + data_gen_fn = ( + self.suite_def.data_gen + if arg.name not in self.suite_def.arg_data_gen_fn + else self.suite_def.arg_data_gen_fn[arg.name] + ) + + ret_str = f"{data_gen_fn}({init_list_str(data)}, {tensor_dtype}, {data_range[0]}, {data_range[1]})" if terminate: ret_str += ";" @@ -288,13 +294,29 @@ def generate_suite_cpp(self) -> str: if (dtype == at::kBool) return at::rand(sizes, at::device(at::kCPU)) > 0.5; - + if (high == 1.0 && low == 0.0) return at::rand(sizes, at::device(at::kCPU).dtype(dtype)); return at::rand(sizes, at::device(at::kCPU).dtype(dtype)) * (high - low) + low; }} +at::Tensor make_zeros_tensor( + std::vector sizes, + at::ScalarType dtype = at::kFloat, + float low = 0.0, + float high = 1.0) {{ + return at::zeros(sizes, at::device(at::kCPU).dtype(dtype)); +}} + +at::Tensor make_ones_tensor( + std::vector sizes, + at::ScalarType dtype = at::kFloat, + float low = 0.0, + float high = 1.0) {{ + return at::ones(sizes, at::device(at::kCPU).dtype(dtype)); +}} + at::Tensor make_seq_tensor( std::vector sizes, at::ScalarType dtype = at::kFloat, diff --git a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py index e7cf5ba92a5..c368c23c539 100644 --- a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py +++ b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py @@ -29,7 +29,6 @@ class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple void SetUp() override {{ GraphConfig config; - config.expect_dynamic_shapes = true; utils::StorageType default_storage_type; utils::GPUMemoryLayout default_memory_layout; std::tie(test_dtype, default_storage_type, default_memory_layout) = GetParam(); diff --git a/backends/vulkan/test/op_tests/utils/test_suite.py b/backends/vulkan/test/op_tests/utils/test_suite.py index 72ba457b5af..427864b0d5d 100644 --- a/backends/vulkan/test/op_tests/utils/test_suite.py +++ b/backends/vulkan/test/op_tests/utils/test_suite.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional ################################### ## Generic Test Suite definition ## @@ -23,6 +23,7 @@ def __init__(self, input_cases: List[Any]): self.data_range = (0, 1) self.arg_dtype = {} + self.arg_data_gen_fn: Dict[str, str] = {} self.arg_data_range = {} self.atol: str = "1e-5" @@ -48,3 +49,5 @@ def __init__(self, input_cases: List[Any]): self.layouts: List[str] = ["utils::kChannelsPacked"] self.data_gen: str = "make_rand_tensor" self.force_io: bool = True + self.arg_storage_types: Dict[str, str] = {} + self.arg_memory_layouts: Dict[str, str] = {}