Skip to content

metal : fix memory leak #2762

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ggml-metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

// max memory buffers that can be mapped to the device
#define GGML_METAL_MAX_BUFFERS 16
#define GGML_METAL_MAX_COMMAND_BUFFERS 32

struct ggml_tensor;
struct ggml_cgraph;
Expand Down
100 changes: 81 additions & 19 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,15 @@
struct ggml_metal_context {
int n_cb;

float * logits;

id<MTLDevice> device;
id<MTLCommandQueue> queue;
id<MTLLibrary> library;

id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];

dispatch_queue_t d_queue;

int n_buffers;
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];

Expand Down Expand Up @@ -114,12 +117,13 @@ @implementation GGMLMetalClass

struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));

ctx->n_cb = n_cb;
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
ctx->device = MTLCreateSystemDefaultDevice();
ctx->queue = [ctx->device newCommandQueue];
ctx->n_buffers = 0;
ctx->concur_list_len = 0;

ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);

#if 0
// compile from source string and show compile log
Expand Down Expand Up @@ -239,9 +243,67 @@ @implementation GGMLMetalClass

void ggml_metal_free(struct ggml_metal_context * ctx) {
fprintf(stderr, "%s: deallocating\n", __func__);
#define GGML_METAL_DEL_KERNEL(name) \
[ctx->function_##name release]; \
[ctx->pipeline_##name release];

GGML_METAL_DEL_KERNEL(add);
GGML_METAL_DEL_KERNEL(add_row);
GGML_METAL_DEL_KERNEL(mul);
GGML_METAL_DEL_KERNEL(mul_row);
GGML_METAL_DEL_KERNEL(scale);
GGML_METAL_DEL_KERNEL(silu);
GGML_METAL_DEL_KERNEL(relu);
GGML_METAL_DEL_KERNEL(gelu);
GGML_METAL_DEL_KERNEL(soft_max);
GGML_METAL_DEL_KERNEL(diag_mask_inf);
GGML_METAL_DEL_KERNEL(get_rows_f16);
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
GGML_METAL_DEL_KERNEL(get_rows_q8_0);
GGML_METAL_DEL_KERNEL(get_rows_q2_K);
GGML_METAL_DEL_KERNEL(get_rows_q3_K);
GGML_METAL_DEL_KERNEL(get_rows_q4_K);
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
GGML_METAL_DEL_KERNEL(rms_norm);
GGML_METAL_DEL_KERNEL(norm);
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
GGML_METAL_DEL_KERNEL(rope);
GGML_METAL_DEL_KERNEL(alibi_f32);
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
GGML_METAL_DEL_KERNEL(cpy_f16_f16);

#undef GGML_METAL_DEL_KERNEL

for (int i = 0; i < ctx->n_buffers; ++i) {
[ctx->buffers[i].metal release];
}

[ctx->library release];
[ctx->queue release];
[ctx->device release];

dispatch_release(ctx->d_queue);

free(ctx);
}

Expand All @@ -261,7 +323,7 @@ void ggml_metal_host_free(void * data) {
}

void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
ctx->n_cb = n_cb;
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
}

int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
Expand Down Expand Up @@ -507,6 +569,8 @@ void ggml_metal_graph_compute(
struct ggml_cgraph * gf) {
metal_printf("%s: evaluating graph\n", __func__);

@autoreleasepool {

// if there is ctx->concur_list, dispatch concurrently
// else fallback to serial dispatch
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
Expand All @@ -521,29 +585,25 @@ void ggml_metal_graph_compute(

const int n_cb = ctx->n_cb;

NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];

for (int i = 0; i < n_cb; ++i) {
command_buffers[i] = [ctx->queue commandBuffer];
ctx->command_buffers[i] = [ctx->queue commandBuffer];

// enqueue the command buffers in order to specify their execution order
[command_buffers[i] enqueue];
}
[ctx->command_buffers[i] enqueue];

// TODO: is this the best way to start threads?
dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
}

for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;

dispatch_async(queue, ^{
dispatch_async(ctx->d_queue, ^{
size_t offs_src0 = 0;
size_t offs_src1 = 0;
size_t offs_dst = 0;

id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];

id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];

const int node_start = (cb_idx + 0) * n_nodes_per_cb;
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
Expand Down Expand Up @@ -1117,17 +1177,19 @@ void ggml_metal_graph_compute(
}

// wait for all threads to finish
dispatch_barrier_sync(queue, ^{});

[command_buffers[n_cb - 1] waitUntilCompleted];
dispatch_barrier_sync(ctx->d_queue, ^{});

// check status of command buffers
// needed to detect if the device ran out-of-memory for example (#1881)
for (int i = 0; i < n_cb; i++) {
MTLCommandBufferStatus status = (MTLCommandBufferStatus) [command_buffers[i] status];
[ctx->command_buffers[i] waitUntilCompleted];

MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
if (status != MTLCommandBufferStatusCompleted) {
fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status);
GGML_ASSERT(false);
}
}

}
}
11 changes: 6 additions & 5 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -2436,7 +2436,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
const int nb = n / qk;

assert(n % qk == 0);
assert(nb % 2 == 0);

const block_q4_0 * restrict x = vx;
const block_q8_0 * restrict y = vy;
Expand All @@ -2445,6 +2444,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);

GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
for (int i = 0; i < nb; i += 2) {
const block_q4_0 * restrict x0 = &x[i + 0];
const block_q4_0 * restrict x1 = &x[i + 1];
Expand Down Expand Up @@ -2623,6 +2623,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
}

// Main loop
GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
for (int i = 2; i < nb; i+=2) {
_mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0);
_mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
Expand Down Expand Up @@ -2706,7 +2707,6 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
const int nb = n / qk;

assert(n % qk == 0);
assert(nb % 2 == 0);

const block_q4_1 * restrict x = vx;
const block_q8_1 * restrict y = vy;
Expand All @@ -2718,6 +2718,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *

float summs = 0;

GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
for (int i = 0; i < nb; i += 2) {
const block_q4_1 * restrict x0 = &x[i + 0];
const block_q4_1 * restrict x1 = &x[i + 1];
Expand Down Expand Up @@ -2832,7 +2833,6 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
const int nb = n / qk;

assert(n % qk == 0);
assert(nb % 2 == 0);
assert(qk == QK5_0);

const block_q5_0 * restrict x = vx;
Expand All @@ -2848,6 +2848,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
uint64_t tmp0[4];
uint64_t tmp1[4];

GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
for (int i = 0; i < nb; i += 2) {
const block_q5_0 * restrict x0 = &x[i];
const block_q5_0 * restrict x1 = &x[i + 1];
Expand Down Expand Up @@ -3072,7 +3073,6 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
const int nb = n / qk;

assert(n % qk == 0);
assert(nb % 2 == 0);
assert(qk == QK5_1);

const block_q5_1 * restrict x = vx;
Expand All @@ -3091,6 +3091,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
uint64_t tmp0[4];
uint64_t tmp1[4];

GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
for (int i = 0; i < nb; i += 2) {
const block_q5_1 * restrict x0 = &x[i];
const block_q5_1 * restrict x1 = &x[i + 1];
Expand Down Expand Up @@ -3328,7 +3329,6 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
const int nb = n / qk;

assert(n % qk == 0);
assert(nb % 2 == 0);

const block_q8_0 * restrict x = vx;
const block_q8_0 * restrict y = vy;
Expand All @@ -3337,6 +3337,7 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);

GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
for (int i = 0; i < nb; i += 2) {
const block_q8_0 * restrict x0 = &x[i + 0];
const block_q8_0 * restrict x1 = &x[i + 1];
Expand Down