diff --git a/.gitmodules b/.gitmodules index 05143134bcf..8cb71f3a18e 100644 --- a/.gitmodules +++ b/.gitmodules @@ -43,3 +43,6 @@ [submodule "examples/demo-apps/android/jni/third-party/fbjni"] path = examples/demo-apps/android/jni/third-party/fbjni url = https://github.com/facebookincubator/fbjni.git +[submodule "backends/arm/third-party/ethos-u-core-driver"] + path = backends/arm/third-party/ethos-u-core-driver + url = https://git.mlplatform.org/ml/ethos-u/ethos-u-core-driver.git diff --git a/CMakeLists.txt b/CMakeLists.txt index f0281766aab..122d9006b20 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -334,6 +334,13 @@ if(EXECUTORCH_BUILD_QNN) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/examples/qualcomm) endif() +# Build Arm Baremetal backend +option(EXECUTORCH_BUILD_ARM_BAREMETAL + "Build the Arm Baremetal flow for Cortex-M and Ethos-U" OFF) +if(EXECUTORCH_BUILD_ARM_BAREMETAL) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/arm) +endif() + # Add selective build subdirectory if(BUILD_SELECTIVE_BUILD_TEST) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/examples/selective_build) diff --git a/backends/arm/CMakeLists.txt b/backends/arm/CMakeLists.txt new file mode 100644 index 00000000000..2b40086091b --- /dev/null +++ b/backends/arm/CMakeLists.txt @@ -0,0 +1,36 @@ +# Copyright 2023 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +cmake_minimum_required(VERSION 3.19) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +# Source root directory for executorch. +if(NOT EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) +endif() + +include(${EXECUTORCH_ROOT}/build/Utils.cmake) + +set(_common_include_directories ${EXECUTORCH_ROOT}/..) + +include(cmake/Dependencies.cmake) + +set(_arm_baremetal_sources backends/arm/runtime/ArmBackendEthosU.cpp) +list(TRANSFORM _arm_baremetal_sources PREPEND "${EXECUTORCH_ROOT}/") + +add_library( + executorch_delegate_ethos_u + STATIC ${_arm_baremetal_sources} +) +target_include_directories( + executorch_delegate_ethos_u + PUBLIC + ${_common_include_directories} +) +target_include_directories( + executorch_delegate_ethos_u + PUBLIC + ${DRIVER_ETHOSU_INCLUDE_DIR} +) diff --git a/backends/arm/arm_backend.py b/backends/arm/arm_backend.py index 6b08d94e3aa..f0f285418c6 100644 --- a/backends/arm/arm_backend.py +++ b/backends/arm/arm_backend.py @@ -12,6 +12,8 @@ import logging import operator import os +import struct +import subprocess import tempfile from typing import final, List @@ -136,13 +138,89 @@ def dbg_tosa_dump(tosa_fb, path): fb = tosa_fb.serialize() js = tosa_fb.writeJson(filename) - f = open(path + filename, "wb") - f.write(fb) - f.close() + with open(path + filename, "wb") as f: + f.write(fb) - f = open(path + "desc.json", "w") - f.write(js) - f.close() + with open(path + "desc.json", "w") as f: + f.write(js) + + +# Output to Vela with current file-based compilation +# WARNING: if this changes, the runtime reader also needs to change +def vela_compile(tosa_fb): + with tempfile.TemporaryDirectory() as tmpdir: + tosaname = "out.tosa" + flatbuffer = tosa_fb.serialize() + with open(os.path.join(tmpdir, tosaname), "wb") as f: + f.write(flatbuffer) + + # invoke vela + vela_command = ( + f"cd {tmpdir}; vela --accelerator-config ethos-u55-128 {tosaname}" + ) + subprocess.run([vela_command], shell=True, check=True) + + np_path = os.path.join(tmpdir, "output", "out_sg0_vela.npz") + blocks = b"" + with np.load(np_path, allow_pickle=False) as data: + # Emit the NPZ regions as: + # - 16 byte block name null terminated string (padded to 16 if name shorter) + # - 4 bytes of int32 block length and 12 bytes of 0's + # - block data (padded to 16 byte alignment at end) + # Repeat for all blocks + for key in data.keys(): + block_name = bytes(key, "utf8")[:15] + block_name = block_name + b"\x00" * (16 - len(block_name)) + + block_data = b"" + if key in ("input_shape", "output_shape"): + inputs = data[key] + # Encode a struct of int len; and one or more int x,y,z,w shape; + input_struct = struct.pack(" 1: + raise RuntimeError( + "Currently only support one output in Vela ArmBackend" + ) + offset_struct = struct.pack("&1 | grep "^gcc"` + + +# +# Prepare and run clean build +# +rm -rf buck-out/ build/lib/ cmake-out/ +rm -rf cmake-corstone +mkdir cmake-corstone +cd cmake-corstone + +#cmake -DBUCK2=buck2 .. + +#cmake --toolchain backends/arm/cmake/arm-none-eabi-gcc.cmake .. +cmake -DFLATC_EXECUTABLE=flatc \ + -DEXECUTORCH_BUILD_XNNPACK=OFF \ + -DEXECUTORCH_BUILD_HOST_TARGETS=OFF \ + -DEXECUTORCH_BUILD_ARM_BAREMETAL=ON \ + -DCMAKE_SYSTEM_PROCESSOR=cortex-m55+nodsp+nofp \ + -DETHOSU_TARGET_NPU_CONFIG=ethos-u55-128 \ + --toolchain backends/arm/cmake/arm-none-eabi-gcc.cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DEXECUTORCH_ENABLE_LOGGING_RELEASE_MODE=ON \ + .. + +cd .. +cmake --build cmake-corstone -j9 --target ethos_u ethosu_core_driver executorch portable_ops_lib portable_kernels diff --git a/backends/arm/cmake/toolchain.sh b/backends/arm/cmake/toolchain.sh new file mode 100755 index 00000000000..92188ee982d --- /dev/null +++ b/backends/arm/cmake/toolchain.sh @@ -0,0 +1,12 @@ +#!/bin/bash +# Copyright 2023 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +set -e + +# Cross compiler for Arm baremetal (e.g. Corestone-300 FVP or silcon) +ARCH=$(uname -i) +curl -o gcc.tar.xz https://armkeil.blob.core.windows.net/developer/Files/downloads/gnu/12.3.rel1/binrel/arm-gnu-toolchain-12.3.rel1-${ARCH}-arm-none-eabi.tar.xz +tar xf gcc.tar.xz +export PATH=${PATH}:`(cd arm-gnu-toolchain-12.3.rel1-aarch64-arm-none-eabi/bin/; pwd)` diff --git a/backends/arm/runtime/ArmBackendEthosU.cpp b/backends/arm/runtime/ArmBackendEthosU.cpp new file mode 100644 index 00000000000..17625bdf20d --- /dev/null +++ b/backends/arm/runtime/ArmBackendEthosU.cpp @@ -0,0 +1,315 @@ +/* + * Copyright 2023 Arm Limited and/or its affiliates. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Arm backend for Ethos-U baremetal driver stack, this relies on the + * ethos-u-core-driver for hardware interaction. + */ + +#include +#include +#include + +#include +#include +#include + +#include +#include + +using namespace std; + +namespace torch { +namespace executor { + +// TODO: we should be in 0x31, to access a full 2MB SRAM +// region and enable maximum program performance up to +// 2MB, rather than 1. +// SRAM (rwx) : ORIGIN = 0x31000000, LENGTH = 0x00200000 +#define CS300_SRAM_LOW ((void*)0x11000000) +#define CS300_SRAM_HIGH ((void*)0x110FFFFF) + +class ArmBackend final : public PyTorchBackendInterface { + public: + ArmBackend() {} + + ~ArmBackend() = default; + + virtual bool is_available() const override { + // TODO: revise to use a register check/init function + return 1; + } + + Result init( + BackendInitContext& context, + FreeableBuffer* processed, + ArrayRef compile_specs) const override { + ET_LOG(Info, "ArmBackend::init %p", processed->data()); + + char* data = (char*)processed->data(); + size_t size = processed->size(); + char* foot = data + size - 16; + + // Header and footer both 16 bit aligned suggest valid structure and we + // wont walk off the end of the chunks and segfault + if (!((int)data == next_mul_16((uintptr_t)data))) { + ET_LOG(Error, "ArmBackend::init: Binary needs to be 16 byte unaligned"); + return Error::InvalidProgram; + } + if (!((int)foot == next_mul_16((uintptr_t)foot))) { + ET_LOG(Error, "ArmBackend::init: Footer expected to be 16 byte aligned"); + ET_LOG( + Error, + "ArmBackend::init: Program expected to be multiple of 16 bytes"); + return Error::InvalidProgram; + } + if (!(0 == strncmp(data, "vela_bin_stream", 15))) { + ET_LOG(Error, "ArmBackend::init: Binary passed is not a vela_bin_stream"); + return Error::InvalidProgram; + } + if (!(0 == strncmp(foot, "vela_end_stream", 15))) { + ET_LOG(Error, "ArmBackend::init: Binary passed missing vela_end_stream"); + return Error::InvalidProgram; + } + // Verify address range is accessible current expectation is the program + // is wholly stored in SRAM + // TODO: expect to improve capabilities here by supporting DRAM storage + // and only moving required data into SRAM. + if (!(data > CS300_SRAM_LOW || foot < CS300_SRAM_HIGH)) { + ET_LOG(Error, "ArmBackend::init: Expected program binary to be in SRAM"); + ET_LOG( + Error, + "ArmBackend::init: program binary range %p:%p", + data, + foot + 16); + return Error::InvalidProgram; + } + + // Return the same buffer we were passed - this data will be + // executed directly + return processed; + } + + Error execute( + BackendExecutionContext& context, + DelegateHandle* input_handle, + EValue** args) const override { + FreeableBuffer* processed = (FreeableBuffer*)input_handle; + + ET_LOG(Info, "ArmBackend::execute %p", processed->data()); + + VelaHandles handles; + + // Command stream - we know at this point it's aligned + char* data = (char*)processed->data(); + + // Read key sections from the vela_bin_stream + if (!this->vela_read(data, &handles, processed->size())) { + ET_LOG(Error, "ArmBackend::vela_read: error, invalid binary layout"); + return Error::InvalidProgram; + } + + ET_LOG( + Debug, + "ArmBackend::execute: Running program data:\n cmd %p %d\n weight %p %d\n scratch %p %d\n", + handles.cmd_data, + handles.cmd_data_size, + handles.weight_data, + handles.weight_data_size, + handles.scratch_data, + handles.scratch_data_size); + + // Write inputs into SRAM scratch area defined by Vela + for (int i = 0; i < handles.input_shapes.size(); i++) { + const char* input_addr = handles.scratch_data + handles.input_offset[i]; + // Process input EValue into scratch + // TODO: Optimise into direct write from Vela into the SRAM or DRAM output + // for compatible data layouts. + int* input_address = (int*)input_addr; + auto tensor_in = args[i]->toTensor(); + for (int j = 0; j < tensor_in.numel(); j++) { + // TODO: extend beyond tensors with 4 byte elements + input_address[j] = tensor_in.mutable_data_ptr()[j]; + } + } + + // Allocate driver handle and synchronously invoke driver + ethosu_driver* drv = ethosu_reserve_driver(); + if (drv == NULL) { + ET_LOG(Error, "ArmBackend::execute: ethosu_reserve_driver failed"); + return Error::InvalidState; + } + + // Ethos-U low level driver expected order for Ethos U-55, we have + // constant weight data, then scratch (which contains input and output) + // scratch is written above in this function. + uint64_t bases[2] = { + (uint64_t)handles.weight_data, (uint64_t)handles.scratch_data}; + size_t bases_size[2] = { + handles.weight_data_size, handles.scratch_data_size}; + int result = ethosu_invoke_v3( + drv, + (void*)handles.cmd_data, + handles.cmd_data_size, + bases, + bases_size, + 2, /* fixed array of pointers to binary interface*/ + nullptr); + + if (result != 0) { + ET_LOG( + Error, + "ArmBackend::execute: Ethos-U invocation failed error (%d)", + result); + return Error::InvalidProgram; + } + + // output data from Ethos U + // We only handle one output at the moment + const char* output_addr = handles.scratch_data + handles.output_offset[0]; + // Outputs are in the index immediately after inputs + int output_index = handles.input_shapes.size(); + + if (handles.output_shapes.size() != 1) { + ET_LOG( + Error, + "ArmBackend::execute: currently only support one return tensor"); + return Error::InvalidProgram; + } + // Process results into EValue storage + // TODO: optimise into direct write for compatible, contig layout + int* output_address = (int*)output_addr; + auto tensor_out = args[output_index]->toTensor(); + for (int j = 0; j < tensor_out.numel(); j++) { + // TODO: extend beyond tensors with 4 byte elements + tensor_out.mutable_data_ptr()[j] = output_address[j]; + } + + return Error::Ok; + } + + void destroy(DelegateHandle* handle) const override { + return; + } + + private: + typedef struct { + const char* cmd_data; + size_t cmd_data_size; + const char* weight_data; + size_t weight_data_size; + const char* scratch_data; + size_t scratch_data_size; + vector input_offset; + vector> input_shapes; + vector output_offset; + vector> output_shapes; + } VelaHandles; + + typedef struct { + char name[16]; + uint32_t size; + char _pad[12]; + char data[]; + } VelaBinBlock; + + typedef struct { + int count; + int shape[][4]; + } VelaShapes; + + typedef struct { + int count; + int offsets[]; + } VelaOffsets; + + static int next_mul_16(int n) { + return ((n - 1) | 15) + 1; + } + + int vela_read(char* data, VelaHandles* handles, int size) const { + constexpr const size_t header_size = 16; + + // Read header string + if (strncmp(data, "vela_bin_stream", 15)) { + return 0; + } + data += header_size; + + // Expect one or more 'VelaBinBlock's + while (1) { + VelaBinBlock* b = (VelaBinBlock*)data; + data += sizeof(VelaBinBlock) + next_mul_16(b->size); + + // Exit with success on finding end of stream + if (!strncmp(b->name, "vela_end_stream", strlen("vela_end_stream"))) + return 1; + + if (!strncmp(b->name, "cmd_data", strlen("cmd_data"))) { + // This magic header confirms a valid command stream in binary + if (strncmp(b->data, "COP1", strlen("COP1"))) + return 0; + handles->cmd_data = b->data; + handles->cmd_data_size = b->size; + } + if (!strncmp(b->name, "weight_data", strlen("weight_data"))) { + handles->weight_data = b->data; + handles->weight_data_size = b->size; + } + if (!strncmp(b->name, "scratch_data", strlen("scratch_data"))) { + handles->scratch_data = b->data; + handles->scratch_data_size = b->size; + } + + // capture inputs and outputs + if (!strncmp(b->name, "input_offset", strlen("input_offset"))) { + VelaOffsets* offsets = (VelaOffsets*)b->data; + for (int i = 0; i < offsets->count; i++) { + handles->input_offset.push_back(offsets->offsets[i]); + } + } + if (!strncmp(b->name, "output_offset", strlen("output_offset"))) { + VelaOffsets* offsets = (VelaOffsets*)b->data; + for (int i = 0; i < offsets->count; i++) { + handles->output_offset.push_back(offsets->offsets[i]); + } + } + + if (!strncmp(b->name, "input_shape", strlen("input_shape"))) { + VelaShapes* shapes = (VelaShapes*)b->data; + for (int i = 0; i < shapes->count; i++) { + vector s = { + shapes->shape[i][0], + shapes->shape[i][1], + shapes->shape[i][2], + shapes->shape[i][3]}; + handles->input_shapes.push_back(s); + } + } + if (!strncmp(b->name, "output_shape", strlen("output_shape"))) { + VelaShapes* shapes = (VelaShapes*)b->data; + for (int i = 0; i < shapes->count; i++) { + vector s = { + shapes->shape[i][0], + shapes->shape[i][1], + shapes->shape[i][2], + shapes->shape[i][3]}; + handles->output_shapes.push_back(s); + } + } + } + } +}; + +namespace { +auto backend = ArmBackend(); +Backend backend_id{"ArmBackend", &backend}; +static auto registered = register_backend(backend_id); +} // namespace + +} // namespace executor +} // namespace torch diff --git a/backends/arm/test/test_models.py b/backends/arm/test/test_models.py index 3400a7c8f7c..46a57a601b8 100644 --- a/backends/arm/test/test_models.py +++ b/backends/arm/test/test_models.py @@ -25,6 +25,7 @@ class TosaProfile(Enum): BI = 0 # Base Inference MI = 1 # Main Inference MT = 2 # Main Training + BI_INT = 3 # integer only BI subset tests (for test graphs) class TorchBuilder: @@ -67,6 +68,7 @@ class simple_add(torch.nn.Module): inputs = { TosaProfile.BI: (torch.ones(5),), TosaProfile.MI: (torch.ones(5),), + TosaProfile.BI_INT: (torch.ones(5, dtype=torch.int32),), } def __init__(self): @@ -75,9 +77,28 @@ def __init__(self): def forward(self, x): return x + x + @register_test + class simple_add_2(torch.nn.Module): + inputs = { + TosaProfile.BI_INT: ( + torch.ones(5, dtype=torch.int32), + torch.ones(5, dtype=torch.int32), + ), + } + + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x + y + @register_test class simple_add_broadcast(torch.nn.Module): inputs = { + TosaProfile.BI_INT: ( + torch.ones(10, 1, dtype=torch.int32), + torch.ones(10, 10, dtype=torch.int32), + ), TosaProfile.BI: ( torch.ones(10, 1), torch.ones(10, 10), @@ -110,7 +131,7 @@ def forward(self, x): x = self.fc(x) return x - @register_test + # @register_test class simple_conv2d(torch.nn.Module): inputs = { TosaProfile.BI: ( @@ -134,7 +155,7 @@ def forward(self, x): x = self.conv2d(x) return x - @register_test + # @register_test class block_two_conv2d(torch.nn.Module): inputs = { TosaProfile.BI: (torch.ones(1, 3, 256, 256),), @@ -155,7 +176,7 @@ def forward(self, x): x = self.conv2d_2(x) return x - @register_test + # @register_test class simple_depthwise_conv2d(torch.nn.Module): inputs = { TosaProfile.BI: ( @@ -259,7 +280,7 @@ def __init__(self): def forward(self, x): return self.softmax(x) - @register_test + # @register_test class block_conv_norm_activation(torch.nn.Module): inputs = { TosaProfile.BI: (torch.ones(1, 3, 256, 256),), @@ -281,7 +302,7 @@ def forward(self, x): x = self.relu6(x) return x - @register_test + # @register_test class block_bottleneck_residual(torch.nn.Module): # This is the essence of MobileNetV2 # Ref: https://arxiv.org/abs/1801.04381 diff --git a/backends/arm/test/test_tosa.py b/backends/arm/test/test_tosa.py index b3e59658641..9736503e626 100644 --- a/backends/arm/test/test_tosa.py +++ b/backends/arm/test/test_tosa.py @@ -17,6 +17,8 @@ from executorch.exir.backend.backend_api import to_backend +from executorch.exir.backend.compile_spec_schema import CompileSpec + # Config for Capturing the weights, will be moved in the future _CAPTURE_CONFIG = exir.CaptureConfig(enable_aot=True) _EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig( @@ -37,9 +39,12 @@ def test_minimal_MI(self): for test_model in TestList: print(f"Running test {test_model}") model, inputs, outputs = prepare_model_and_ref(test_model, TosaProfile.MI) - - model_edge, exec_prog = export_model(model, inputs, []) - # TODO: check there is a tosa delegate blob in the output + if inputs is None: + print(" Skipping, no inputs for this profile") + continue + model_edge, exec_prog = export_model( + model, inputs, [CompileSpec("output_format", bytes("tosa", "utf8"))] + ) def test_minimal_BI(self): for test_model in TestList: @@ -48,14 +53,31 @@ def test_minimal_BI(self): if inputs is None: print(" Skipping, no inputs for this profile") continue - model_edge, exec_prog = export_model(model, inputs, []) - # TODO: check there is a tosa delegate blob in the output + model_edge, exec_prog = export_model( + model, inputs, [CompileSpec("output_format", bytes("tosa", "utf8"))] + ) + + def test_minimal_BI_INT(self): + for test_model in TestList: + print(f"Running test {test_model}") + model, inputs, outputs = prepare_model_and_ref( + test_model, TosaProfile.BI_INT + ) + if inputs is None: + print(" Skipping, no inputs for this profile") + continue + model_edge, exec_prog = export_model( + model, inputs, [CompileSpec("output_format", bytes("tosa", "utf8"))] + ) def prepare_model_and_ref(test_model, profile=TosaProfile.MI): model = TestList[test_model] model_inputs = model.inputs.get(profile) + if model_inputs is None: + return model, model_inputs, None + model.eval() if profile == TosaProfile.BI: # Quantize the model @@ -72,10 +94,8 @@ def prepare_model_and_ref(test_model, profile=TosaProfile.MI): prepared_model(*model.inputs[profile]) model = convert_pt2e(prepared_model) - if model_inputs is not None: - model_outputs = model.forward(*model_inputs) - return model, model_inputs, model_outputs - return model, model_inputs, None + model_outputs = model.forward(*model_inputs) + return model, model_inputs, model_outputs def export_model(model, inputs, compile_spec): diff --git a/backends/arm/third-party/ethos-u-core-driver b/backends/arm/third-party/ethos-u-core-driver new file mode 160000 index 00000000000..90f9df900ac --- /dev/null +++ b/backends/arm/third-party/ethos-u-core-driver @@ -0,0 +1 @@ +Subproject commit 90f9df900acdc0718ecd2dfdc53780664758dec5 diff --git a/examples/arm/arm_ethosu_minimal.py b/examples/arm/arm_ethosu_minimal.py new file mode 100644 index 00000000000..93b73909251 --- /dev/null +++ b/examples/arm/arm_ethosu_minimal.py @@ -0,0 +1,212 @@ +# Copyright 2023 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +import subprocess + +import executorch.exir as exir + +import numpy as np +from executorch.backends.arm.arm_backend import ArmPartitioner +from executorch.backends.arm.test.test_models import TosaProfile +from executorch.backends.arm.test.test_tosa import prepare_model_and_ref + +from executorch.exir.backend.backend_api import to_backend +from executorch.exir.backend.canonical_partitioners.duplicate_dequant_node_pass import ( + DuplicateDequantNodePass, +) +from executorch.exir.backend.compile_spec_schema import CompileSpec + +# Assumes you have these two tools on your path +TOSA_REF_MODEL_PATH = "tosa_reference_model" +VELA_COMPILER_PATH = "vela" + +# Basic config for graph capture +_CAPTURE_CONFIG = exir.CaptureConfig(enable_aot=True) +_EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig( + _check_ir_validity=False, +) + +EXAMPLE_TEST_LIST = ["simple_add", "simple_add_2"] + +# +# +# +# +def tosa_ref_capture_inputs( + model_edge, + inputs, + path, + input_quantization_scales, + input_quantization_zps, + profile=TosaProfile.MI, +): + # Emit TOSA test data from the model inputs - assumes whole graph lowered so we just have + # placeholders for the TOSA delegate. Emits data in tosa_ref_model expected layout. + # - Skips placeholders which are encoded as constants (i.e. are already captured weights) + # - Assumes argument order is fixed + argument_names = [] + for node in model_edge.exported_program.graph.nodes: + gs = model_edge.exported_program.graph_signature + if node.op == "placeholder": + if node.name in gs.inputs_to_parameters: + pass + elif node.name in gs.inputs_to_buffers: + pass + else: + argument_names.append(node.name) + else: + break + + for arg in zip(argument_names, inputs): + name = arg[0] + data = arg[1].detach().numpy() + file_path = path + "/" + name + ".npy" + + # Torch is doing Input[FP32]->Q[INT8]->DQ[FP32]->Operator[FP32]->Q[INT]->DQ[FP32]->[Output]FP32 + # Need to quantize the input to INT8 for TOSA comsumption + if profile is TosaProfile.BI: + data_quantized = ( + (data / input_quantization_scales[name]) - input_quantization_zps[name] + ).astype(np.int8) + np.save(file_path, data_quantized, allow_pickle=False) + else: + np.save(file_path, data, allow_pickle=False) + + +# +# Minimal sequence to take a model through the ArmPartitioner and produce +# both TOSA intermediate output, and an Ethos-U55 command stream within +# the ExecuTorch .pte binary +# +def run_test(op, profile=TosaProfile.MI, output_path="./ethosout/"): + # + # Minimal sequence to take model through TosaPartitioner and emit + # tosaout/ debug directory containing the flatbuffer - assumes one and will only save last output + # tosaout is generated even for partial/broken subgraph capture to aid in debg + # delegated.pte containing the flatbuffer within the executorch flatbuffer binary + # + print(f"\n\033[96mProcessing:::{op}\033[0m") + print(f"\033[96mDebug output path for intermediates: {output_path}\033[0m") + + os.makedirs(output_path, exist_ok=True) + + # Debug output for TORCH + TORCH_OUT_PATH = os.path.join(output_path, op, "torch", "") + os.makedirs(TORCH_OUT_PATH, exist_ok=True) + + # Debug output for TOSA + TOSA_OUT_PATH = os.path.join(output_path, op, "tosa", "") + os.makedirs(TOSA_OUT_PATH, exist_ok=True) + + model, inputs, torch_output = prepare_model_and_ref(op, profile) + + if inputs is None: + print("\033[96m Skipping, model has no inputs for TOSA profile \033[0m") + return + + print(f" Model: {op}\n Inputs: {inputs}\n Outputs: {torch_output}") + + # Export model + model_capture = exir.capture(model, inputs, _CAPTURE_CONFIG) + model_edge = model_capture.to_edge(_EDGE_COMPILE_CONFIG) + + # Partition with ArmBackend + ArmPartitioner.compile_spec = [ + CompileSpec("debug_tosa_path", bytes(TOSA_OUT_PATH, "utf8")) + ] + model_edge.exported_program = to_backend( + model_edge.transform(DuplicateDequantNodePass()).exported_program, + ArmPartitioner, + ) + exec_prog = model_edge.to_executorch() + + # Save .pte including delegated Vela section + with open(TORCH_OUT_PATH + "/delegated.pte", "wb") as fh: + fh.write(exec_prog.buffer) + + # NOTE: + # Additional steps from here are optional but can be helpful with + # debug as they will capture the inputs and outputs as well as running + # the intermediate output on the tosa_reference_model. + # This can ensure the compilation flow is working correctly as part of + # a development loop, ahead of running the example on hardware. + + # Save inputs for TOSA reference run + tosa_ref_capture_inputs(model_edge, inputs, TOSA_OUT_PATH, {}, {}, profile) + + # Save ground truth results to file + with open(TORCH_OUT_PATH + "/torch_output.npy", "wb") as f: + np.save(f, torch_output.detach().numpy()) + + # Convert TOSA Flatbuffer into JSON format for human debugging + cmd_flatc = ( + "flatc" + + " -o " + + TOSA_OUT_PATH + + " --raw-binary -t ./backends/arm/third-party/serialization_lib/schema/tosa.fbs -- " + + TOSA_OUT_PATH + + "/output.tosa" + ) + subprocess.run([cmd_flatc], shell=True, check=True) + + ### Run the TOSA flatbuffer through TOSA Ref_Model and print the results + DESC_FILE_NAME = "/desc.json" + DESC_FILE_PATH = TOSA_OUT_PATH + DESC_FILE_NAME + cmd_ref_model = TOSA_REF_MODEL_PATH + " --test_desc " + DESC_FILE_PATH + subprocess.run([cmd_ref_model], shell=True, check=True) + + ## Load in the JSON File, Read the tosa output + desc_file = open(DESC_FILE_PATH) + desc_json = json.load(desc_file) + tosa_out_filenames = desc_json["ofm_file"] + for tosa_out_fm_file_name in tosa_out_filenames: + f = open(TOSA_OUT_PATH + "/" + tosa_out_fm_file_name, "rb") + tosa_output = np.load(f) + + ## Read the Torch Output + torch_file = open(TORCH_OUT_PATH + "/torch_output.npy", "rb") + torch_output = np.load(torch_file) + + ## Compare Tosa and Torch Results + if np.allclose(tosa_output, torch_output, rtol=1e-1, atol=1e-1, equal_nan=True): + print( + "\033[92m" + + "Torch and Tosa Reference results are matching for operator: " + + op + + " from " + + str(str(profile)) + + "\033[0m" + ) + + else: + print("\033[91m" + "Sorry, Torch and Tosa Reference Results Do not Match!") + print("============================") + print("TOSA Output Shape is: " + str(tosa_output.shape)) + print("TOSA Output is: ") + print(tosa_output) + print("\033[93m") + print("============================") + print("Torch Output Shape is: " + str(torch_output.shape)) + print("Torch Output is: ") + print(torch_output) + print("\033[0m") + + if profile in (TosaProfile.BI, TosaProfile.BI_INT): + cmd_vela = "cd " + TOSA_OUT_PATH + "; " + VELA_COMPILER_PATH + " ./output.tosa" + try: + subprocess.run([cmd_vela], shell=True, check=True) + print("\033[92m" + "Vela compile worked for: " + op + "\033[0m") + except: + print("\033[91m" + "Vela compile failed for: " + op + "\033[0m") + else: + print("\033[96m" + "Skipping Vela test on non-BI profile." + "\033[0m") + + +# systest mode for running all models against both inference profiles +if __name__ == "__main__": + for op in EXAMPLE_TEST_LIST: + run_test(op, profile=TosaProfile.BI_INT) diff --git a/examples/arm/arm_tosa_e2e.py b/examples/arm/arm_tosa_e2e.py index 0dba4fa9866..80f1e19a357 100644 --- a/examples/arm/arm_tosa_e2e.py +++ b/examples/arm/arm_tosa_e2e.py @@ -144,8 +144,13 @@ def tosa_run_test(op, profile=TosaProfile.MI): # noqa: C901 TOSA_OUT_PATH = os.path.join(DEBUG_OUTPUT_PATH, op, "tosa", "") os.makedirs(TOSA_OUT_PATH, exist_ok=True) - # Debug flag for compilers - compile_spec = [CompileSpec("debug_tosa_path", bytes(TOSA_OUT_PATH, "utf8"))] + # Debug flags for compilers + # - Emit some debug files into /tmp + # - output_format TOSA for this test (and pure tosa flows) + compile_spec = [ + CompileSpec("debug_tosa_path", bytes(TOSA_OUT_PATH, "utf8")), + CompileSpec("output_format", bytes("tosa", "utf8")), + ] model, inputs, torch_output = prepare_model_and_ref(op, profile) diff --git a/examples/arm/ethos-u-setup/core_platform/patches/0007-Add-delegate-runner-test.patch b/examples/arm/ethos-u-setup/core_platform/patches/0007-Add-delegate-runner-test.patch new file mode 100644 index 00000000000..c1270961510 --- /dev/null +++ b/examples/arm/ethos-u-setup/core_platform/patches/0007-Add-delegate-runner-test.patch @@ -0,0 +1,300 @@ +From 0fe8caba3068da05021232912c069124a81e0d94 Mon Sep 17 00:00:00 2001 +From: Rob Elliott +Date: Wed, 4 Oct 2023 13:31:33 +0000 +Subject: [PATCH] Add delegate runner test + +Signed-off-by: Rob Elliott +--- + applications/executorch_tests/CMakeLists.txt | 27 ++- + .../executorch_tests/pte_to_header.py | 11 +- + .../executorch_tests/runner_delegate.cpp | 160 ++++++++++++++++++ + cmake/toolchain/arm-none-eabi-gcc.cmake | 6 +- + 4 files changed, 195 insertions(+), 9 deletions(-) + create mode 100644 applications/executorch_tests/runner_delegate.cpp + +diff --git a/applications/executorch_tests/CMakeLists.txt b/applications/executorch_tests/CMakeLists.txt +index c95d53e..835f824 100644 +--- a/applications/executorch_tests/CMakeLists.txt ++++ b/applications/executorch_tests/CMakeLists.txt +@@ -28,20 +28,24 @@ set(ET_DIR_PATH "<..>/executorch" CACHE PATH "Path to ExecuTorch dir") + set(ET_BUILD_DIR_PATH "${ET_DIR_PATH}/cmake-out" CACHE PATH "Path to ExecuTorch build dir") + set(ET_INCLUDE_PATH "${ET_DIR_PATH}/.." CACHE PATH "Path to ExecuTorch headers") + set(ET_PTE_FILE_PATH "${ET_PTE_FILE_PATH}" CACHE PATH "Path to ExecuTorch model pte") ++set(ET_PTE_DELEGATE_FILE_PATH "${ET_PTE_DELGATE__FILE_PATH}" CACHE PATH "Path to ExecuTorch delegate model pte") + + get_filename_component(ET_BUILD_DIR_PATH ${ET_BUILD_DIR_PATH} REALPATH) + get_filename_component(ET_DIR_PATH ${ET_DIR_PATH} REALPATH) + get_filename_component(ET_INCLUDE_PATH ${ET_INCLUDE_PATH} REALPATH) + get_filename_component(ET_PTE_FILE_PATH ${ET_PTE_FILE_PATH} REALPATH) ++get_filename_component(ET_PTE_DELEGATE_FILE_PATH ${ET_PTE_DELEGATE_FILE_PATH} REALPATH) + + message("**********************") + message("ExecuTorch dir (ET_DIR_PATH) : ${ET_DIR_PATH}") + message("ExecuTorch build dir(ET_BUILD_DIR_PATH) : ${ET_BUILD_DIR_PATH}") + message("ExecuTorch headers (ET_INCUDE_PATH) : ${ET_INCLUDE_PATH}") + message("ExecuTorch pte file (ET_PTE_FILE_PATH) : ${ET_PTE_FILE_PATH}") ++message("ExecuTorch pte delegate file (ET_PTE_DELEGATE_FILE_PATH) : ${ET_PTE_DELEGATE_FILE_PATH}") + message("**********************") + + set(LIB_ET_RUNTIME "${ET_BUILD_DIR_PATH}/libexecutorch.a") ++set(LIB_ET_ETHOS "${ET_BUILD_DIR_PATH}/backends/arm/libexecutorch_delegate_ethos_u.a") + set(LIB_ET_OP_REGISTRATION "${ET_BUILD_DIR_PATH}/kernels/portable/libportable_ops_lib.a") + set(LIB_ET_OP_KERNELS "${ET_BUILD_DIR_PATH}/kernels/portable/libportable_kernels.a") + +@@ -54,8 +58,11 @@ add_custom_command( + OUTPUT + ${CMAKE_CURRENT_BINARY_DIR}/fake_dep + ${CMAKE_CURRENT_BINARY_DIR}/model_pte.h ++ ${CMAKE_CURRENT_BINARY_DIR}/model_delegate_pte.h + COMMAND ${PYTHON_EXECUTABLE} ./pte_to_header.py --pte ${ET_PTE_FILE_PATH} +- --out ${CMAKE_CURRENT_BINARY_DIR} ++ --outdir ${CMAKE_CURRENT_BINARY_DIR} ++ COMMAND ${PYTHON_EXECUTABLE} ./pte_to_header.py --pte ${ET_PTE_DELEGATE_FILE_PATH} ++ --outdir ${CMAKE_CURRENT_BINARY_DIR} --outfile model_delegate_pte.h + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + ) + +@@ -67,10 +74,24 @@ ethosu_add_executable_test(executor_runner PRIVATE + ${LIB_ET_OP_REGISTRATION} + ${LIB_ET_OP_KERNELS}) + +-add_dependencies(executor_runner gen_model_header) +- + target_include_directories(executor_runner PRIVATE + ${ET_INCLUDE_PATH} + ${CMAKE_CURRENT_BINARY_DIR}) + ++ethosu_add_executable_test(executor_runner_delegate PRIVATE ++ WHOLE_ARCHIVE TRUE ++ SOURCES runner_delegate.cpp ++ LIBRARIES ++ ${LIB_ET_RUNTIME} ++ ${LIB_ET_ETHOS} ++ ) ++ ++target_include_directories(executor_runner_delegate PRIVATE ++${ET_INCLUDE_PATH} ++${CMAKE_CURRENT_BINARY_DIR}) ++ ++add_dependencies(executor_runner gen_model_header) ++ ++ ++ + # TODO Memory setup +diff --git a/applications/executorch_tests/pte_to_header.py b/applications/executorch_tests/pte_to_header.py +index 37d88aa..be3282d 100644 +--- a/applications/executorch_tests/pte_to_header.py ++++ b/applications/executorch_tests/pte_to_header.py +@@ -30,11 +30,18 @@ parser.add_argument( + ) + parser.add_argument( + "--outdir", +- help="Output dir for model_pte.h", ++ help="Output dir for model header", + type=str, + required=False, + default=".", + ) ++parser.add_argument( ++ "--outfile", ++ help="Output filename for model header", ++ type=str, ++ required=False, ++ default="model_pte.h", ++) + parser.add_argument( + "--section", + help="Section attribute for the data array", +@@ -43,7 +50,7 @@ parser.add_argument( + default=".sram.data", + ) + args = parser.parse_args() +-outfile = os.path.join(args.outdir, "model_pte.h") ++outfile = os.path.join(args.outdir, args.outfile) + attr = f'__attribute__((section("{args.section}"), aligned(16))) char ' + + with open(args.pte, "rb") as fr, open( +diff --git a/applications/executorch_tests/runner_delegate.cpp b/applications/executorch_tests/runner_delegate.cpp +new file mode 100644 +index 0000000..ff40084 +--- /dev/null ++++ b/applications/executorch_tests/runner_delegate.cpp +@@ -0,0 +1,160 @@ ++/* ++ * SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates ++ * ++ * SPDX-License-Identifier: Apache-2.0 ++ * ++ * Licensed under the Apache License, Version 2.0 (the License); you may ++ * not use this file except in compliance with the License. ++ * You may obtain a copy of the License at ++ * ++ * www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an AS IS BASIS, WITHOUT ++ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++/**************************************************************************** ++ * Includes ++ ****************************************************************************/ ++ ++#include ++#include ++#include ++ ++using namespace std; ++ ++#include ++#include ++#include ++#include ++#include ++ ++/**************************************************************************** ++ * Data ++ ****************************************************************************/ ++ ++// Our .pte file generated from the AoT flow ++#include "model_delegate_pte.h" // contains model_pte ++ ++// Storage for intermediate data in SRAM ++__attribute__((section(".sram.data"), aligned(16))) uint8_t method_allocator_pool[4 * 1024U]; ++ ++void et_pal_init(void) {} ++ ++__ET_NORETURN void et_pal_abort(void) { ++ __builtin_trap(); ++} ++ ++et_timestamp_t et_pal_current_ticks(void) { ++ // libc.a - warning: _gettimeofday is not implemented and will always fail ++ return 11223344; ++} ++ ++/** ++ * Emit a log message via platform output (serial port, console, etc). ++ */ ++void et_pal_emit_log_message( ++ __ET_UNUSED et_timestamp_t timestamp, ++ et_pal_log_level_t level, ++ const char* filename, ++ __ET_UNUSED const char* function, ++ size_t line, ++ const char* message, ++ __ET_UNUSED size_t length) { ++ fprintf( ++ stderr, ++ "%c executorch:%s:%zu] %s\n", ++ level, ++ filename, ++ line, ++ message); ++} ++ ++int main() ++{ ++ ET_LOG(Info, "Initialising runtime"); ++ torch::executor::runtime_init(); ++ ++ using torch::executor::Result; ++ using torch::executor::Error; ++ ++ // Load pte from the global model_pte .pte file loaded into SRAM. ++ auto loader = torch::executor::util::BufferDataLoader(model_pte, sizeof(model_pte)); ++ Result program = torch::executor::Program::load(&loader); ++ if(!program.ok()) { ++ ET_LOG(Info, "Program loading failed @ 0x%p: 0x%x", model_pte, (int)program.error()); ++ } ++ ET_LOG(Info, "Model buffer loaded, has %u methods", program->num_methods()); ++ ++ // Find our entrypoint in the .pte program ++ const char* method_name = nullptr; ++ const auto method_name_result = program->get_method_name(0); ++ ET_CHECK_MSG(method_name_result.ok(), "Program has no methods"); ++ method_name = *method_name_result; ++ ET_LOG(Info, "Found (and will run) method '%s'", method_name); ++ ++ // Allocate necessary memories for this method ++ Result method_meta = program->method_meta(method_name); ++ if (!method_meta.ok()) { ++ ET_LOG(Info, "Failed to get method_meta for %s: 0x%x", ++ method_name, (unsigned int)method_meta.error()); ++ } ++ ++ torch::executor::MemoryAllocator method_allocator{ ++ torch::executor::MemoryAllocator(sizeof(method_allocator_pool), method_allocator_pool)}; ++ ++ std::vector> planned_buffers; // Owns the memory ++ std::vector> planned_spans; // Passed to the allocator ++ size_t num_memory_planned_buffers = method_meta->num_memory_planned_buffers(); ++ ++ for (size_t id = 0; id < num_memory_planned_buffers; ++id) { ++ size_t buffer_size = static_cast(method_meta->memory_planned_buffer_size(id).get()); ++ ET_LOG(Info, "Setting up planned buffer %zu, size %zu.", id, buffer_size); ++ ++ planned_buffers.push_back(std::make_unique(buffer_size)); ++ planned_spans.push_back({planned_buffers.back().get(), buffer_size}); ++ } ++ ++ torch::executor::HierarchicalAllocator planned_memory( ++ {planned_spans.data(), planned_spans.size()}); ++ ++ torch::executor::MemoryManager memory_manager(&method_allocator, &planned_memory); ++ ++ Result method = program->load_method(method_name, &memory_manager); ++ ++ if(!method.ok()) { ++ ET_LOG(Info, "Loading of method %s failed with status 0x%x", method_name, (int)method.error()); ++ } ++ ET_LOG(Info, "Loading of method '%s' succesful", method_name); ++ ++ auto inputs = torch::executor::util::PrepareInputTensors(*method); ++ ++ ET_LOG(Info, "Starting the model execution..."); ++ Error status = method->execute(); ++ if(status != Error::Ok){ ++ ET_LOG(Info, "Execution of method %s failed with status 0x%x", method_name, (int)status); ++ } else { ++ ET_LOG(Info, "Model executed successfully."); ++ } ++ ++ // Print the outputs. ++ std::vector outputs(method->outputs_size()); ++ ET_LOG(Info, "%d outputs - ", outputs.size()); ++ status = method->get_outputs(outputs.data(), outputs.size()); ++ ET_CHECK(status == Error::Ok); ++ for (size_t i = 0; i < outputs.size(); ++i) ++ { ++ ET_LOG(Info, "Output %d numel %d", i, outputs[i].toTensor().numel()); ++ for (size_t j = 0; j < outputs[i].toTensor().numel(); ++j) ++ { ++ ET_LOG(Info, " Output[%d]: %d", j, outputs[i].toTensor().const_data_ptr()[j]); ++ } ++ } ++ ++ return 0; ++} ++ ++ +diff --git a/cmake/toolchain/arm-none-eabi-gcc.cmake b/cmake/toolchain/arm-none-eabi-gcc.cmake +index 0e6a2ed..fdb0d7c 100644 +--- a/cmake/toolchain/arm-none-eabi-gcc.cmake ++++ b/cmake/toolchain/arm-none-eabi-gcc.cmake +@@ -98,8 +98,6 @@ add_compile_options( + # -Wswitch + # -Wswitch-default + # -Wunused +- +- # -Wno-redundant-decls +- +- # -Wno-psabi ++ -Wno-redundant-decls ++ -Wno-psabi + ) +-- +2.41.0 + diff --git a/examples/arm/ethos-u-setup/ethos-u-vela/patches/0001-Improve-rescale-codegen-for-TOSA.patch b/examples/arm/ethos-u-setup/ethos-u-vela/patches/0001-Improve-rescale-codegen-for-TOSA.patch new file mode 100644 index 00000000000..e131ca76ee8 --- /dev/null +++ b/examples/arm/ethos-u-setup/ethos-u-vela/patches/0001-Improve-rescale-codegen-for-TOSA.patch @@ -0,0 +1,129 @@ +From ef07230fbb15edbf27ecaf48994fb157430a5e7c Mon Sep 17 00:00:00 2001 +From: Rob Elliott +Date: Thu, 5 Oct 2023 16:45:42 +0000 +Subject: [PATCH] Improve rescale codegen for TOSA + +Signed-off-by: Rob Elliott +--- + ethosu/vela/tosa_graph_optimiser.py | 56 +++++++++++------------------ + ethosu/vela/tosa_mapping.py | 2 +- + 2 files changed, 22 insertions(+), 36 deletions(-) + +diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py +index df6b575..b2e3697 100644 +--- a/ethosu/vela/tosa_graph_optimiser.py ++++ b/ethosu/vela/tosa_graph_optimiser.py +@@ -337,7 +337,8 @@ def rewrite_concat(op): + + def remove_memory_ops(op, arch): + if op.run_on_npu and op.type in (Op.Reshape, Op.Identity): +- bypass_memory_only_ops(op) ++ # TODO: is this ok - function doesn't use arch or nng ++ bypass_memory_only_ops(op, arch, None) + + + def rewrite_activation(op, arch, nng): +@@ -357,7 +358,6 @@ def rewrite_activation(op, arch, nng): + + return op + +- + def rewrite_rescale(op, arch, nng): + if op.type == Op.Rescale: + ifm = op.ifm +@@ -368,7 +368,7 @@ def rewrite_rescale(op, arch, nng): + prev_op = ifm.ops[0] + + # TODO currently not supported +- assert len(ifm.consumer_list) == 1 ++ #assert len(ifm.consumer_list) == 1 + + input_zp = op.attrs["input_zp"] + output_zp = op.attrs["output_zp"] +@@ -390,6 +390,9 @@ def rewrite_rescale(op, arch, nng): + assert False + ifm.quantization.zero_point = input_zp + ofm.quantization.zero_point = output_zp ++ ++ assert False == per_channel, "Don't like per_channel!" ++ + for s, m in zip(shift, multiplier): + # TODO these are the TOSA limitations + assert m >= 0 +@@ -403,45 +406,28 @@ def rewrite_rescale(op, arch, nng): + else: + rounding_mode = RoundingMode.HalfUp + +- if prev_op.type.is_depthwise_conv2d_op() or prev_op.type.is_conv2d_op() or prev_op.type == Op.FullyConnected: ++ fuse = len(ifm.ops) == 1 and prev_op.type.is_depthwise_conv2d_op() or prev_op.type.is_conv2d_op() ++ if fuse: ++ # TODO: ERROR: bias.values didn't exist for an op like Add - presumably not a capability of that op + assert len(multiplier) == len(shift) == len(prev_op.bias.values) +- +- if ifm.dtype == DataType.int32 and per_channel: +- prev_op.explicit_scaling = explicit_scaling +- prev_op.rounding_mode = rounding_mode +- +- # Bypass op +- prev_op.set_output_tensor(ofm) +- DebugDatabase.add_optimised(op, prev_op) +- return op +- else: +- print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type) +- assert False +- # TODO which are the cases we need to and can do standalone Rescale? +- # TODO should we try to identify a conversion uint8<->int8 accomplished by 2 RESCALE ops? +- # origin might be TFLite op QUANTIZE, should we look to see if they can be translated to QUANTIZE? +- # limited to these at the moment: +- elif ( +- (ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8) +- or (ifm.dtype == DataType.uint8 and ofm.dtype == DataType.int8) +- or (ifm.dtype == DataType.int8 and ofm.dtype == DataType.uint8) +- ): +- # Create NOP performing the RESCALE ++ # TODO: generate replacement fusion code from below ++ assert False, "Fusion possible but i've not implemented it" ++ else: ++ # Generate Rescale behaviour attached to a compatible NOP ++ # TODO: I assume this attaches a new operator into the graph?? + avgpool_op = replace_rescale_with_avg_pool(op) + avgpool_op.rounding_mode = rounding_mode +- ++ + if per_channel: +- # TODO +- avgpool_op.explicit_scaling = explicit_scaling +- print("Warning, unsupported TOSA Rescale") +- assert False ++ assert False, "Assert above removed but still not implemented... :/" + else: + avgpool_op.explicit_scaling = explicit_scaling +- else: +- print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type) +- assert False +- return op + ++ #print( len(multiplier), len(shift), len(prev_op.get_bias_tensors()) ) ++ #print( ifm.dtype, "PC:", per_channel, op.type ) ++ #print( ifm.dtype, ofm.dtype ) ++ ++ return op + + def convert_pad_in_width(op): + """ +diff --git a/ethosu/vela/tosa_mapping.py b/ethosu/vela/tosa_mapping.py +index 2dafd81..ed5aa2e 100644 +--- a/ethosu/vela/tosa_mapping.py ++++ b/ethosu/vela/tosa_mapping.py +@@ -148,7 +148,7 @@ transpose_conv_attrs = AttrSerializer( + ) + transpose_attrs = AttrSerializer("TransposeAttribute", (("perms", is_vec),)) + axis_attrs = AttrSerializer("AxisAttribute", ("axis",)) +-reshape_attrs = AttrSerializer("ReshapeAttribute", (("shape", is_vec),)) ++reshape_attrs = AttrSerializer("ReshapeAttribute", (("newShape", is_vec),)) + slice_attrs = AttrSerializer("SliceAttribute", (("start", is_vec), ("size", is_vec))) + tile_attrs = AttrSerializer("TileAttribute", (("multiplies", is_vec),)) + resize_attrs = AttrSerializer( +-- +2.41.0 + diff --git a/examples/arm/run.sh b/examples/arm/run.sh index 828ac16bdc6..3f9bd37d90c 100755 --- a/examples/arm/run.sh +++ b/examples/arm/run.sh @@ -7,7 +7,7 @@ set -eu -if [[ "${1}" == "-h" ]]; then +if [[ "${1:-"."}" == "-h" ]]; then echo "Usage: $(basename $0) [path-to-a-scratch-dir] [buck2 binary]" exit 0 fi @@ -18,7 +18,8 @@ fi script_dir=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) # Ethos-u -root_dir=${1:-"$(realpath ${script_dir}/ethos-u-scratch)"} +root_dir=${1:-"${script_dir}/ethos-u-scratch"} +root_dir=$(realpath ${root_dir}) buck2=${2:-"/tmp/buck2"} ethos_u_root_dir="$(cd ${root_dir}/ethos-u && pwd)" ethos_u_build_dir=${ethos_u_root_dir}/core_platform/build @@ -43,6 +44,16 @@ function generate_pte_file() { echo "${pte_file}" } +# Generate the ethos delegate PTE file +function generate_ethos_pte_file() { + cd $et_root_dir + python3 examples/arm/arm_ethosu_minimal.py &> /dev/null + cd ./ethosout/simple_add/torch/ + local pte_file=$(realpath ./delegated.pte) + [[ -f ${pte_file} ]] || { echo "Failed to generate a pte file - ${pte_file}"; exit 1; } + echo "${pte_file}" +} + # build ExecuTorch Libraries function build_executorch() { [[ -d "${et_build_dir}" ]] \ @@ -56,6 +67,7 @@ function build_executorch() { -DEXECUTORCH_BUILD_GFLAGS=OFF \ -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=OFF \ -DEXECUTORCH_BUILD_HOST_TARGETS=OFF \ + -DEXECUTORCH_BUILD_ARM_BAREMETAL=ON \ -DCMAKE_BUILD_TYPE=Release \ -DEXECUTORCH_ENABLE_LOGGING=ON \ -DEXECUTORCH_SELECT_OPS_LIST="aten::_softmax.out" \ @@ -73,8 +85,9 @@ function build_executorch() { # build Arm Baremetal executor_runner function build_executorch_runner() { - [[ $# -ne 1 ]] && { echo "[${FUNCNAME[0]}]" "Expecting pte file as an argument got, $*"; exit 1; } + [[ $# -ne 2 ]] && { echo "[${FUNCNAME[0]}]" "Expecting 2 pte files as arguments got, $*"; exit 1; } local pte=${1} + local pte_delegate=${2} cd "${ethos_u_root_dir}"/core_platform cmake \ -DCMAKE_TOOLCHAIN_FILE=${toolchain_cmake} \ @@ -82,18 +95,21 @@ function build_executorch_runner() { -DET_DIR_PATH:PATH=${et_root_dir} \ -DET_BUILD_DIR_PATH:PATH=${et_build_dir} \ -DET_PTE_FILE_PATH:PATH="${pte}" \ + -DET_PTE_DELEGATE_FILE_PATH:PATH="${pte_delegate}" \ -DPYTHON_EXECUTABLE=$(which python3) echo "[${FUNCNAME[0]}] Configured CMAKE" n=$(nproc) - cmake --build build -- -j"$((n - 5))" executor_runner VERBOSE=1 + cmake --build build -- -j"$((n - 5))" executor_runner executor_runner_delegate VERBOSE=1 echo "[${FUNCNAME[0]}] Generated baremetal elf file:" find . -name "executor_runner.elf" } # Execute the executor_runner on FVP Simulator function run_fvp() { - elf=$(find ${ethos_u_build_dir} -name "executor_runner.elf") + [[ $# -ne 1 ]] && { echo "[${FUNCNAME[0]}]" "Expexted elf binary name, got $*"; exit 1; } + local elf_name=${1} + elf=$(find ${ethos_u_build_dir} -name "${elf_name}") [[ ! -f $elf ]] && { echo "[${FUNCNAME[0]}]: Unable to find executor_runner elf: ${elf}"; exit 1; } FVP_Corstone_SSE-300_Ethos-U55 \ -C ethosu.num_macs=128 \ @@ -101,7 +117,7 @@ function run_fvp() { -C mps3_board.telnetterminal0.start_telnet=0 \ -C mps3_board.uart0.out_file='-' \ -a "${elf}" \ - --timelimit 10 # seconds + --timelimit 5 || true # seconds echo "[${FUNCNAME[0]} Simulation complete, $?" } @@ -132,14 +148,18 @@ type ${buck2} 2>&1 > /dev/null \ # get the pte pte=$(generate_pte_file) +pte_delegate=$(generate_ethos_pte_file) # build et build_executorch # build the et baremetal app -build_executorch_runner "${pte}" +build_executorch_runner "${pte}" "${pte_delegate}" # run the app -run_fvp +run_fvp executor_runner.elf + +# run the delegate app +run_fvp executor_runner_delegate.elf exit 0 diff --git a/examples/arm/setup.sh b/examples/arm/setup.sh index d6f6880e173..34b20498cd7 100755 --- a/examples/arm/setup.sh +++ b/examples/arm/setup.sh @@ -7,37 +7,13 @@ set -eu -if [[ "${1}" == "-h" ]]; then +if [[ "${1:-'.'}" == "-h" ]]; then echo "Usage: $(basename $0) [path-to-a-scratch-dir]" exit 0 fi ######## -### Hardcoded constants -######## -script_dir=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) - -# FVP -fvp_url="https://developer.arm.com/-/media/Arm%20Developer%20Community/Downloads/OSS/FVP/Corstone-300/FVP_Corstone_SSE-300_11.22_20_Linux64.tgz?rev=018659bd574f4e7b95fa647e7836ccf4&hash=22A79103C6FA5FFA7AFF3BE0447F3FF9" -fvp_model_dir="Linux64_GCC-9.3" -fvp_md5_checksum="98e93b949d0fbac977292d8668d34523" - -# toochain -toolchain_url="https://armkeil.blob.core.windows.net/developer/Files/downloads/gnu/12.3.rel1/binrel/arm-gnu-toolchain-12.3.rel1-x86_64-arm-none-eabi.tar.xz" -toolchain_dir="arm-gnu-toolchain-12.3.rel1-x86_64-arm-none-eabi" -toolchain_md5_checksum="00ebb1b70b1f88906c61206457eacb61" - -# ethos-u -ethos_u_repo_url="https://review.mlplatform.org/ml/ethos-u/ethos-u" -ethos_u_base_rev="0995223100e3da8011700f58e491f1bf59511e3c" - -######## -### Optional user args -######## -root_dir=${1:-"$(realpath ${script_dir}/ethos-u-scratch)"} - -######## -### Functions +### Helper functions ######## function get_os_name() { # Returns the name of the system i.e. Linux or Darwin @@ -62,6 +38,49 @@ function verify_md5() { fi } +######## +### Hardcoded constants +######## +script_dir=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) + +if [[ $(get_cpu_arch) == "x86_64" ]]; then + # FVP + fvp_url="https://developer.arm.com/-/media/Arm%20Developer%20Community/Downloads/OSS/FVP/Corstone-300/FVP_Corstone_SSE-300_11.22_20_Linux64.tgz?rev=018659bd574f4e7b95fa647e7836ccf4&hash=22A79103C6FA5FFA7AFF3BE0447F3FF9" + fvp_model_dir="Linux64_GCC-9.3" + fvp_md5_checksum="98e93b949d0fbac977292d8668d34523" + + # toochain + toolchain_url="https://armkeil.blob.core.windows.net/developer/Files/downloads/gnu/12.3.rel1/binrel/arm-gnu-toolchain-12.3.rel1-x86_64-arm-none-eabi.tar.xz" + toolchain_dir="arm-gnu-toolchain-12.3.rel1-x86_64-arm-none-eabi" + toolchain_md5_checksum="00ebb1b70b1f88906c61206457eacb61" +elif [[ $(get_cpu_arch) == "aarch64" ]]; then + # FVP + fvp_url="https://developer.arm.com/-/media/Arm%20Developer%20Community/Downloads/OSS/FVP/Corstone-300/FVP_Corstone_SSE-300_11.22_20_Linux64_armv8l.tgz?rev=9cc6e9a32bb947ca9b21fa162144cb01&hash=7657A4CF27D42E892E3F08D452AAB073" + fvp_model_dir="Linux64_armv8l_GCC-9.3" + fvp_md5_checksum="cbbabbe39b07939cff7a3738e1492ef1" + + # toochain + toolchain_url="https://armkeil.blob.core.windows.net/developer/Files/downloads/gnu/12.3.rel1/binrel/arm-gnu-toolchain-12.3.rel1-aarch64-arm-none-eabi.tar.xz" + toolchain_dir="arm-gnu-toolchain-12.3.rel1-aarch64-arm-none-eabi" + toolchain_md5_checksum="02c9b0d3bb1110575877d8eee1f223f2" +else + echo "[main] Error: only x86-64 & aarch64 architecture is supported for now!"; exit 1; +fi + +# ethos-u +ethos_u_repo_url="https://review.mlplatform.org/ml/ethos-u/ethos-u" +ethos_u_base_rev="0995223100e3da8011700f58e491f1bf59511e3c" + +######## +### Optional user args +######## +root_dir=${1:-"${script_dir}/ethos-u-scratch"} +root_dir=$(realpath ${root_dir}) + +######## +### Functions +######## + function setup_fvp() { # Download and install the Corstone 300 FVP simulator platform cd "${root_dir}" @@ -132,14 +151,48 @@ function patch_repo() { echo -e "[${FUNCNAME[0]}] Patched ${name} @ $(git describe --all --long 2> /dev/null) in ${repo_dir} dir.\n" } +function setup_tosa_reference_model() { + # The debug flow on the host includes running on a reference implementation of TOSA + # This is useful primarily for debug of quantization accuracy, but also for internal + # errors for the early codebase + cd "${root_dir}" + if [[ ! -e reference_model ]]; then + git clone https://git.mlplatform.org/tosa/reference_model.git -b v0.80.0 + cd reference_model + git submodule update --init --recursive + cd .. + fi + cd reference_model + mkdir -p build + cd build + cmake .. + make + cd reference_model + tosa_bin_path=`pwd` + echo "export PATH=\${PATH}:${tosa_bin_path}" >> "${setup_path_script}" +} + +function setup_vela() { + # + # Prepare the Vela compiler for AoT to Ethos-U compilation + # + cd "${root_dir}/ethos-u/" + if [[ ! -e ethos-u-vela ]]; then + git clone https://git.mlplatform.org/ml/ethos-u/ethos-u-vela.git + name="ethos-u-vela" + base_rev=00a15db3e1a188b25065d095152d701f4394cdc5 + patch_repo + fi + pip install . +} + ######## ### main ######## # do basic checks # Make sure we are on a supported platform -# Linux ARM64 is a supported platform - adding it here is a WIP -[[ "$(get_cpu_arch)" != "x86_64" ]] \ - && { echo "[main] Error: only x86-64 architecture is supported for now!"; exit 1; } +[[ $(get_cpu_arch) != "x86_64" ]] && [[ $(get_cpu_arch) != "aarch64" ]] \ + && { echo "[main] Error: only x86-64 & aarch64 architecture is supported for now!"; exit 1; } # No OSx support for FVP [[ "$(get_os_name)" != "Linux" ]] \ @@ -169,6 +222,13 @@ name="core_platform" base_rev=204210b1074071532627da9dc69950d058a809f4 patch_repo +# Setup the tosa_reference_model +setup_tosa_reference_model + +# Setup vela and patch in codegen fixes +setup_vela + echo "[main] update path by doing 'source ${setup_path_script}'" -echo "[main] sucecss!" + +echo "[main] success!" exit 0