Skip to content

Update XLA #1614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jun 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ jobs:
fail-fast: false
matrix:
working_directory: ["nx", "exla", "torchx"]
elixir: ["1.15.4", "1.16.2"]
otp: ["25.3"]
elixir: ["1.15.8", "1.18.4"]
include:
- elixir: "1.16.2"
- elixir: "1.15.8"
otp: "25.3"
- elixir: "1.18.4"
otp: "27.3"
lint: true
defaults:
run:
Expand Down Expand Up @@ -57,8 +59,9 @@ jobs:
fail-fast: false
matrix:
working_directory: ["nx", "torchx"]
elixir: ["1.16.2"]
otp: ["25.2"]
include:
- elixir: "1.18.4"
otp: "27.3"
defaults:
run:
working-directory: ${{ matrix.working_directory }}
Expand Down
10 changes: 5 additions & 5 deletions exla/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ EXLA_CACHE_SO_LINK_PATH = $(CWD_RELATIVE_TO_PRIV_PATH)/$(EXLA_CACHE_SO)
# Note that XLA requires c++17, Fine as well
CFLAGS += -fPIC -I$(ERTS_INCLUDE_DIR) -I$(FINE_INCLUDE_DIR) -I$(XLA_INCLUDE_PATH) -Wall -Wno-sign-compare \
-Wno-unused-parameter -Wno-missing-field-initializers -Wno-comment \
-std=c++17 -w -DLLVM_VERSION_STRING=
-std=c++17 -w

ifdef DEBUG
CFLAGS += -g
else
CFLAGS += -O3
endif

NVCC := $(CXX)
NVCC = $(CXX)
NVCCFLAGS = $(CFLAGS)
LDFLAGS += -L$(XLA_EXTENSION_LIB) -lxla_extension -shared -fvisibility=hidden

Expand All @@ -48,8 +48,8 @@ $(info EXLA_CPU_ONLY is not set, checking for nvcc availability)

ifeq ($(NVCC_TEST),nvcc)
$(info CUDA is available.)
NVCC := nvcc
NVCCFLAGS += -DCUDA_ENABLED
NVCC = nvcc
NVCCFLAGS = -Xcompiler "$(CFLAGS)" -DCUDA_ENABLED
else
$(info CUDA is not available.)
endif
Expand Down Expand Up @@ -82,7 +82,7 @@ $(EXLA_SO): $(EXLA_CACHE_SO)
ln -sf $(EXLA_CACHE_SO_LINK_PATH) $(EXLA_SO) ; \
fi

SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/custom_calls.cc $(EXLA_DIR)/ipc.cc
SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/ipc.cc
SOURCES += $(wildcard $(EXLA_DIR)/custom_calls/*.cc)
HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls/qr.h $(EXLA_DIR)/custom_calls/eigh.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h $(EXLA_DIR)/ipc.h
OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(SOURCES)) $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o
Expand Down
23 changes: 0 additions & 23 deletions exla/c_src/exla/custom_calls.cc

This file was deleted.

59 changes: 22 additions & 37 deletions exla/c_src/exla/custom_calls/eigh.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
#pragma once

#include "Eigen/Eigenvalues"

#include <algorithm>
#include <iostream>
#include <numeric>
#include <vector>

#include "Eigen/Eigenvalues"
#include "xla/ffi/api/ffi.h"
#include "xla/ffi/ffi_api.h"

namespace ffi = xla::ffi;

template <typename DataType>
void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out,
DataType *eigenvectors_out,
Expand Down Expand Up @@ -55,51 +59,32 @@ void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out,
m * n * sizeof(DataType));
}

template <typename DataType>
void eigh_cpu_custom_call(void *out[], const void *in[]) {
DataType *operand = (DataType *)in[0];

uint64_t *dim_sizes = (uint64_t *)in[1];
uint64_t num_operand_dims = dim_sizes[0];
uint64_t num_eigenvalues_dims = dim_sizes[1];
uint64_t num_eigenvectors_dims = dim_sizes[2];

uint64_t *operand_dims_ptr = (uint64_t *)in[2];
std::vector<uint64_t> operand_dims(operand_dims_ptr,
operand_dims_ptr + num_operand_dims);

uint64_t *eigenvalues_dims_ptr = (uint64_t *)in[3];
std::vector<uint64_t> eigenvalues_dims(
eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims);

uint64_t *eigenvectors_dims_ptr = (uint64_t *)in[4];
std::vector<uint64_t> eigenvectors_dims(
eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims);
template <typename DataType, typename BufferType>
ffi::Error eigh_cpu_custom_call_impl(BufferType operand,
ffi::Result<BufferType> eigenvalues,
ffi::Result<BufferType> eigenvectors) {
auto operand_dims = operand.dimensions();
auto eigenvalues_dims = eigenvalues->dimensions();
auto eigenvectors_dims = eigenvectors->dimensions();

uint64_t m = eigenvectors_dims[eigenvectors_dims.size() - 2];
uint64_t n = eigenvectors_dims[eigenvectors_dims.size() - 1];

auto leading_dimensions =
std::vector<uint64_t>(operand_dims.begin(), operand_dims.end() - 2);

uint64_t batch_items = 1;
for (uint64_t i = 0; i < leading_dimensions.size(); i++) {
batch_items *= leading_dimensions[i];
for (auto it = operand_dims.begin(); it != operand_dims.end() - 2; it++) {
batch_items *= *it;
}

DataType *eigenvalues = (DataType *)out[0];
DataType *eigenvectors = (DataType *)out[1];

uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1];
uint64_t eigenvectors_stride =
eigenvectors_dims[eigenvectors_dims.size() - 1] *
eigenvectors_dims[eigenvectors_dims.size() - 2];
uint64_t eigenvectors_stride = m * n;
uint64_t inner_stride = m * n;

for (uint64_t i = 0; i < batch_items; i++) {
single_matrix_eigh_cpu_custom_call<DataType>(
eigenvalues + i * eigenvalues_stride,
eigenvectors + i * eigenvectors_stride, operand + i * inner_stride, m,
n);
eigenvalues->typed_data() + i * eigenvalues_stride,
eigenvectors->typed_data() + i * eigenvectors_stride,
operand.typed_data() + i * inner_stride, m, n);
}
}

return ffi::Error::Success();
}
18 changes: 16 additions & 2 deletions exla/c_src/exla/custom_calls/eigh_f32.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
#include "eigh.h"

void eigh_cpu_custom_call_f32(void *out[], const void *in[]) {
eigh_cpu_custom_call<float>(out, in);
ffi::Error
eigh_cpu_custom_call_f32_impl(ffi::Buffer<ffi::F32> operand,
ffi::ResultBuffer<ffi::F32> eigenvalues,
ffi::ResultBuffer<ffi::F32> eigenvectors) {
return eigh_cpu_custom_call_impl<float, ffi::Buffer<ffi::F32>>(
operand, eigenvalues, eigenvectors);
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(eigh_cpu_custom_call_f32,
eigh_cpu_custom_call_f32_impl,
ffi::Ffi::Bind()
.Arg<ffi::Buffer<ffi::F32>>()
.Ret<ffi::Buffer<ffi::F32>>()
.Ret<ffi::Buffer<ffi::F32>>());

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "eigh_cpu_custom_call_f32",
"Host", eigh_cpu_custom_call_f32);
18 changes: 16 additions & 2 deletions exla/c_src/exla/custom_calls/eigh_f64.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
#include "eigh.h"

void eigh_cpu_custom_call_f64(void *out[], const void *in[]) {
eigh_cpu_custom_call<double>(out, in);
ffi::Error
eigh_cpu_custom_call_f64_impl(ffi::Buffer<ffi::F64> operand,
ffi::ResultBuffer<ffi::F64> eigenvalues,
ffi::ResultBuffer<ffi::F64> eigenvectors) {
return eigh_cpu_custom_call_impl<double, ffi::Buffer<ffi::F64>>(
operand, eigenvalues, eigenvectors);
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(eigh_cpu_custom_call_f64,
eigh_cpu_custom_call_f64_impl,
ffi::Ffi::Bind()
.Arg<ffi::Buffer<ffi::F64>>()
.Ret<ffi::Buffer<ffi::F64>>()
.Ret<ffi::Buffer<ffi::F64>>());

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "eigh_cpu_custom_call_f64",
"Host", eigh_cpu_custom_call_f64);
74 changes: 33 additions & 41 deletions exla/c_src/exla/custom_calls/lu.h
Original file line number Diff line number Diff line change
@@ -1,16 +1,30 @@
#pragma once

#include "Eigen/LU";
#include <algorithm>
#include <iostream>
#include <numeric>
#include <vector>

#include "Eigen/LU"
#include "xla/ffi/api/ffi.h"
#include "xla/ffi/ffi_api.h"

namespace ffi = xla::ffi;

template <typename DataType>
void single_matrix_lu_cpu_custom_call(uint8_t *p_out, DataType *l_out, DataType *u_out, DataType *in, uint64_t n) {
typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> RowMajorMatrix;
void single_matrix_lu_cpu_custom_call(uint8_t *p_out, DataType *l_out,
DataType *u_out, DataType *in,
uint64_t n) {
typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic,
Eigen::RowMajor>
RowMajorMatrix;

Eigen::Map<RowMajorMatrix> input(in, n, n);
Eigen::PartialPivLU<RowMajorMatrix> lu = input.partialPivLu();

// Get the permutation matrix P and convert to indices
Eigen::PermutationMatrix<Eigen::Dynamic, Eigen::Dynamic> P = lu.permutationP();
Eigen::PermutationMatrix<Eigen::Dynamic, Eigen::Dynamic> P =
lu.permutationP();
for (uint64_t i = 0; i < n; i++) {
for (uint64_t j = 0; j < n; j++) {
p_out[i * n + j] = static_cast<uint8_t>(P.indices()[i] == j ? 1 : 0);
Expand All @@ -24,7 +38,6 @@ void single_matrix_lu_cpu_custom_call(uint8_t *p_out, DataType *l_out, DataType
// Copy L matrix
for (uint64_t i = 0; i < n; i++) {
for (uint64_t j = 0; j < n; j++) {

if (j < i) {
l_out[i * n + j] = static_cast<DataType>(L(i, j));
} else if (j == i) {
Expand All @@ -47,49 +60,28 @@ void single_matrix_lu_cpu_custom_call(uint8_t *p_out, DataType *l_out, DataType
}
}

template <typename DataType>
void lu_cpu_custom_call(void *out[], const void *in[]) {
DataType *operand = (DataType *)in[0];

uint64_t *dim_sizes = (uint64_t *)in[1];
uint64_t num_operand_dims = dim_sizes[0];
uint64_t num_p_dims = dim_sizes[1];
uint64_t num_l_dims = dim_sizes[2];
uint64_t num_u_dims = dim_sizes[3];

uint64_t *operand_dims_ptr = (uint64_t *)in[2];
std::vector<uint64_t> operand_dims(operand_dims_ptr, operand_dims_ptr + num_operand_dims);

uint64_t *p_dims_ptr = (uint64_t *)in[3];
std::vector<uint64_t> p_dims(p_dims_ptr, p_dims_ptr + num_p_dims);

uint64_t *l_dims_ptr = (uint64_t *)in[4];
std::vector<uint64_t> l_dims(l_dims_ptr, l_dims_ptr + num_l_dims);

uint64_t *u_dims_ptr = (uint64_t *)in[5];
std::vector<uint64_t> u_dims(u_dims_ptr, u_dims_ptr + num_u_dims);

template <typename DataType, typename BufferType>
ffi::Error
lu_cpu_custom_call_impl(BufferType operand, ffi::Result<ffi::Buffer<ffi::U8>> p,
ffi::Result<BufferType> l, ffi::Result<BufferType> u) {
auto operand_dims = operand.dimensions();
auto l_dims = l->dimensions();
uint64_t n = l_dims[l_dims.size() - 1];

auto leading_dimensions = std::vector<uint64_t>(operand_dims.begin(), operand_dims.end() - 2);

uint64_t batch_items = 1;
for (uint64_t i = 0; i < leading_dimensions.size(); i++) {
batch_items *= leading_dimensions[i];
for (auto it = operand_dims.begin(); it != operand_dims.end() - 2; it++) {
batch_items *= *it;
}

uint8_t *p = (uint8_t *)out[0];
DataType *l = (DataType *)out[1];
DataType *u = (DataType *)out[2];

uint64_t stride = n * n;

for (uint64_t i = 0; i < batch_items; i++) {
single_matrix_lu_cpu_custom_call<DataType>(
p + i * stride,
l + i * stride,
u + i * stride,
operand + i * stride,
n);
p->typed_data() + i * stride,
(DataType *)l->untyped_data() + i * stride,
(DataType *)u->untyped_data() + i * stride,
(DataType *)operand.untyped_data() + i * stride, n);
}
}

return ffi::Error::Success();
}
21 changes: 18 additions & 3 deletions exla/c_src/exla/custom_calls/lu_bf16.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
#include "lu.h"
#include "../exla_types.h"
#include "lu.h"

void lu_cpu_custom_call_bf16(void *out[], const void *in[]) {
lu_cpu_custom_call<exla::bfloat16>(out, in);
ffi::Error lu_cpu_custom_call_bf16_impl(ffi::Buffer<ffi::BF16> operand,
ffi::ResultBuffer<ffi::U8> p,
ffi::ResultBuffer<ffi::BF16> l,
ffi::ResultBuffer<ffi::BF16> u) {
return lu_cpu_custom_call_impl<exla::bfloat16, ffi::Buffer<ffi::BF16>>(
operand, p, l, u);
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(lu_cpu_custom_call_bf16,
lu_cpu_custom_call_bf16_impl,
ffi::Ffi::Bind()
.Arg<ffi::Buffer<ffi::BF16>>()
.Ret<ffi::Buffer<ffi::U8>>()
.Ret<ffi::Buffer<ffi::BF16>>()
.Ret<ffi::Buffer<ffi::BF16>>());

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "lu_cpu_custom_call_bf16", "Host",
lu_cpu_custom_call_bf16);
21 changes: 18 additions & 3 deletions exla/c_src/exla/custom_calls/lu_f16.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
#include "lu.h"
#include "../exla_types.h"
#include "lu.h"

void lu_cpu_custom_call_f16(void *out[], const void *in[]) {
lu_cpu_custom_call<exla::float16>(out, in);
ffi::Error lu_cpu_custom_call_f16_impl(ffi::Buffer<ffi::F16> operand,
ffi::ResultBuffer<ffi::U8> p,
ffi::ResultBuffer<ffi::F16> l,
ffi::ResultBuffer<ffi::F16> u) {
return lu_cpu_custom_call_impl<exla::float16, ffi::Buffer<ffi::F16>>(operand,
p, l, u);
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(lu_cpu_custom_call_f16,
lu_cpu_custom_call_f16_impl,
ffi::Ffi::Bind()
.Arg<ffi::Buffer<ffi::F16>>()
.Ret<ffi::Buffer<ffi::U8>>()
.Ret<ffi::Buffer<ffi::F16>>()
.Ret<ffi::Buffer<ffi::F16>>());

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "lu_cpu_custom_call_f16", "Host",
lu_cpu_custom_call_f16);
Loading