diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c60d01f0d3..f2e8b5d1f0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,10 +12,12 @@ jobs: fail-fast: false matrix: working_directory: ["nx", "exla", "torchx"] - elixir: ["1.15.4", "1.16.2"] - otp: ["25.3"] + elixir: ["1.15.8", "1.18.4"] include: - - elixir: "1.16.2" + - elixir: "1.15.8" + otp: "25.3" + - elixir: "1.18.4" + otp: "27.3" lint: true defaults: run: @@ -57,8 +59,9 @@ jobs: fail-fast: false matrix: working_directory: ["nx", "torchx"] - elixir: ["1.16.2"] - otp: ["25.2"] + include: + - elixir: "1.18.4" + otp: "27.3" defaults: run: working-directory: ${{ matrix.working_directory }} diff --git a/exla/Makefile b/exla/Makefile index c8f5be211c..43e79e3907 100644 --- a/exla/Makefile +++ b/exla/Makefile @@ -25,7 +25,7 @@ EXLA_CACHE_SO_LINK_PATH = $(CWD_RELATIVE_TO_PRIV_PATH)/$(EXLA_CACHE_SO) # Note that XLA requires c++17, Fine as well CFLAGS += -fPIC -I$(ERTS_INCLUDE_DIR) -I$(FINE_INCLUDE_DIR) -I$(XLA_INCLUDE_PATH) -Wall -Wno-sign-compare \ -Wno-unused-parameter -Wno-missing-field-initializers -Wno-comment \ - -std=c++17 -w -DLLVM_VERSION_STRING= + -std=c++17 -w ifdef DEBUG CFLAGS += -g @@ -33,7 +33,7 @@ else CFLAGS += -O3 endif -NVCC := $(CXX) +NVCC = $(CXX) NVCCFLAGS = $(CFLAGS) LDFLAGS += -L$(XLA_EXTENSION_LIB) -lxla_extension -shared -fvisibility=hidden @@ -48,8 +48,8 @@ $(info EXLA_CPU_ONLY is not set, checking for nvcc availability) ifeq ($(NVCC_TEST),nvcc) $(info CUDA is available.) - NVCC := nvcc - NVCCFLAGS += -DCUDA_ENABLED + NVCC = nvcc + NVCCFLAGS = -Xcompiler "$(CFLAGS)" -DCUDA_ENABLED else $(info CUDA is not available.) endif @@ -82,7 +82,7 @@ $(EXLA_SO): $(EXLA_CACHE_SO) ln -sf $(EXLA_CACHE_SO_LINK_PATH) $(EXLA_SO) ; \ fi -SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/custom_calls.cc $(EXLA_DIR)/ipc.cc +SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/ipc.cc SOURCES += $(wildcard $(EXLA_DIR)/custom_calls/*.cc) HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls/qr.h $(EXLA_DIR)/custom_calls/eigh.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h $(EXLA_DIR)/ipc.h OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(SOURCES)) $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o diff --git a/exla/c_src/exla/custom_calls.cc b/exla/c_src/exla/custom_calls.cc deleted file mode 100644 index 8acd67cab6..0000000000 --- a/exla/c_src/exla/custom_calls.cc +++ /dev/null @@ -1,23 +0,0 @@ -#include "xla/service/custom_call_target_registry.h" - -void qr_cpu_custom_call_f32(void *out[], const void *in[]); -void qr_cpu_custom_call_f64(void *out[], const void *in[]); -void qr_cpu_custom_call_f16(void *out[], const void *in[]); -void qr_cpu_custom_call_bf16(void *out[], const void *in[]); -void lu_cpu_custom_call_f32(void *out[], const void *in[]); -void lu_cpu_custom_call_f64(void *out[], const void *in[]); -void lu_cpu_custom_call_f16(void *out[], const void *in[]); -void lu_cpu_custom_call_bf16(void *out[], const void *in[]); -void eigh_cpu_custom_call_f32(void *out[], const void *in[]); -void eigh_cpu_custom_call_f64(void *out[], const void *in[]); - -XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f64", qr_cpu_custom_call_f64); -XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f32", qr_cpu_custom_call_f32); -XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f16", qr_cpu_custom_call_f16); -XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_bf16", qr_cpu_custom_call_bf16); -XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f64", eigh_cpu_custom_call_f64); -XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f32", eigh_cpu_custom_call_f32); -XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lu_cpu_custom_call_f64", lu_cpu_custom_call_f64); -XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lu_cpu_custom_call_f32", lu_cpu_custom_call_f32); -XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lu_cpu_custom_call_f16", lu_cpu_custom_call_f16); -XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lu_cpu_custom_call_bf16", lu_cpu_custom_call_bf16); \ No newline at end of file diff --git a/exla/c_src/exla/custom_calls/eigh.h b/exla/c_src/exla/custom_calls/eigh.h index 55cb5adfc1..9ab743336f 100644 --- a/exla/c_src/exla/custom_calls/eigh.h +++ b/exla/c_src/exla/custom_calls/eigh.h @@ -1,12 +1,16 @@ #pragma once -#include "Eigen/Eigenvalues" - #include #include #include #include +#include "Eigen/Eigenvalues" +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/ffi_api.h" + +namespace ffi = xla::ffi; + template void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out, DataType *eigenvectors_out, @@ -55,51 +59,32 @@ void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out, m * n * sizeof(DataType)); } -template -void eigh_cpu_custom_call(void *out[], const void *in[]) { - DataType *operand = (DataType *)in[0]; - - uint64_t *dim_sizes = (uint64_t *)in[1]; - uint64_t num_operand_dims = dim_sizes[0]; - uint64_t num_eigenvalues_dims = dim_sizes[1]; - uint64_t num_eigenvectors_dims = dim_sizes[2]; - - uint64_t *operand_dims_ptr = (uint64_t *)in[2]; - std::vector operand_dims(operand_dims_ptr, - operand_dims_ptr + num_operand_dims); - - uint64_t *eigenvalues_dims_ptr = (uint64_t *)in[3]; - std::vector eigenvalues_dims( - eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims); - - uint64_t *eigenvectors_dims_ptr = (uint64_t *)in[4]; - std::vector eigenvectors_dims( - eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims); +template +ffi::Error eigh_cpu_custom_call_impl(BufferType operand, + ffi::Result eigenvalues, + ffi::Result eigenvectors) { + auto operand_dims = operand.dimensions(); + auto eigenvalues_dims = eigenvalues->dimensions(); + auto eigenvectors_dims = eigenvectors->dimensions(); uint64_t m = eigenvectors_dims[eigenvectors_dims.size() - 2]; uint64_t n = eigenvectors_dims[eigenvectors_dims.size() - 1]; - auto leading_dimensions = - std::vector(operand_dims.begin(), operand_dims.end() - 2); - uint64_t batch_items = 1; - for (uint64_t i = 0; i < leading_dimensions.size(); i++) { - batch_items *= leading_dimensions[i]; + for (auto it = operand_dims.begin(); it != operand_dims.end() - 2; it++) { + batch_items *= *it; } - DataType *eigenvalues = (DataType *)out[0]; - DataType *eigenvectors = (DataType *)out[1]; - uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1]; - uint64_t eigenvectors_stride = - eigenvectors_dims[eigenvectors_dims.size() - 1] * - eigenvectors_dims[eigenvectors_dims.size() - 2]; + uint64_t eigenvectors_stride = m * n; uint64_t inner_stride = m * n; for (uint64_t i = 0; i < batch_items; i++) { single_matrix_eigh_cpu_custom_call( - eigenvalues + i * eigenvalues_stride, - eigenvectors + i * eigenvectors_stride, operand + i * inner_stride, m, - n); + eigenvalues->typed_data() + i * eigenvalues_stride, + eigenvectors->typed_data() + i * eigenvectors_stride, + operand.typed_data() + i * inner_stride, m, n); } -} \ No newline at end of file + + return ffi::Error::Success(); +} diff --git a/exla/c_src/exla/custom_calls/eigh_f32.cc b/exla/c_src/exla/custom_calls/eigh_f32.cc index 62395984ad..5e9e15096c 100644 --- a/exla/c_src/exla/custom_calls/eigh_f32.cc +++ b/exla/c_src/exla/custom_calls/eigh_f32.cc @@ -1,5 +1,19 @@ #include "eigh.h" -void eigh_cpu_custom_call_f32(void *out[], const void *in[]) { - eigh_cpu_custom_call(out, in); +ffi::Error +eigh_cpu_custom_call_f32_impl(ffi::Buffer operand, + ffi::ResultBuffer eigenvalues, + ffi::ResultBuffer eigenvectors) { + return eigh_cpu_custom_call_impl>( + operand, eigenvalues, eigenvectors); } + +XLA_FFI_DEFINE_HANDLER_SYMBOL(eigh_cpu_custom_call_f32, + eigh_cpu_custom_call_f32_impl, + ffi::Ffi::Bind() + .Arg>() + .Ret>() + .Ret>()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "eigh_cpu_custom_call_f32", + "Host", eigh_cpu_custom_call_f32); diff --git a/exla/c_src/exla/custom_calls/eigh_f64.cc b/exla/c_src/exla/custom_calls/eigh_f64.cc index 5e7ffef084..b047c234e7 100644 --- a/exla/c_src/exla/custom_calls/eigh_f64.cc +++ b/exla/c_src/exla/custom_calls/eigh_f64.cc @@ -1,5 +1,19 @@ #include "eigh.h" -void eigh_cpu_custom_call_f64(void *out[], const void *in[]) { - eigh_cpu_custom_call(out, in); +ffi::Error +eigh_cpu_custom_call_f64_impl(ffi::Buffer operand, + ffi::ResultBuffer eigenvalues, + ffi::ResultBuffer eigenvectors) { + return eigh_cpu_custom_call_impl>( + operand, eigenvalues, eigenvectors); } + +XLA_FFI_DEFINE_HANDLER_SYMBOL(eigh_cpu_custom_call_f64, + eigh_cpu_custom_call_f64_impl, + ffi::Ffi::Bind() + .Arg>() + .Ret>() + .Ret>()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "eigh_cpu_custom_call_f64", + "Host", eigh_cpu_custom_call_f64); diff --git a/exla/c_src/exla/custom_calls/lu.h b/exla/c_src/exla/custom_calls/lu.h index 1c72565d4b..51d100191e 100644 --- a/exla/c_src/exla/custom_calls/lu.h +++ b/exla/c_src/exla/custom_calls/lu.h @@ -1,16 +1,30 @@ #pragma once -#include "Eigen/LU"; +#include +#include +#include +#include + +#include "Eigen/LU" +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/ffi_api.h" + +namespace ffi = xla::ffi; template -void single_matrix_lu_cpu_custom_call(uint8_t *p_out, DataType *l_out, DataType *u_out, DataType *in, uint64_t n) { - typedef Eigen::Matrix RowMajorMatrix; +void single_matrix_lu_cpu_custom_call(uint8_t *p_out, DataType *l_out, + DataType *u_out, DataType *in, + uint64_t n) { + typedef Eigen::Matrix + RowMajorMatrix; Eigen::Map input(in, n, n); Eigen::PartialPivLU lu = input.partialPivLu(); // Get the permutation matrix P and convert to indices - Eigen::PermutationMatrix P = lu.permutationP(); + Eigen::PermutationMatrix P = + lu.permutationP(); for (uint64_t i = 0; i < n; i++) { for (uint64_t j = 0; j < n; j++) { p_out[i * n + j] = static_cast(P.indices()[i] == j ? 1 : 0); @@ -24,7 +38,6 @@ void single_matrix_lu_cpu_custom_call(uint8_t *p_out, DataType *l_out, DataType // Copy L matrix for (uint64_t i = 0; i < n; i++) { for (uint64_t j = 0; j < n; j++) { - if (j < i) { l_out[i * n + j] = static_cast(L(i, j)); } else if (j == i) { @@ -47,49 +60,28 @@ void single_matrix_lu_cpu_custom_call(uint8_t *p_out, DataType *l_out, DataType } } -template -void lu_cpu_custom_call(void *out[], const void *in[]) { - DataType *operand = (DataType *)in[0]; - - uint64_t *dim_sizes = (uint64_t *)in[1]; - uint64_t num_operand_dims = dim_sizes[0]; - uint64_t num_p_dims = dim_sizes[1]; - uint64_t num_l_dims = dim_sizes[2]; - uint64_t num_u_dims = dim_sizes[3]; - - uint64_t *operand_dims_ptr = (uint64_t *)in[2]; - std::vector operand_dims(operand_dims_ptr, operand_dims_ptr + num_operand_dims); - - uint64_t *p_dims_ptr = (uint64_t *)in[3]; - std::vector p_dims(p_dims_ptr, p_dims_ptr + num_p_dims); - - uint64_t *l_dims_ptr = (uint64_t *)in[4]; - std::vector l_dims(l_dims_ptr, l_dims_ptr + num_l_dims); - - uint64_t *u_dims_ptr = (uint64_t *)in[5]; - std::vector u_dims(u_dims_ptr, u_dims_ptr + num_u_dims); - +template +ffi::Error +lu_cpu_custom_call_impl(BufferType operand, ffi::Result> p, + ffi::Result l, ffi::Result u) { + auto operand_dims = operand.dimensions(); + auto l_dims = l->dimensions(); uint64_t n = l_dims[l_dims.size() - 1]; - auto leading_dimensions = std::vector(operand_dims.begin(), operand_dims.end() - 2); - uint64_t batch_items = 1; - for (uint64_t i = 0; i < leading_dimensions.size(); i++) { - batch_items *= leading_dimensions[i]; + for (auto it = operand_dims.begin(); it != operand_dims.end() - 2; it++) { + batch_items *= *it; } - uint8_t *p = (uint8_t *)out[0]; - DataType *l = (DataType *)out[1]; - DataType *u = (DataType *)out[2]; - uint64_t stride = n * n; for (uint64_t i = 0; i < batch_items; i++) { single_matrix_lu_cpu_custom_call( - p + i * stride, - l + i * stride, - u + i * stride, - operand + i * stride, - n); + p->typed_data() + i * stride, + (DataType *)l->untyped_data() + i * stride, + (DataType *)u->untyped_data() + i * stride, + (DataType *)operand.untyped_data() + i * stride, n); } -} \ No newline at end of file + + return ffi::Error::Success(); +} diff --git a/exla/c_src/exla/custom_calls/lu_bf16.cc b/exla/c_src/exla/custom_calls/lu_bf16.cc index 806f886b4c..3d040d1121 100644 --- a/exla/c_src/exla/custom_calls/lu_bf16.cc +++ b/exla/c_src/exla/custom_calls/lu_bf16.cc @@ -1,6 +1,21 @@ -#include "lu.h" #include "../exla_types.h" +#include "lu.h" -void lu_cpu_custom_call_bf16(void *out[], const void *in[]) { - lu_cpu_custom_call(out, in); +ffi::Error lu_cpu_custom_call_bf16_impl(ffi::Buffer operand, + ffi::ResultBuffer p, + ffi::ResultBuffer l, + ffi::ResultBuffer u) { + return lu_cpu_custom_call_impl>( + operand, p, l, u); } + +XLA_FFI_DEFINE_HANDLER_SYMBOL(lu_cpu_custom_call_bf16, + lu_cpu_custom_call_bf16_impl, + ffi::Ffi::Bind() + .Arg>() + .Ret>() + .Ret>() + .Ret>()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "lu_cpu_custom_call_bf16", "Host", + lu_cpu_custom_call_bf16); diff --git a/exla/c_src/exla/custom_calls/lu_f16.cc b/exla/c_src/exla/custom_calls/lu_f16.cc index 81f6724e6e..022248bdef 100644 --- a/exla/c_src/exla/custom_calls/lu_f16.cc +++ b/exla/c_src/exla/custom_calls/lu_f16.cc @@ -1,6 +1,21 @@ -#include "lu.h" #include "../exla_types.h" +#include "lu.h" -void lu_cpu_custom_call_f16(void *out[], const void *in[]) { - lu_cpu_custom_call(out, in); +ffi::Error lu_cpu_custom_call_f16_impl(ffi::Buffer operand, + ffi::ResultBuffer p, + ffi::ResultBuffer l, + ffi::ResultBuffer u) { + return lu_cpu_custom_call_impl>(operand, + p, l, u); } + +XLA_FFI_DEFINE_HANDLER_SYMBOL(lu_cpu_custom_call_f16, + lu_cpu_custom_call_f16_impl, + ffi::Ffi::Bind() + .Arg>() + .Ret>() + .Ret>() + .Ret>()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "lu_cpu_custom_call_f16", "Host", + lu_cpu_custom_call_f16); diff --git a/exla/c_src/exla/custom_calls/lu_f32.cc b/exla/c_src/exla/custom_calls/lu_f32.cc index c506caab72..0332190158 100644 --- a/exla/c_src/exla/custom_calls/lu_f32.cc +++ b/exla/c_src/exla/custom_calls/lu_f32.cc @@ -1,5 +1,20 @@ #include "lu.h" -void lu_cpu_custom_call_f32(void *out[], const void *in[]) { - lu_cpu_custom_call(out, in); +ffi::Error lu_cpu_custom_call_f32_impl(ffi::Buffer operand, + ffi::ResultBuffer p, + ffi::ResultBuffer l, + ffi::ResultBuffer u) { + return lu_cpu_custom_call_impl>(operand, p, l, + u); } + +XLA_FFI_DEFINE_HANDLER_SYMBOL(lu_cpu_custom_call_f32, + lu_cpu_custom_call_f32_impl, + ffi::Ffi::Bind() + .Arg>() + .Ret>() + .Ret>() + .Ret>()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "lu_cpu_custom_call_f32", "Host", + lu_cpu_custom_call_f32); diff --git a/exla/c_src/exla/custom_calls/lu_f64.cc b/exla/c_src/exla/custom_calls/lu_f64.cc index aed6ed2dab..e129166611 100644 --- a/exla/c_src/exla/custom_calls/lu_f64.cc +++ b/exla/c_src/exla/custom_calls/lu_f64.cc @@ -1,5 +1,20 @@ #include "lu.h" -void lu_cpu_custom_call_f64(void *out[], const void *in[]) { - lu_cpu_custom_call(out, in); +ffi::Error lu_cpu_custom_call_f64_impl(ffi::Buffer operand, + ffi::ResultBuffer p, + ffi::ResultBuffer l, + ffi::ResultBuffer u) { + return lu_cpu_custom_call_impl>(operand, p, l, + u); } + +XLA_FFI_DEFINE_HANDLER_SYMBOL(lu_cpu_custom_call_f64, + lu_cpu_custom_call_f64_impl, + ffi::Ffi::Bind() + .Arg>() + .Ret>() + .Ret>() + .Ret>()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "lu_cpu_custom_call_f64", "Host", + lu_cpu_custom_call_f64); diff --git a/exla/c_src/exla/custom_calls/qr.h b/exla/c_src/exla/custom_calls/qr.h index 85e881447c..ff489a6908 100644 --- a/exla/c_src/exla/custom_calls/qr.h +++ b/exla/c_src/exla/custom_calls/qr.h @@ -1,17 +1,29 @@ - #pragma once +#include +#include +#include +#include + #include "Eigen/QR" +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/ffi_api.h" + +namespace ffi = xla::ffi; template -void single_matrix_qr_cpu_custom_call(DataType *q_out, DataType *r_out, DataType *in, uint64_t m, uint64_t k, uint64_t n, bool complete) { - typedef Eigen::Matrix RowMajorMatrix; +void single_matrix_qr_cpu_custom_call(DataType *q_out, DataType *r_out, + DataType *in, uint64_t m, uint64_t k, + uint64_t n, bool complete) { + typedef Eigen::Matrix + RowMajorMatrix; Eigen::Map input(in, m, n); Eigen::HouseholderQR qr = input.householderQr(); RowMajorMatrix Q, R; - size_t num_bytes_q, num_bytes_r; + size_t num_bytes_q; if (complete) { Q = qr.householderQ() * RowMajorMatrix::Identity(m, m); @@ -40,48 +52,37 @@ void single_matrix_qr_cpu_custom_call(DataType *q_out, DataType *r_out, DataType memcpy(q_out, Q.data(), num_bytes_q); } -template -void qr_cpu_custom_call(void *out[], const void *in[]) { - DataType *operand = (DataType *)in[0]; - - uint64_t *dim_sizes = (uint64_t *)in[1]; - uint64_t num_operand_dims = dim_sizes[0]; - uint64_t num_q_dims = dim_sizes[1]; - uint64_t num_r_dims = dim_sizes[2]; - - uint64_t *operand_dims_ptr = (uint64_t *)in[2]; - std::vector operand_dims(operand_dims_ptr, operand_dims_ptr + num_operand_dims); - - uint64_t *q_dims_ptr = (uint64_t *)in[3]; - std::vector q_dims(q_dims_ptr, q_dims_ptr + num_q_dims); - - uint64_t *r_dims_ptr = (uint64_t *)in[4]; - std::vector r_dims(r_dims_ptr, r_dims_ptr + num_r_dims); +template +ffi::Error qr_cpu_custom_call_impl(BufferType operand, + ffi::Result q, + ffi::Result r) { + auto operand_dims = operand.dimensions(); + auto q_dims = q->dimensions(); + auto r_dims = r->dimensions(); uint64_t m = q_dims[q_dims.size() - 2]; uint64_t k = q_dims[q_dims.size() - 1]; uint64_t n = r_dims[r_dims.size() - 1]; - bool complete = r_dims[r_dims.size() - 2] == m; + uint64_t l = r_dims[r_dims.size() - 2]; - auto leading_dimensions = std::vector(operand_dims.begin(), operand_dims.end() - 2); + bool complete = l == m; uint64_t batch_items = 1; - for (uint64_t i = 0; i < leading_dimensions.size(); i++) { - batch_items *= leading_dimensions[i]; + for (auto it = operand_dims.begin(); it != operand_dims.end() - 2; it++) { + batch_items *= *it; } - DataType *q = (DataType *)out[0]; - DataType *r = (DataType *)out[1]; - - uint64_t r_stride = r_dims[r_dims.size() - 1] * r_dims[r_dims.size() - 2]; - uint64_t q_stride = q_dims[q_dims.size() - 1] * q_dims[q_dims.size() - 2]; + uint64_t q_stride = m * k; + uint64_t r_stride = n * l; uint64_t inner_stride = m * n; for (uint64_t i = 0; i < batch_items; i++) { single_matrix_qr_cpu_custom_call( - (DataType *)out[0] + i * q_stride, - (DataType *)out[1] + i * r_stride, - operand + i * inner_stride, - m, k, n, complete); + (DataType *)q->untyped_data() + i * q_stride, + (DataType *)r->untyped_data() + i * r_stride, + (DataType *)operand.untyped_data() + i * inner_stride, m, k, n, + complete); } -} \ No newline at end of file + + return ffi::Error::Success(); +} diff --git a/exla/c_src/exla/custom_calls/qr_bf16.cc b/exla/c_src/exla/custom_calls/qr_bf16.cc index 32b6e616a9..01cc6ca3cc 100644 --- a/exla/c_src/exla/custom_calls/qr_bf16.cc +++ b/exla/c_src/exla/custom_calls/qr_bf16.cc @@ -1,6 +1,19 @@ -#include "qr.h" #include "../exla_types.h" +#include "qr.h" -void qr_cpu_custom_call_bf16(void *out[], const void *in[]) { - qr_cpu_custom_call(out, in); +ffi::Error qr_cpu_custom_call_bf16_impl(ffi::Buffer operand, + ffi::ResultBuffer q, + ffi::ResultBuffer r) { + return qr_cpu_custom_call_impl>( + operand, q, r); } + +XLA_FFI_DEFINE_HANDLER_SYMBOL(qr_cpu_custom_call_bf16, + qr_cpu_custom_call_bf16_impl, + ffi::Ffi::Bind() + .Arg>() + .Ret>() + .Ret>()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "qr_cpu_custom_call_bf16", "Host", + qr_cpu_custom_call_bf16); diff --git a/exla/c_src/exla/custom_calls/qr_f16.cc b/exla/c_src/exla/custom_calls/qr_f16.cc index ed689ac432..49e5f7b96b 100644 --- a/exla/c_src/exla/custom_calls/qr_f16.cc +++ b/exla/c_src/exla/custom_calls/qr_f16.cc @@ -1,6 +1,19 @@ -#include "qr.h" #include "../exla_types.h" +#include "qr.h" -void qr_cpu_custom_call_f16(void *out[], const void *in[]) { - qr_cpu_custom_call(out, in); +ffi::Error qr_cpu_custom_call_f16_impl(ffi::Buffer operand, + ffi::ResultBuffer q, + ffi::ResultBuffer r) { + return qr_cpu_custom_call_impl>(operand, + q, r); } + +XLA_FFI_DEFINE_HANDLER_SYMBOL(qr_cpu_custom_call_f16, + qr_cpu_custom_call_f16_impl, + ffi::Ffi::Bind() + .Arg>() + .Ret>() + .Ret>()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "qr_cpu_custom_call_f16", "Host", + qr_cpu_custom_call_f16); diff --git a/exla/c_src/exla/custom_calls/qr_f32.cc b/exla/c_src/exla/custom_calls/qr_f32.cc index ffa5f87465..544a192141 100644 --- a/exla/c_src/exla/custom_calls/qr_f32.cc +++ b/exla/c_src/exla/custom_calls/qr_f32.cc @@ -1,5 +1,17 @@ #include "qr.h" -void qr_cpu_custom_call_f32(void *out[], const void *in[]) { - qr_cpu_custom_call(out, in); +ffi::Error qr_cpu_custom_call_f32_impl(ffi::Buffer operand, + ffi::ResultBuffer q, + ffi::ResultBuffer r) { + return qr_cpu_custom_call_impl>(operand, q, r); } + +XLA_FFI_DEFINE_HANDLER_SYMBOL(qr_cpu_custom_call_f32, + qr_cpu_custom_call_f32_impl, + ffi::Ffi::Bind() + .Arg>() + .Ret>() + .Ret>()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "qr_cpu_custom_call_f32", "Host", + qr_cpu_custom_call_f32); diff --git a/exla/c_src/exla/custom_calls/qr_f64.cc b/exla/c_src/exla/custom_calls/qr_f64.cc index 0f930352f8..64ec1e8aac 100644 --- a/exla/c_src/exla/custom_calls/qr_f64.cc +++ b/exla/c_src/exla/custom_calls/qr_f64.cc @@ -1,5 +1,17 @@ #include "qr.h" -void qr_cpu_custom_call_f64(void *out[], const void *in[]) { - qr_cpu_custom_call(out, in); +ffi::Error qr_cpu_custom_call_f64_impl(ffi::Buffer operand, + ffi::ResultBuffer q, + ffi::ResultBuffer r) { + return qr_cpu_custom_call_impl>(operand, q, r); } + +XLA_FFI_DEFINE_HANDLER_SYMBOL(qr_cpu_custom_call_f64, + qr_cpu_custom_call_f64_impl, + ffi::Ffi::Bind() + .Arg>() + .Ret>() + .Ret>()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "qr_cpu_custom_call_f64", "Host", + qr_cpu_custom_call_f64); diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index ed3ce31a03..7a32f5e8a0 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -14,7 +14,7 @@ #include "stablehlo/dialect/StablehloOps.h" #include "xla/pjrt/pjrt_api.h" #include "xla/service/platform_util.h" -#include "xla/statusor.h" +#include "xla/tsl/platform/statusor.h" #include "llvm/Support/ThreadPool.h" namespace exla { @@ -178,7 +178,7 @@ std::string mlir_module_to_string(ErlNifEnv *env, FINE_NIF(mlir_module_to_string, 0); -template T unwrap(xla::StatusOr status_or) { +template T unwrap(tsl::StatusOr status_or) { if (!status_or.ok()) { throw std::runtime_error(status_or.status().message().data()); } @@ -186,7 +186,7 @@ template T unwrap(xla::StatusOr status_or) { return std::move(status_or.value()); } -void unwrap(xla::Status status) { +void unwrap(tsl::Status status) { if (!status.ok()) { throw std::runtime_error(status.message().data()); } @@ -302,8 +302,9 @@ fine::ResourcePtr create_buffer_from_device_pointer( auto device = unwrap( client->client()->LookupDevice(xla::PjRtGlobalDeviceId(device_id))); + auto memory_space = unwrap(device->default_memory_space()); auto buffer = unwrap(client->client()->CreateViewOfDeviceBuffer( - ptr, shape, device, on_delete_callback)); + ptr, shape, memory_space, on_delete_callback)); return fine::make_resource(std::move(buffer)); } @@ -329,7 +330,7 @@ std::variant, fine::Error> deallocate_device_mem(ErlNifEnv *env, fine::Term buffer_term) { auto buffer = decode_exla_buffer(env, buffer_term); - xla::Status dealloc_status = buffer->Deallocate(); + tsl::Status dealloc_status = buffer->Deallocate(); if (!dealloc_status.ok()) { return fine::Error(atoms::already_deallocated); diff --git a/exla/c_src/exla/exla_client.cc b/exla/c_src/exla/exla_client.cc index a13f87082c..9d58239401 100644 --- a/exla/c_src/exla/exla_client.cc +++ b/exla/c_src/exla/exla_client.cc @@ -4,6 +4,7 @@ #include #include "exla_nif_util.h" #include "xla/layout_util.h" +#include "xla/pjrt/cpu/cpu_client.h" #include "xla/pjrt/gpu/gpu_helpers.h" #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #include "xla/pjrt/pjrt_api.h" @@ -24,7 +25,7 @@ void CopyLiteralToBinary(xla::Literal* literal, ErlNifBinary* binary, exla::int6 std::memcpy(binary->data, literal->untyped_data(), size); } -xla::StatusOr ExlaBuffer::ToBinary(ErlNifEnv* env, exla::int64 size) { +tsl::StatusOr ExlaBuffer::ToBinary(ErlNifEnv* env, exla::int64 size) { EXLA_ASSIGN_OR_RETURN(std::shared_ptr literal, buffer_->ToLiteralSync()); exla::int64 actual_size = literal->size_bytes(); @@ -37,7 +38,7 @@ xla::StatusOr ExlaBuffer::ToBinary(ErlNifEnv* env, exla::int64 siz return binary_term; } -xla::Status ExlaBuffer::Deallocate() { +tsl::Status ExlaBuffer::Deallocate() { if (buffer_->IsDeleted()) { return xla::FailedPrecondition("Attempt to deallocate already deallocated buffer."); } else { @@ -46,9 +47,11 @@ xla::Status ExlaBuffer::Deallocate() { } } -xla::StatusOr> ExlaBuffer::CopyToDevice(xla::PjRtDevice* dst_device) { +tsl::StatusOr> ExlaBuffer::CopyToDevice(xla::PjRtDevice* dst_device) { + EXLA_ASSIGN_OR_RETURN(auto memory_space, + dst_device->default_memory_space()); EXLA_ASSIGN_OR_RETURN(std::unique_ptr buf, - buffer_->CopyToDevice(dst_device)); + buffer_->CopyToMemorySpace(memory_space)); return fine::make_resource(std::move(buf)); } @@ -58,7 +61,7 @@ ExlaExecutable::ExlaExecutable(std::unique_ptr execut fingerprint_(std::move(fingerprint)), client_(client) {} -xla::StatusOr> PjRtBufferFromBinary(xla::PjRtClient* client, +tsl::StatusOr> PjRtBufferFromBinary(xla::PjRtClient* client, ERL_NIF_TERM source_term, const xla::Shape& shape, int device_id) { @@ -75,16 +78,17 @@ xla::StatusOr> PjRtBufferFromBinary(xla::PjRtCl std::function on_done_with_host_buffer = [copy_env]() { enif_free_env(copy_env); }; EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client->LookupDevice(xla::PjRtGlobalDeviceId(device_id))); + EXLA_ASSIGN_OR_RETURN(auto memory_space, device->default_memory_space()); // Passing std::nullopt should work, but it fails for subbyte types, // so we build the default strides. See https://github.com/openxla/xla/issues/16795 auto byte_strides = xla::ShapeUtil::ByteStrides(shape); EXLA_ASSIGN_OR_RETURN(auto buffer, client->BufferFromHostBuffer( - binary.data, shape.element_type(), shape.dimensions(), byte_strides, semantics, on_done_with_host_buffer, device)); + binary.data, shape.element_type(), shape.dimensions(), byte_strides, semantics, on_done_with_host_buffer, memory_space, /*device_layout=*/nullptr)); return std::move(buffer); } -xla::StatusOr>> +tsl::StatusOr>> UnpackRunArguments(ErlNifEnv* env, ExlaExecutable::RunArguments arguments, std::vector> &transient_buffers, @@ -150,7 +154,7 @@ ExlaExecutable::RunResult UnpackResult(ErlNifEnv* env, int64_t device = device_id >= 0 ? device_id : device_assignment(i, 0); for (auto& pjrt_buf : result.at(i)) { - pjrt_buf->BlockHostUntilReady(); + pjrt_buf->GetReadyFuture().Await(); auto result = fine::make_resource(std::move(pjrt_buf)); replica_results.push_back(result); } @@ -161,7 +165,7 @@ ExlaExecutable::RunResult UnpackResult(ErlNifEnv* env, return per_replica_results; } -xla::StatusOr ExlaExecutable::Run(ErlNifEnv* env, +tsl::StatusOr ExlaExecutable::Run(ErlNifEnv* env, ExlaExecutable::RunArguments arguments, int device_id) { xla::ExecuteOptions options; @@ -268,14 +272,14 @@ xla::StatusOr ExlaExecutable::Run(ErlNifEnv* env, ExlaClient::ExlaClient(std::shared_ptr client) : client_(std::move(client)) {} -xla::StatusOr> ExlaClient::BufferFromBinary(ERL_NIF_TERM source_term, +tsl::StatusOr> ExlaClient::BufferFromBinary(ERL_NIF_TERM source_term, xla::Shape& shape, int device_id) { EXLA_ASSIGN_OR_RETURN(auto buffer, PjRtBufferFromBinary(client(), source_term, shape, device_id)); return fine::make_resource(std::move(buffer)); } -xla::StatusOr> ExecutableFingerprint(std::unique_ptr& executable) { +tsl::StatusOr> ExecutableFingerprint(std::unique_ptr& executable) { auto fingerprint = executable->FingerprintExecutable(); if (fingerprint.ok()) { @@ -288,9 +292,9 @@ xla::StatusOr> ExecutableFingerprint(std::unique_ptr< } } -xla::StatusOr> ExlaClient::DeserializeExecutable(std::string deserialized_executable) { +tsl::StatusOr> ExlaClient::DeserializeExecutable(std::string deserialized_executable) { EXLA_ASSIGN_OR_RETURN(std::unique_ptr executable, - client_->DeserializeExecutable(deserialized_executable, std::nullopt)); + client_->LoadSerializedExecutable(deserialized_executable, std::nullopt, xla::LoadOptions())); EXLA_ASSIGN_OR_RETURN(absl::optional fingerprint, ExecutableFingerprint(executable)); @@ -298,7 +302,7 @@ xla::StatusOr> ExlaClient::DeserializeExecutab return fine::make_resource(std::move(executable), std::move(fingerprint), this); } -xla::StatusOr> ExlaClient::Compile(mlir::ModuleOp module, +tsl::StatusOr> ExlaClient::Compile(mlir::ModuleOp module, std::vector argument_layouts, xla::ExecutableBuildOptions& options, bool compile_portable_executable) { @@ -317,14 +321,14 @@ xla::StatusOr> ExlaClient::Compile(mlir::Modul compile_opts.compile_portable_executable = compile_portable_executable; EXLA_ASSIGN_OR_RETURN(std::unique_ptr executable, - client_->Compile(module, std::move(compile_opts))); + client_->CompileAndLoad(module, std::move(compile_opts))); EXLA_ASSIGN_OR_RETURN(absl::optional fingerprint, ExecutableFingerprint(executable)); return fine::make_resource(std::move(executable), std::move(fingerprint), this); } -xla::Status ExlaClient::TransferToInfeed(ErlNifEnv* env, +tsl::Status ExlaClient::TransferToInfeed(ErlNifEnv* env, std::vector buffer_bins, std::vector shapes, int device_id) { @@ -361,12 +365,12 @@ xla::Status ExlaClient::TransferToInfeed(ErlNifEnv* env, EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client_->LookupDevice(xla::PjRtGlobalDeviceId(device_id))); - xla::Status status = device->TransferToInfeed(literal); + tsl::Status status = device->TransferToInfeed(literal); return status; } -xla::StatusOr ExlaClient::TransferFromOutfeed(ErlNifEnv* env, int device_id, xla::Shape& shape) { +tsl::StatusOr ExlaClient::TransferFromOutfeed(ErlNifEnv* env, int device_id, xla::Shape& shape) { EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client_->LookupDevice(xla::PjRtGlobalDeviceId(device_id))); auto literal = std::make_shared(shape); @@ -386,14 +390,16 @@ xla::StatusOr ExlaClient::TransferFromOutfeed(ErlNifEnv* env, int return binary_term; } -xla::StatusOr> GetHostClient() { +tsl::StatusOr> GetHostClient() { + xla::CpuClientOptions options; + options.asynchronous = false; EXLA_ASSIGN_OR_RETURN(std::unique_ptr client, - xla::GetTfrtCpuClient(false)); + xla::GetXlaPjrtCpuClient(options)); return fine::make_resource(std::move(client)); } -xla::StatusOr> GetGpuClient(double memory_fraction, +tsl::StatusOr> GetGpuClient(double memory_fraction, bool preallocate, xla::GpuAllocatorConfig::Kind kind) { xla::GpuAllocatorConfig allocator_config = { @@ -410,13 +416,13 @@ xla::StatusOr> GetGpuClient(double memory_fraction return fine::make_resource(std::move(client)); } -xla::StatusOr> GetTpuClient() { +tsl::StatusOr> GetTpuClient() { auto statusor = pjrt::LoadPjrtPlugin("tpu", "libtpu.so"); if (!statusor.ok()) { return statusor.status(); } - xla::Status status = pjrt::InitializePjrtPlugin("tpu"); + tsl::Status status = pjrt::InitializePjrtPlugin("tpu"); if (!status.ok()) { return status; @@ -428,7 +434,7 @@ xla::StatusOr> GetTpuClient() { return fine::make_resource(std::move(client)); } -xla::StatusOr> GetCApiClient(std::string device_type) { +tsl::StatusOr> GetCApiClient(std::string device_type) { EXLA_ASSIGN_OR_RETURN(std::unique_ptr client, xla::GetCApiClient(device_type)); diff --git a/exla/c_src/exla/exla_client.h b/exla/c_src/exla/exla_client.h index 0dcc0842cb..fbab1d51e1 100644 --- a/exla/c_src/exla/exla_client.h +++ b/exla/c_src/exla/exla_client.h @@ -28,15 +28,15 @@ class ExlaBuffer { int device_id() { return buffer_->device()->id(); } xla::PjRtBuffer* buffer() { return buffer_.get(); } - xla::StatusOr> CopyToDevice(xla::PjRtDevice* dst_device); - xla::StatusOr ToBinary(ErlNifEnv* env, exla::int64 size); - xla::Status Deallocate(); + tsl::StatusOr> CopyToDevice(xla::PjRtDevice* dst_device); + tsl::StatusOr ToBinary(ErlNifEnv* env, exla::int64 size); + tsl::Status Deallocate(); - xla::StatusOr GetDevicePointer(xla::PjRtClient* client) { + tsl::StatusOr GetDevicePointer(xla::PjRtClient* client) { return client->UnsafeBufferPointer(buffer_.get()); } - xla::StatusOr GetOnDeviceSizeInBytes() { + tsl::StatusOr GetOnDeviceSizeInBytes() { return buffer_.get()->GetOnDeviceSizeInBytes(); } @@ -63,9 +63,9 @@ class ExlaExecutable { xla::PjRtLoadedExecutable* executable() { return executable_.get(); } - xla::StatusOr Run(ErlNifEnv* env, RunArguments arguments, int device_id); + tsl::StatusOr Run(ErlNifEnv* env, RunArguments arguments, int device_id); - xla::StatusOr SerializeExecutable() { return executable_->SerializeExecutable(); } + tsl::StatusOr SerializeExecutable() { return executable_->SerializeExecutable(); } private: std::unique_ptr executable_; @@ -83,38 +83,38 @@ class ExlaClient { // Compiles the given computation with the given compile options - xla::StatusOr> Compile(mlir::ModuleOp computation, + tsl::StatusOr> Compile(mlir::ModuleOp computation, std::vector argument_layouts, xla::ExecutableBuildOptions& options, bool compile_portable_executable); - xla::StatusOr> BufferFromBinary(ERL_NIF_TERM binary_term, + tsl::StatusOr> BufferFromBinary(ERL_NIF_TERM binary_term, xla::Shape& shape, int device_id); - xla::StatusOr> DeserializeExecutable(std::string serialized_executable); + tsl::StatusOr> DeserializeExecutable(std::string serialized_executable); // TODO(seanmor5): This is device logic and should be refactored - xla::Status TransferToInfeed(ErlNifEnv* env, + tsl::Status TransferToInfeed(ErlNifEnv* env, std::vector buffer_bins, std::vector shapes, int device_id); - xla::StatusOr TransferFromOutfeed(ErlNifEnv* env, int device_id, xla::Shape& shape); + tsl::StatusOr TransferFromOutfeed(ErlNifEnv* env, int device_id, xla::Shape& shape); private: std::shared_ptr client_; }; -xla::StatusOr> GetHostClient(); +tsl::StatusOr> GetHostClient(); -xla::StatusOr> GetGpuClient(double memory_fraction, +tsl::StatusOr> GetGpuClient(double memory_fraction, bool preallocate, xla::GpuAllocatorConfig::Kind kind); -xla::StatusOr> GetTpuClient(); +tsl::StatusOr> GetTpuClient(); -xla::StatusOr> GetCApiClient(std::string device_type); +tsl::StatusOr> GetCApiClient(std::string device_type); } // namespace exla #endif diff --git a/exla/c_src/exla/exla_cuda.cc b/exla/c_src/exla/exla_cuda.cc index 6f5bbe7b97..f3851de5d5 100644 --- a/exla/c_src/exla/exla_cuda.cc +++ b/exla/c_src/exla/exla_cuda.cc @@ -27,7 +27,7 @@ std::optional get_cuda_ipc_handle(std::uintptr_t ptr) { std::optional get_pointer_for_ipc_handle(uint8_t* handle_bin, size_t handle_size, int device_id) { if (handle_size != sizeof(cudaIpcMemHandle_t)) { - return std::make_tuple(nullptr, 1); // Return with error status + return std::nullopt; } unsigned char ipc_handle_data[sizeof(cudaIpcMemHandle_t)]; diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 55b13ee9b0..626f2b4b15 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -404,8 +404,8 @@ defmodule EXLA.Defn do data: %Expr{ args: [ %{data: %{op: :eigh, args: [tensor, _opts]}}, - {%{type: {evec_type_kind, _}} = eigenvecs_expr, - %{type: {eval_type_kind, _}} = eigenvals_expr}, + {%{type: {evec_type_kind, _}} = eigenvals_expr, + %{type: {eval_type_kind, _}} = eigenvecs_expr}, _callback ] } @@ -429,14 +429,14 @@ defmodule EXLA.Defn do tensor end - {eigenvecs, eigenvals} = + {eigenvals, eigenvecs} = Value.eigh( tensor, - expr_to_typespec(%{eigenvecs_expr | type: out_type}), - expr_to_typespec(%{eigenvals_expr | type: out_type}) + expr_to_typespec(%{eigenvals_expr | type: out_type}), + expr_to_typespec(%{eigenvecs_expr | type: out_type}) ) - {[to_type(eigenvecs, eigenvecs_expr.type), to_type(eigenvals, eigenvals_expr.type)], cache} + {[to_type(eigenvals, eigenvals_expr.type), to_type(eigenvecs, eigenvecs_expr.type)], cache} end defp cached_recur_operator( diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 46baa95c8c..f955e67200 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -719,30 +719,11 @@ defmodule EXLA.MLIR.Value do op(func, "stablehlo.return", values, []) end - def eigh(%Value{function: func} = value, eigenvecs_typespec, eigenvals_typespec) do - %{type: op_type, shape: op_shape} = get_typespec(value) - %{type: eigenvecs_type, shape: eigenvecs_shape} = eigenvecs_typespec - %{type: eigenvals_type, shape: eigenvals_shape} = eigenvals_typespec + def eigh(%Value{function: func} = value, eigenvals_typespec, eigenvecs_typespec) do + %{type: op_type} = get_typespec(value) - dim_sizes = [tuple_size(op_shape), tuple_size(eigenvecs_shape), tuple_size(eigenvals_shape)] - operand_dims = Tuple.to_list(op_shape) - eigenvecs_dims = Tuple.to_list(eigenvecs_shape) - eigenvals_dims = Tuple.to_list(eigenvals_shape) - - dim_sizes = constant(func, dim_sizes, Typespec.tensor({:u, 64}, {length(dim_sizes)})) - operand_dims = constant(func, operand_dims, Typespec.tensor({:u, 64}, {length(operand_dims)})) - - eigenvecs_dims = - constant(func, eigenvecs_dims, Typespec.tensor({:u, 64}, {length(eigenvecs_dims)})) - - eigenvals_dims = - constant(func, eigenvals_dims, Typespec.tensor({:u, 64}, {length(eigenvals_dims)})) - - operands = [value, dim_sizes, operand_dims, eigenvecs_dims, eigenvals_dims] - - eigenvecs_result_type = type_tensor(eigenvecs_type, eigenvecs_shape) - eigenvals_result_type = type_tensor(eigenvals_type, eigenvals_shape) - result_types = [type_tuple([eigenvecs_result_type, eigenvals_result_type])] + operands = [value] + result_types = typespecs_to_mlir_types([eigenvals_typespec, eigenvecs_typespec]) call_target_name = case op_type do @@ -759,37 +740,20 @@ defmodule EXLA.MLIR.Value do attributes = [ call_target_name: attr_string(call_target_name), - backend_config: attr_string("Host") + api_version: attr_i32(4) ] - result = - op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) |> one!() - - eigenvecs = get_tuple_element(result, 0, eigenvecs_typespec) - eigenvals = get_tuple_element(result, 1, eigenvals_typespec) + [eigenvals, eigenvecs] = + op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) - {eigenvecs, eigenvals} + {eigenvals, eigenvecs} end def qr(%Value{function: func} = value, q_typespec, r_typespec) do - %{type: op_type, shape: op_shape} = get_typespec(value) - %{type: q_type, shape: q_shape} = q_typespec - %{type: r_type, shape: r_shape} = r_typespec - - dim_sizes = [tuple_size(op_shape), tuple_size(q_shape), tuple_size(r_shape)] - operand_dims = Tuple.to_list(op_shape) - q_dims = Tuple.to_list(q_shape) - r_dims = Tuple.to_list(r_shape) + %{type: op_type} = get_typespec(value) - dim_sizes = constant(func, dim_sizes, Typespec.tensor({:u, 64}, {length(dim_sizes)})) - operand_dims = constant(func, operand_dims, Typespec.tensor({:u, 64}, {length(operand_dims)})) - q_dims = constant(func, q_dims, Typespec.tensor({:u, 64}, {length(q_dims)})) - r_dims = constant(func, r_dims, Typespec.tensor({:u, 64}, {length(r_dims)})) - operands = [value, dim_sizes, operand_dims, q_dims, r_dims] - - q_result_type = type_tensor(q_type, q_shape) - r_result_type = type_tensor(r_type, r_shape) - result_types = [type_tuple([q_result_type, r_result_type])] + operands = [value] + result_types = typespecs_to_mlir_types([q_typespec, r_typespec]) call_target_name = case op_type do @@ -812,48 +776,23 @@ defmodule EXLA.MLIR.Value do attributes = [ call_target_name: attr_string(call_target_name), - backend_config: attr_string("Host") + api_version: attr_i32(4) ] - result = - op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) |> one!() - - q = get_tuple_element(result, 0, q_typespec) - r = get_tuple_element(result, 1, r_typespec) + [q, r] = + op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) {q, r} end def lu(%Value{function: func} = value, p_typespec, l_typespec, u_typespec) do - %{type: op_type, shape: op_shape} = get_typespec(value) - %{type: _p_type, shape: p_shape} = p_typespec - %{type: l_type, shape: l_shape} = l_typespec - %{type: u_type, shape: u_shape} = u_typespec - - dim_sizes = [ - tuple_size(op_shape), - tuple_size(p_shape), - tuple_size(l_shape), - tuple_size(u_shape) - ] - - operand_dims = Tuple.to_list(op_shape) - p_dims = Tuple.to_list(p_shape) - l_dims = Tuple.to_list(l_shape) - u_dims = Tuple.to_list(u_shape) + %{type: op_type} = get_typespec(value) - dim_sizes = constant(func, dim_sizes, Typespec.tensor({:u, 64}, {length(dim_sizes)})) - operand_dims = constant(func, operand_dims, Typespec.tensor({:u, 64}, {length(operand_dims)})) - p_dims = constant(func, p_dims, Typespec.tensor({:u, 64}, {length(p_dims)})) - l_dims = constant(func, l_dims, Typespec.tensor({:u, 64}, {length(l_dims)})) - u_dims = constant(func, u_dims, Typespec.tensor({:u, 64}, {length(u_dims)})) - operands = [value, dim_sizes, operand_dims, p_dims, l_dims, u_dims] + operands = [value] - # Force P to always b u8 to avoid requiring too many template instances during custom_call registration - p_result_type = type_tensor({:u, 8}, p_shape) - l_result_type = type_tensor(l_type, l_shape) - u_result_type = type_tensor(u_type, u_shape) - result_types = [type_tuple([p_result_type, l_result_type, u_result_type])] + # Force P to always be u8 to avoid requiring too many template instances during custom_call registration + u8_typespec = Typespec.to_type(p_typespec, {:u, 8}) + result_types = typespecs_to_mlir_types([u8_typespec, l_typespec, u_typespec]) call_target_name = case op_type do @@ -876,16 +815,13 @@ defmodule EXLA.MLIR.Value do attributes = [ call_target_name: attr_string(call_target_name), - backend_config: attr_string("Host") + api_version: attr_i32(4) ] - result = - op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) |> one!() - - # This is not the best approach, but the alternative would require many more template instances - u8_typespec = Typespec.to_type(p_typespec, {:u, 8}) - p = get_tuple_element(result, 0, u8_typespec) + [p, l, u] = + op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) + # Convert p to the requested type if necessary p = if u8_typespec != p_typespec do convert(p, p_typespec) @@ -893,9 +829,6 @@ defmodule EXLA.MLIR.Value do p end - l = get_tuple_element(result, 1, l_typespec) - u = get_tuple_element(result, 2, u_typespec) - {p, l, u} end @@ -965,10 +898,6 @@ defmodule EXLA.MLIR.Value do defp type_token(), do: "!stablehlo.token" - defp type_tuple(children) do - "tuple<#{Enum.join(children, ", ")}>" - end - defp number_literal(value, type) do cond do Nx.Type.complex?(type) -> diff --git a/exla/mix.exs b/exla/mix.exs index 2d1980bfd5..544f46d343 100644 --- a/exla/mix.exs +++ b/exla/mix.exs @@ -68,7 +68,7 @@ defmodule EXLA.MixProject do # {:nx, "~> 0.9.0"}, {:nx, path: "../nx"}, {:telemetry, "~> 0.4.0 or ~> 1.0"}, - {:xla, "~> 0.8.0", runtime: false}, + {:xla, "~> 0.9.0", runtime: false}, {:fine, "~> 0.1.0", runtime: false}, {:elixir_make, "~> 0.6", runtime: false}, {:benchee, "~> 1.0", only: :dev}, diff --git a/exla/mix.lock b/exla/mix.lock index de6ba787a3..768c68ded2 100644 --- a/exla/mix.lock +++ b/exla/mix.lock @@ -3,7 +3,7 @@ "complex": {:hex, :complex, "0.6.0", "b0130086a7a8c33574d293b2e0e250f4685580418eac52a5658a4bd148f3ccf1", [:mix], [], "hexpm", "0a5fa95580dcaf30fcd60fe1aaf24327c0fe401e98c24d892e172e79498269f9"}, "deep_merge": {:hex, :deep_merge, "1.0.0", "b4aa1a0d1acac393bdf38b2291af38cb1d4a52806cf7a4906f718e1feb5ee961", [:mix], [], "hexpm", "ce708e5f094b9cd4e8f2be4f00d2f4250c4095be93f8cd6d018c753894885430"}, "earmark_parser": {:hex, :earmark_parser, "1.4.41", "ab34711c9dc6212dda44fcd20ecb87ac3f3fce6f0ca2f28d4a00e4154f8cd599", [:mix], [], "hexpm", "a81a04c7e34b6617c2792e291b5a2e57ab316365c2644ddc553bb9ed863ebefa"}, - "elixir_make": {:hex, :elixir_make, "0.8.4", "4960a03ce79081dee8fe119d80ad372c4e7badb84c493cc75983f9d3bc8bde0f", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "6e7f1d619b5f61dfabd0a20aa268e575572b542ac31723293a4c1a567d5ef040"}, + "elixir_make": {:hex, :elixir_make, "0.9.0", "6484b3cd8c0cee58f09f05ecaf1a140a8c97670671a6a0e7ab4dc326c3109726", [:mix], [], "hexpm", "db23d4fd8b757462ad02f8aa73431a426fe6671c80b200d9710caf3d1dd0ffdb"}, "ex_doc": {:hex, :ex_doc, "0.34.2", "13eedf3844ccdce25cfd837b99bea9ad92c4e511233199440488d217c92571e8", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "5ce5f16b41208a50106afed3de6a2ed34f4acfd65715b82a0b84b49d995f95c1"}, "fine": {:hex, :fine, "0.1.0", "9bb99a5ff9b968f12c3b458fa1277c39e9a620b23a9439103703a25917293871", [:mix], [], "hexpm", "1d6485bf811b95dc6ae3d197c0e6f994880b86167a827983bb29cbfc03a02684"}, "makeup": {:hex, :makeup, "1.1.2", "9ba8837913bdf757787e71c1581c21f9d2455f4dd04cfca785c70bbfff1a76a3", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "cce1566b81fbcbd21eca8ffe808f33b221f9eee2cbc7a1706fc3da9ff18e6cac"}, @@ -14,5 +14,5 @@ "nx": {:hex, :nx, "0.9.0", "03a622a27d93eaaa2d24ff9b812d9f675cc04eb0340ca3dd065674f3642867d3", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "3810a5a90db0654b6e538430c0fb473a22bfc11b3d02ea7834db493cf3f56153"}, "statistex": {:hex, :statistex, "1.0.0", "f3dc93f3c0c6c92e5f291704cf62b99b553253d7969e9a5fa713e5481cd858a5", [:mix], [], "hexpm", "ff9d8bee7035028ab4742ff52fc80a2aa35cece833cf5319009b52f1b5a86c27"}, "telemetry": {:hex, :telemetry, "1.3.0", "fedebbae410d715cf8e7062c96a1ef32ec22e764197f70cda73d82778d61e7a2", [:rebar3], [], "hexpm", "7015fc8919dbe63764f4b4b87a95b7c0996bd539e0d499be6ec9d7f3875b79e6"}, - "xla": {:hex, :xla, "0.8.0", "fef314d085dd3ee16a0816c095239938f80769150e15db16dfaa435553d7cb16", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "739c61c8d93b97e12ba0369d10e76130224c208f1a76ad293e3581f056833e57"}, + "xla": {:hex, :xla, "0.9.0", "18a97b47746c371c6b5ac0ccf77155eaebff0c25c588cbd08f54be74eb95f862", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "a7f929c425ab624ff1c5b4aa049f7261c5c0a879d509ab1ff159dedcd62af944"}, } diff --git a/exla/test/exla/device_memory_sharing_test.exs b/exla/test/exla/device_memory_sharing_test.exs index 7ef3b165ef..25348ec83a 100644 --- a/exla/test/exla/device_memory_sharing_test.exs +++ b/exla/test/exla/device_memory_sharing_test.exs @@ -28,12 +28,13 @@ defmodule EXLA.DeviceMemorySharingTest do @tag :cuda_required test "invalid ipc handles don't crash the runtime" do - assert {:error, ~c"Unable to get pointer for IPC handle."} == - Nx.from_pointer( - {EXLA.Backend, client: :cuda}, - %Nx.Pointer{handle: "#{System.unique_integer()}", kind: :ipc, data_size: 4}, - {:f, 32}, - {1} - ) + assert_raise RuntimeError, "unable to get pointer for IPC handle", fn -> + Nx.from_pointer( + {EXLA.Backend, client: :cuda}, + %Nx.Pointer{handle: "#{System.unique_integer()}", kind: :ipc, data_size: 4}, + {:f, 32}, + {1} + ) + end end end diff --git a/nx/lib/nx/binary_backend.ex b/nx/lib/nx/binary_backend.ex index cae7c94997..974d558b0d 100644 --- a/nx/lib/nx/binary_backend.ex +++ b/nx/lib/nx/binary_backend.ex @@ -522,11 +522,13 @@ defmodule Nx.BinaryBackend do right_batch_item_bits = right_batch_item_length * right_size <<_::bitstring-size(^left_offset_bits), - left_batch_item_binary::bitstring-size(^left_batch_item_bits), _::bitstring>> = + left_batch_item_binary::bitstring-size(^left_batch_item_bits), + _::bitstring>> = left_binary <<_::bitstring-size(^right_offset_bits), - right_batch_item_binary::bitstring-size(^right_batch_item_bits), _::bitstring>> = + right_batch_item_binary::bitstring-size(^right_batch_item_bits), + _::bitstring>> = right_binary bin_dot( @@ -1756,7 +1758,8 @@ defmodule Nx.BinaryBackend do before_slice_size = current - previous <> = + current_bitstring::bitstring-size(^target_chunk), + to_traverse::bitstring>> = to_traverse updated_elements =