Skip to content

A way to use abort_callback with the cpu backend #725

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/whisper/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,11 @@ static bool ggml_graph_compute_helper(
struct ggml_cgraph * graph,
std::vector<uint8_t> & buf,
int n_threads,
whisper_abort_callback abort_callback,
ggml_abort_callback abort_callback,
void * abort_callback_data) {
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);

plan.abort_callback = abort_callback;
plan.abort_callback = abort_callback;
plan.abort_callback_data = abort_callback_data;

if (plan.work_size > 0) {
Expand Down Expand Up @@ -2130,7 +2130,7 @@ static bool whisper_encode_internal(
whisper_state & wstate,
const int mel_offset,
const int n_threads,
whisper_abort_callback abort_callback,
ggml_abort_callback abort_callback,
void * abort_callback_data) {
const int64_t t_start_us = ggml_time_us();

Expand Down Expand Up @@ -2561,7 +2561,7 @@ static bool whisper_decode_internal(
whisper_state & wstate,
const whisper_batch & batch,
const int n_threads,
whisper_abort_callback abort_callback,
ggml_abort_callback abort_callback,
void * abort_callback_data) {
const int64_t t_start_us = ggml_time_us();

Expand Down
7 changes: 1 addition & 6 deletions examples/whisper/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,11 +411,6 @@ extern "C" {
// If it returns false, the computation is aborted
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);

// Abort callback
// If not NULL, called before ggml computation
// If it returns true, the computation is aborted
typedef bool (*whisper_abort_callback)(void * user_data);

// Logits filter callback
// Can be used to modify the logits before sampling
// If not NULL, called after applying temperature to logits
Expand Down Expand Up @@ -512,7 +507,7 @@ extern "C" {
void * encoder_begin_callback_user_data;

// called each time before ggml computation starts
whisper_abort_callback abort_callback;
ggml_abort_callback abort_callback;
void * abort_callback_user_data;

// called by each decoder to filter obtained logits
Expand Down
5 changes: 3 additions & 2 deletions include/ggml/ggml-backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ extern "C" {

GGML_API ggml_backend_t ggml_backend_cpu_init(void);

GGML_API GGML_CALL bool ggml_backend_is_cpu (ggml_backend_t backend);
GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads);
GGML_API GGML_CALL bool ggml_backend_is_cpu (ggml_backend_t backend);
GGML_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads);
GGML_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);

// Create a backend buffer from an existing pointer
GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
Expand Down
9 changes: 7 additions & 2 deletions include/ggml/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,11 @@ extern "C" {

static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);

// Abort callback
// If not NULL, called before ggml computation
// If it returns true, the computation is aborted
typedef bool (*ggml_abort_callback)(void * data);

// the compute plan that needs to be prepared for ggml_graph_compute()
// since https://github.com/ggerganov/ggml/issues/287
struct ggml_cplan {
Expand All @@ -576,8 +581,8 @@ extern "C" {
int n_threads;

// abort ggml_graph_compute when true
bool (*abort_callback)(void * data);
void * abort_callback_data;
ggml_abort_callback abort_callback;
void * abort_callback_data;
};

enum ggml_cgraph_eval_order {
Expand Down
26 changes: 22 additions & 4 deletions src/ggml-backend.c
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,9 @@ struct ggml_backend_cpu_context {
int n_threads;
void * work_data;
size_t work_size;

ggml_abort_callback abort_callback;
void * abort_callback_data;
};

GGML_CALL static const char * ggml_backend_cpu_name(ggml_backend_t backend) {
Expand Down Expand Up @@ -691,6 +694,9 @@ GGML_CALL static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(gg
cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size);
}

cpu_plan->cplan.abort_callback = cpu_ctx->abort_callback;
cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data;

return cpu_plan;
}

Expand Down Expand Up @@ -721,9 +727,11 @@ GGML_CALL static bool ggml_backend_cpu_graph_compute(ggml_backend_t backend, str
cpu_ctx->work_data = realloc(cpu_ctx->work_data, cplan.work_size);
cpu_ctx->work_size = cplan.work_size;
}

cplan.work_data = cpu_ctx->work_data;

cplan.abort_callback = cpu_ctx->abort_callback;
cplan.abort_callback_data = cpu_ctx->abort_callback_data;

ggml_graph_compute(cgraph, &cplan);
return true;
}
Expand Down Expand Up @@ -759,9 +767,11 @@ static struct ggml_backend_i cpu_backend_i = {
ggml_backend_t ggml_backend_cpu_init(void) {
struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context));

ctx->n_threads = GGML_DEFAULT_N_THREADS;
ctx->work_data = NULL;
ctx->work_size = 0;
ctx->n_threads = GGML_DEFAULT_N_THREADS;
ctx->work_data = NULL;
ctx->work_size = 0;
ctx->abort_callback = NULL;
ctx->abort_callback_data = NULL;

ggml_backend_t cpu_backend = malloc(sizeof(struct ggml_backend));

Expand All @@ -783,6 +793,14 @@ void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
ctx->n_threads = n_threads;
}

void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data) {
GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));

struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
ctx->abort_callback = abort_callback;
ctx->abort_callback_data = abort_callback_data;
}

GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
return ggml_backend_buffer_init(ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size);
}
Expand Down
2 changes: 1 addition & 1 deletion src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -16560,7 +16560,7 @@ struct ggml_compute_state_shared {
atomic_int node_n; // active graph node
atomic_int node_task; // active graph node task phase

bool (*abort_callback)(void * data); // abort ggml_graph_compute when true
ggml_abort_callback abort_callback; // abort ggml_graph_compute when true
void * abort_callback_data;
};

Expand Down