From 637be12f16b99b8e059437ec7aaf8c6b4d08adf4 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Mon, 8 May 2023 22:21:03 +0200 Subject: [PATCH 1/9] CUDA kernel for q4_0 dequant. + mat. vec. mult. --- examples/common.cpp | 8 +++ examples/common.h | 1 + ggml-cuda.cu | 158 +++++++++++++++++++++++++++++++++++++------- ggml-cuda.h | 2 + ggml.c | 1 + ggml.h | 8 ++- llama.cpp | 22 +++++- llama.h | 1 + 8 files changed, 175 insertions(+), 26 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index 80e35d2e9cec8..86cea79736ab2 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -277,6 +277,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.use_color = true; } else if (arg == "--mlock") { params.use_mlock = true; + } else if (arg == "--gpu_layers") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.gpu_layers = std::stoi(argv[i]); } else if (arg == "--no-mmap") { params.use_mmap = false; } else if (arg == "--mtest") { @@ -421,6 +427,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { if (llama_mmap_supported()) { fprintf(stderr, " --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n"); } + fprintf(stderr, " --gpu_layers number of layers to store in VRAM\n"); fprintf(stderr, " --mtest compute maximum memory usage\n"); fprintf(stderr, " --verbose-prompt print prompt before generation\n"); fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); @@ -469,6 +476,7 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) { lparams.f16_kv = params.memory_f16; lparams.use_mmap = params.use_mmap; lparams.use_mlock = params.use_mlock; + lparams.gpu_layers = params.gpu_layers; lparams.logits_all = params.perplexity; lparams.embedding = params.embedding; diff --git a/examples/common.h b/examples/common.h index 499671b2e8d6d..636dc359412c0 100644 --- a/examples/common.h +++ b/examples/common.h @@ -69,6 +69,7 @@ struct gpt_params { bool perplexity = false; // compute perplexity over the prompt bool use_mmap = true; // use mmap for faster loads bool use_mlock = false; // use mlock to keep model in memory + int gpu_layers = 0; // number of layers to store in VRAM bool mem_test = false; // compute maximum memory usage bool verbose_prompt = false; // print prompt tokens before generation }; diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 8a3beb0e54b88..99c3ea80890d9 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -173,6 +173,52 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) { } } +template static __global__ void dequantize_mul_mat_q4_0(const void * vx, const float * y, float * dst, const int ncols) { + const block_q4_0 * x = (const block_q4_0 *) vx; + const int qk = QK4_0; + + const int row = blockIdx.x; + const int tid = threadIdx.x; + + __shared__ float tmp[block_size]; // separate sum for each thread + tmp[tid] = 0; + + for (int i = 0; i < ncols/block_size; i += 2) { + const int col = i*block_size + 2*tid; + const int ib = (row*ncols + col)/qk; // block index + const int iqs = (col%qk)/2; // quant index + const int iybs = col - col%qk; // y block start index + + // dequantize + const float d = x[ib].d; + + const uint8_t * pp = x[ib].qs; + + const uint8_t vui = pp[iqs]; + + const int8_t vi0 = vui & 0xF; + const int8_t vi1 = vui >> 4; + + const float v0 = (vi0 - 8)*d; + const float v1 = (vi1 - 8)*d; + + // matrix multiplication + tmp[tid] += v0 * y[iybs + iqs + 0]; + tmp[tid] += v1 * y[iybs + iqs + qk/2]; + } + + // sum up partial sums and write back result + for (int s=block_size/2; s>0; s>>=1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + __syncthreads(); + } + if (tid == 0) { + dst[row] = tmp[0]; + } +} + static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK4_0; dequantize_block_q4_0<<>>(vx, y); @@ -198,6 +244,23 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre dequantize_block_q8_0<<>>(vx, y); } +static void dequantize_mul_mat_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + // static int block_size = -1; + // if (block_size == -1) { + // int min_grid_size, max_block_size = 1; + // CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &max_block_size, dequantize_mul_mat_q4_0<256>, 0, 0)); + // max_block_size = min(max_block_size, GGML_CUDA_MAX_BLOCK_SIZE); + // block_size = 1; + // while (block_size*2 <= max_block_size && block_size*2 % ncols == 0) { + // block_size *= 2; + // } + // } + // dequantize_mul_mat_q4_0<<>>(vx, y, dst, ncols); + const int block_size = 32; + GGML_ASSERT(ncols % block_size == 0); + dequantize_mul_mat_q4_0<<>>(vx, y, dst, ncols); +} + // TODO: optimize static __global__ void convert_fp16_to_fp32(const void * vx, float * y) { const half * x = (const half *) vx; @@ -231,7 +294,7 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { } // buffer pool for cuda -#define MAX_CUDA_BUFFERS 16 +#define MAX_CUDA_BUFFERS 256 struct scoped_spin_lock { std::atomic_flag& lock; @@ -538,7 +601,10 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type); size_t x_size, y_size, d_size, q_size; - float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size); + float * d_X; + if (ne11 > 1) { + d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size); + } float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size); float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size); char * d_Q = (char *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size); @@ -553,31 +619,54 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS]; cudaEvent_t cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS]; - float * c_X = d_X + i * x_ne; float * c_Y = d_Y + i * y_ne; float * c_D = d_D + i * d_ne; char * c_Q = d_Q + i * q_sz; - // copy src0 and convert to fp32 on device - CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2)); - to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2); - CUDA_CHECK(cudaGetLastError()); - CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2)); + // copy src0 to device if necessary + if (src0->backend == GGML_BACKEND_CPU) { + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2)); + } else if (src0->backend == GGML_BACKEND_CUDA) { + c_Q = ((char *) src0->data) + i * q_sz; + } else { + GGML_ASSERT(false); + } + if (ne11 == 1) { + CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2)); - // copy src1 to device - CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream)); + // copy src1 to device + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream)); - // wait for conversion - CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0)); + // wait for data + CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0)); - // compute - CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream)); - CUBLAS_CHECK( - cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, - ne01, ne11, ne10, - &alpha, c_X, ne00, - c_Y, ne10, - &beta, c_D, ne01)); + // compute + dequantize_mul_mat_q4_0_cuda(c_Q, c_Y, c_D, ne00, ne01, cudaStream); + CUDA_CHECK(cudaGetLastError()); + + } else { + float * c_X = d_X + i * x_ne; + + // convert src0 to fp32 on device + to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2); + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2)); + + // copy src1 to device + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream)); + + // wait for conversion + CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0)); + + // compute + CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream)); + CUBLAS_CHECK( + cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha, c_X, ne00, + c_Y, ne10, + &beta, c_D, ne01)); + } // copy dst to host float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); @@ -586,7 +675,9 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor } CUDA_CHECK(cudaDeviceSynchronize()); - ggml_cuda_pool_free(d_X, x_size); + if (ne11 > 1) { + ggml_cuda_pool_free(d_X, x_size); + } ggml_cuda_pool_free(d_Y, y_size); ggml_cuda_pool_free(d_D, d_size); ggml_cuda_pool_free(d_Q, q_size); @@ -602,8 +693,7 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && - (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) { - + ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_CUDA)) { return true; } @@ -655,3 +745,25 @@ size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct return 0; } } + +void ggml_cuda_transform_tensor(ggml_tensor * tensor) { + const int64_t ne0 = tensor->ne[0]; + const int64_t ne1 = tensor->ne[1]; + const int64_t ne2 = tensor->ne[2]; + const int64_t ne3 = tensor->ne[3]; + + const ggml_type type = tensor->type; + const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type); + + size_t q_size; + char * d_Q = (char *) ggml_cuda_pool_malloc(q_sz, &q_size); + + cudaStream_t cudaStream2 = g_cudaStreams2[0]; + + // copy tensor to device + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2)); + CUDA_CHECK(cudaDeviceSynchronize()); + + tensor->data = d_Q; + tensor->backend = GGML_BACKEND_CUDA; +} diff --git a/ggml-cuda.h b/ggml-cuda.h index f7d6a8bc1842a..4e2c24283ccf4 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -14,6 +14,8 @@ void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens void * ggml_cuda_host_malloc(size_t size); void ggml_cuda_host_free(void * ptr); +void ggml_cuda_transform_tensor(struct ggml_tensor * tensor); + #ifdef __cplusplus } #endif diff --git a/ggml.c b/ggml.c index 096ccacfb7e08..344ed8c912076 100644 --- a/ggml.c +++ b/ggml.c @@ -3702,6 +3702,7 @@ struct ggml_tensor * ggml_new_tensor_impl( *result = (struct ggml_tensor) { /*.type =*/ type, + /*.backend =*/ GGML_BACKEND_CPU, /*.n_dims =*/ n_dims, /*.ne =*/ { 1, 1, 1, 1 }, /*.nb =*/ { 0, 0, 0, 0 }, diff --git a/ggml.h b/ggml.h index bb9a025e257d5..d16f416b129c2 100644 --- a/ggml.h +++ b/ggml.h @@ -243,6 +243,11 @@ extern "C" { GGML_TYPE_COUNT, }; + enum ggml_backend { + GGML_BACKEND_CPU = 0, + GGML_BACKEND_CUDA = 1, + }; + // model file types enum ggml_ftype { GGML_FTYPE_UNKNOWN = -1, @@ -322,6 +327,7 @@ extern "C" { // n-dimensional tensor struct ggml_tensor { enum ggml_type type; + enum ggml_backend backend; int n_dims; int64_t ne[GGML_MAX_DIMS]; // number of elements @@ -352,7 +358,7 @@ extern "C" { char name[32]; - char padding[8]; // TODO: remove and add padding to name? + char padding[9]; // TODO: remove and add padding to name? }; // computation graph diff --git a/llama.cpp b/llama.cpp index 0a47faa9d738d..fd890ed3196ca 100644 --- a/llama.cpp +++ b/llama.cpp @@ -9,6 +9,9 @@ #include "llama.h" #include "ggml.h" +#ifdef GGML_USE_CUBLAS +#include "ggml-cuda.h" +#endif #include #include @@ -816,6 +819,7 @@ struct llama_context_params llama_context_default_params() { /*.vocab_only =*/ false, /*.use_mmap =*/ true, /*.use_mlock =*/ false, + /*.gpu_layers =*/ 0, /*.embedding =*/ false, /*.progress_callback =*/ nullptr, /*.progress_callback_user_data =*/ nullptr, @@ -879,6 +883,7 @@ static void llama_model_load_internal( ggml_type memory_type, bool use_mmap, bool use_mlock, + int gpu_layers, bool vocab_only, llama_progress_callback progress_callback, void * progress_callback_user_data) { @@ -1021,6 +1026,18 @@ static void llama_model_load_internal( ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL); model.mapping = std::move(ml->mapping); +#ifdef GGML_USE_CUBLAS + for (int i = 0; i < std::min(gpu_layers, int(hparams.n_layer)); ++i) { + auto & layer = model.layers[i]; + ggml_cuda_transform_tensor(layer.wq); + ggml_cuda_transform_tensor(layer.wk); + ggml_cuda_transform_tensor(layer.wv); + ggml_cuda_transform_tensor(layer.wo); + ggml_cuda_transform_tensor(layer.w1); + ggml_cuda_transform_tensor(layer.w2); + ggml_cuda_transform_tensor(layer.w3); + } +#endif // loading time will be recalculate after the first eval, so // we take page faults deferred by mmap() into consideration @@ -1034,11 +1051,12 @@ static bool llama_model_load( ggml_type memory_type, bool use_mmap, bool use_mlock, + int gpu_layers, bool vocab_only, llama_progress_callback progress_callback, void *progress_callback_user_data) { try { - llama_model_load_internal(fname, lctx, n_ctx, memory_type, use_mmap, use_mlock, + llama_model_load_internal(fname, lctx, n_ctx, memory_type, use_mmap, use_mlock, gpu_layers, vocab_only, progress_callback, progress_callback_user_data); return true; } catch (const std::string & err) { @@ -2097,7 +2115,7 @@ struct llama_context * llama_init_from_file( ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; if (!llama_model_load(path_model, *ctx, params.n_ctx, memory_type, - params.use_mmap, params.use_mlock, params.vocab_only, + params.use_mmap, params.use_mlock, params.gpu_layers, params.vocab_only, params.progress_callback, params.progress_callback_user_data)) { fprintf(stderr, "%s: failed to load model\n", __func__); llama_free(ctx); diff --git a/llama.h b/llama.h index 1a65cd5892389..07e686208bc01 100644 --- a/llama.h +++ b/llama.h @@ -63,6 +63,7 @@ extern "C" { bool vocab_only; // only load the vocabulary, no weights bool use_mmap; // use mmap if possible bool use_mlock; // force system to keep model in RAM + int gpu_layers; // number of layers to store in VRAM bool embedding; // embedding mode only // called with a progress value between 0 and 1, pass NULL to disable From 12fc292ee64bc90ddf3c6c421ec576e87a736733 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Fri, 12 May 2023 12:42:09 +0200 Subject: [PATCH 2/9] Added q4_1 via template --- ggml-cuda.cu | 88 ++++++++++++++++++++++++++++++++++------------------ 1 file changed, 57 insertions(+), 31 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 99c3ea80890d9..0674bd3c9147e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -32,7 +32,9 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); } \ } while (0) +typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1); typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream); +typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream); #define QK4_0 32 typedef struct { @@ -73,6 +75,37 @@ typedef struct { } block_q8_0; static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); +#define CUDA_DMMV_BLOCK_SIZE 32 + +static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){ + const block_q4_0 * x = (const block_q4_0 *) vx; + + const float d = x[ib].d; + + const uint8_t vui = x[ib].qs[iqs]; + + const int8_t vi0 = vui & 0xF; + const int8_t vi1 = vui >> 4; + + v0 = (vi0 - 8)*d; + v1 = (vi1 - 8)*d; +} + +static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){ + const block_q4_1 * x = (const block_q4_1 *) vx; + + const float d = x[ib].d; + const float m = x[ib].m; + + const uint8_t vui = x[ib].qs[iqs]; + + const int8_t vi0 = vui & 0xF; + const int8_t vi1 = vui >> 4; + + v0 = vi0*d + m; + v1 = vi1*d + m; +} + static __global__ void dequantize_block_q4_0(const void * vx, float * y) { static const int qk = QK4_0; @@ -173,10 +206,7 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) { } } -template static __global__ void dequantize_mul_mat_q4_0(const void * vx, const float * y, float * dst, const int ncols) { - const block_q4_0 * x = (const block_q4_0 *) vx; - const int qk = QK4_0; - +template static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) { const int row = blockIdx.x; const int tid = threadIdx.x; @@ -190,17 +220,8 @@ template static __global__ void dequantize_mul_mat_q4_0(const v const int iybs = col - col%qk; // y block start index // dequantize - const float d = x[ib].d; - - const uint8_t * pp = x[ib].qs; - - const uint8_t vui = pp[iqs]; - - const int8_t vi0 = vui & 0xF; - const int8_t vi1 = vui >> 4; - - const float v0 = (vi0 - 8)*d; - const float v1 = (vi1 - 8)*d; + float v0, v1; + dequantize_kernel(vx, ib, iqs, v0, v1); // matrix multiplication tmp[tid] += v0 * y[iybs + iqs + 0]; @@ -244,21 +265,14 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre dequantize_block_q8_0<<>>(vx, y); } -static void dequantize_mul_mat_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - // static int block_size = -1; - // if (block_size == -1) { - // int min_grid_size, max_block_size = 1; - // CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &max_block_size, dequantize_mul_mat_q4_0<256>, 0, 0)); - // max_block_size = min(max_block_size, GGML_CUDA_MAX_BLOCK_SIZE); - // block_size = 1; - // while (block_size*2 <= max_block_size && block_size*2 % ncols == 0) { - // block_size *= 2; - // } - // } - // dequantize_mul_mat_q4_0<<>>(vx, y, dst, ncols); - const int block_size = 32; - GGML_ASSERT(ncols % block_size == 0); - dequantize_mul_mat_q4_0<<>>(vx, y, dst, ncols); +static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); + dequantize_mul_mat_vec<<>>(vx, y, dst, ncols); +} + +static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); + dequantize_mul_mat_vec<<>>(vx, y, dst, ncols); } // TODO: optimize @@ -293,6 +307,17 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { } } +static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return dequantize_mul_mat_vec_q4_0_cuda; + case GGML_TYPE_Q4_1: + return dequantize_mul_mat_vec_q4_1_cuda; + default: + return nullptr; + } +} + // buffer pool for cuda #define MAX_CUDA_BUFFERS 256 @@ -610,6 +635,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor char * d_Q = (char *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size); const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(type); + dequantize_mul_mat_vec_cuda_t dmmv = ggml_get_dequantize_mul_mat_vec_cuda(type); GGML_ASSERT(to_fp32_cuda != nullptr); for (int64_t i03 = 0; i03 < ne03; i03++) { @@ -641,7 +667,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0)); // compute - dequantize_mul_mat_q4_0_cuda(c_Q, c_Y, c_D, ne00, ne01, cudaStream); + dmmv(c_Q, c_Y, c_D, ne00, ne01, cudaStream); CUDA_CHECK(cudaGetLastError()); } else { From 7dc2f57e5e145b8136300eedb5836da6821b76e0 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Fri, 12 May 2023 21:37:34 +0200 Subject: [PATCH 3/9] Added missing __syncthreads(); --- ggml-cuda.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 0674bd3c9147e..4c37f42789902 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -229,6 +229,7 @@ template static } // sum up partial sums and write back result + __syncthreads(); for (int s=block_size/2; s>0; s>>=1) { if (tid < s) { tmp[tid] += tmp[tid + s]; From f0af475739524cc8a372ac9ec58458527cfdaf84 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Fri, 12 May 2023 21:43:47 +0200 Subject: [PATCH 4/9] --gpu_layers -> --gpu-layers --- examples/common.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index 86cea79736ab2..43a105cdde840 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -277,7 +277,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.use_color = true; } else if (arg == "--mlock") { params.use_mlock = true; - } else if (arg == "--gpu_layers") { + } else if (arg == "--gpu-layers") { if (++i >= argc) { invalid_param = true; break; @@ -427,7 +427,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { if (llama_mmap_supported()) { fprintf(stderr, " --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n"); } - fprintf(stderr, " --gpu_layers number of layers to store in VRAM\n"); + fprintf(stderr, " --gpu-layers number of layers to store in VRAM\n"); fprintf(stderr, " --mtest compute maximum memory usage\n"); fprintf(stderr, " --verbose-prompt print prompt before generation\n"); fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); From 0986c2f44e5fb950cfa0331f6846bbc5518266da Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Fri, 12 May 2023 23:30:17 +0200 Subject: [PATCH 5/9] Shorter dequantize_mul_mat_vec line --- ggml-cuda.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 4c37f42789902..7873f4f6b1fc9 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -206,7 +206,8 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) { } } -template static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) { +template +static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) { const int row = blockIdx.x; const int tid = threadIdx.x; From 9da44fdcb38ff3e6627f4318f9bb47bd3709ea22 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Fri, 12 May 2023 23:57:10 +0200 Subject: [PATCH 6/9] q5_0 dequantize_mul_mat kernel --- ggml-cuda.cu | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 7873f4f6b1fc9..66a2b0f931104 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -106,6 +106,24 @@ static __device__ void dequantize_q4_1(const void * vx, const int ib, const int v1 = vi1*d + m; } +static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){ + const block_q5_0 * x = (const block_q5_0 *) vx; + + const float d = x[ib].d; + + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10; + + const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0) - 16; + const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1) - 16; + + v0 = x0*d; + v1 = x1*d; +} + static __global__ void dequantize_block_q4_0(const void * vx, float * y) { static const int qk = QK4_0; @@ -277,6 +295,11 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, f dequantize_mul_mat_vec<<>>(vx, y, dst, ncols); } +static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); + dequantize_mul_mat_vec<<>>(vx, y, dst, ncols); +} + // TODO: optimize static __global__ void convert_fp16_to_fp32(const void * vx, float * y) { const half * x = (const half *) vx; @@ -315,6 +338,8 @@ static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_t return dequantize_mul_mat_vec_q4_0_cuda; case GGML_TYPE_Q4_1: return dequantize_mul_mat_vec_q4_1_cuda; + case GGML_TYPE_Q5_0: + return dequantize_mul_mat_vec_q5_0_cuda; default: return nullptr; } From 5a0ecf768d1d69e4f40c3e922f72e88112580d2f Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Sat, 13 May 2023 07:14:27 +0200 Subject: [PATCH 7/9] More readable dequantize_mul_mat_vec logic --- ggml-cuda.cu | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 66a2b0f931104..161b5468228f9 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -643,6 +643,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor const int nb2 = dst->nb[2]; const int nb3 = dst->nb[3]; const ggml_type type = src0->type; + const bool mul_mat_vec = ne11 == 1; const float alpha = 1.0f; const float beta = 0.0f; @@ -654,7 +655,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor size_t x_size, y_size, d_size, q_size; float * d_X; - if (ne11 > 1) { + if (!mul_mat_vec) { d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size); } float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size); @@ -684,7 +685,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor } else { GGML_ASSERT(false); } - if (ne11 == 1) { + if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2)); // copy src1 to device @@ -697,7 +698,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor dmmv(c_Q, c_Y, c_D, ne00, ne01, cudaStream); CUDA_CHECK(cudaGetLastError()); - } else { + } else { // general dequantization kernel + cuBLAS matrix matrix multiplication float * c_X = d_X + i * x_ne; // convert src0 to fp32 on device @@ -728,7 +729,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor } CUDA_CHECK(cudaDeviceSynchronize()); - if (ne11 > 1) { + if (!mul_mat_vec) { ggml_cuda_pool_free(d_X, x_size); } ggml_cuda_pool_free(d_Y, y_size); From bb0993ed4830112052b318bb5100706f996f923a Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Sat, 13 May 2023 08:10:38 +0200 Subject: [PATCH 8/9] dequantize_mul_mat_vec kernels for q5_1, q8_0, f16 --- ggml-cuda.cu | 87 ++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 81 insertions(+), 6 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 161b5468228f9..812e0d4027520 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -36,7 +36,11 @@ typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream); typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream); +// QK = number of values after dequantization +// QR = QK / number of values before dequantization + #define QK4_0 32 +#define QR4_0 2 typedef struct { float d; // delta uint8_t qs[QK4_0 / 2]; // nibbles / quants @@ -44,6 +48,7 @@ typedef struct { static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding"); #define QK4_1 32 +#define QR4_1 2 typedef struct { float d; // delta float m; // min @@ -52,6 +57,7 @@ typedef struct { static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); #define QK5_0 32 +#define QR5_0 2 typedef struct { half d; // delta uint8_t qh[4]; // 5-th bit of quants @@ -60,6 +66,7 @@ typedef struct { static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); #define QK5_1 32 +#define QR5_1 2 typedef struct { half d; // delta half m; // min @@ -69,6 +76,7 @@ typedef struct { static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); #define QK8_0 32 +#define QR8_0 1 typedef struct { float d; // delta int8_t qs[QK8_0]; // quants @@ -124,6 +132,44 @@ static __device__ void dequantize_q5_0(const void * vx, const int ib, const int v1 = x1*d; } +static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){ + const block_q5_1 * x = (const block_q5_1 *) vx; + + const float d = x[ib].d; + const float m = x[ib].m; + + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10; + + const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0); + const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1); + + v0 = x0*d + m; + v1 = x1*d + m; +} + +static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){ + const block_q8_0 * x = (const block_q8_0 *) vx; + + const float d = x[ib].d; + + const int8_t vi0 = x[ib].qs[iqs + 0]; + const int8_t vi1 = x[ib].qs[iqs + 1]; + + v0 = vi0*d; + v1 = vi1*d; +} + +static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){ + const half * x = (const half *) vx; + + v0 = __half2float(x[ib + 0]); + v1 = __half2float(x[ib + 1]); +} + static __global__ void dequantize_block_q4_0(const void * vx, float * y) { static const int qk = QK4_0; @@ -224,18 +270,20 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) { } } -template +template static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) { const int row = blockIdx.x; const int tid = threadIdx.x; + const int y_offset = qr == 1 ? 1 : qk/2; + __shared__ float tmp[block_size]; // separate sum for each thread tmp[tid] = 0; for (int i = 0; i < ncols/block_size; i += 2) { const int col = i*block_size + 2*tid; const int ib = (row*ncols + col)/qk; // block index - const int iqs = (col%qk)/2; // quant index + const int iqs = (col%qk)/qr; // quant index const int iybs = col - col%qk; // y block start index // dequantize @@ -244,7 +292,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, // matrix multiplication tmp[tid] += v0 * y[iybs + iqs + 0]; - tmp[tid] += v1 * y[iybs + iqs + qk/2]; + tmp[tid] += v1 * y[iybs + iqs + y_offset]; } // sum up partial sums and write back result @@ -287,17 +335,32 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); - dequantize_mul_mat_vec<<>>(vx, y, dst, ncols); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); - dequantize_mul_mat_vec<<>>(vx, y, dst, ncols); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); - dequantize_mul_mat_vec<<>>(vx, y, dst, ncols); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); +} + +static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); +} + +static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } // TODO: optimize @@ -313,6 +376,12 @@ static void convert_fp16_to_fp32_cuda(const void * x, float * y, int k, cudaStre convert_fp16_to_fp32<<>>(x, y); } +static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); +} + static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: @@ -340,6 +409,12 @@ static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_t return dequantize_mul_mat_vec_q4_1_cuda; case GGML_TYPE_Q5_0: return dequantize_mul_mat_vec_q5_0_cuda; + case GGML_TYPE_Q5_1: + return dequantize_mul_mat_vec_q5_1_cuda; + case GGML_TYPE_Q8_0: + return dequantize_mul_mat_vec_q8_0_cuda; + case GGML_TYPE_F16: + return dequantize_mul_mat_vec_q8_0_cuda; default: return nullptr; } From ad8a9e69711ac18092bf4d13f29bcf26c209248f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 13 May 2023 16:35:21 +0300 Subject: [PATCH 9/9] llama : offload "output" tensor to GPU too + coding style fixes --- examples/common.cpp | 25 +++++++++++++------------ examples/common.h | 12 ++++++------ ggml-cuda.cu | 2 +- llama.cpp | 45 ++++++++++++++++++++++++++++++--------------- llama.h | 8 ++++---- 5 files changed, 54 insertions(+), 38 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index 43a105cdde840..86c1eef41b475 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -277,12 +277,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.use_color = true; } else if (arg == "--mlock") { params.use_mlock = true; - } else if (arg == "--gpu-layers") { + } else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") { if (++i >= argc) { invalid_param = true; break; } - params.gpu_layers = std::stoi(argv[i]); + params.n_gpu_layers = std::stoi(argv[i]); } else if (arg == "--no-mmap") { params.use_mmap = false; } else if (arg == "--mtest") { @@ -427,7 +427,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { if (llama_mmap_supported()) { fprintf(stderr, " --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n"); } - fprintf(stderr, " --gpu-layers number of layers to store in VRAM\n"); + fprintf(stderr, " -ngl N, --n-gpu-layers N\n"); + fprintf(stderr, " number of layers to store in VRAM\n"); fprintf(stderr, " --mtest compute maximum memory usage\n"); fprintf(stderr, " --verbose-prompt print prompt before generation\n"); fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); @@ -470,15 +471,15 @@ std::vector llama_tokenize(struct llama_context * ctx, const std::s struct llama_context * llama_init_from_gpt_params(const gpt_params & params) { auto lparams = llama_context_default_params(); - lparams.n_ctx = params.n_ctx; - lparams.n_parts = params.n_parts; - lparams.seed = params.seed; - lparams.f16_kv = params.memory_f16; - lparams.use_mmap = params.use_mmap; - lparams.use_mlock = params.use_mlock; - lparams.gpu_layers = params.gpu_layers; - lparams.logits_all = params.perplexity; - lparams.embedding = params.embedding; + lparams.n_ctx = params.n_ctx; + lparams.n_parts = params.n_parts; + lparams.n_gpu_layers = params.n_gpu_layers; + lparams.seed = params.seed; + lparams.f16_kv = params.memory_f16; + lparams.use_mmap = params.use_mmap; + lparams.use_mlock = params.use_mlock; + lparams.logits_all = params.perplexity; + lparams.embedding = params.embedding; llama_context * lctx = llama_init_from_file(params.model.c_str(), lparams); diff --git a/examples/common.h b/examples/common.h index 636dc359412c0..717838f06e064 100644 --- a/examples/common.h +++ b/examples/common.h @@ -21,13 +21,14 @@ int32_t get_num_physical_cores(); struct gpt_params { - int32_t seed = -1; // RNG seed + int32_t seed = -1; // RNG seed int32_t n_threads = get_num_physical_cores(); int32_t n_predict = -1; // new tokens to predict - int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions) - int32_t n_ctx = 512; // context size - int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) - int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions) + int32_t n_ctx = 512; // context size + int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_gpu_layers = 0; // number of layers to store in VRAM // sampling parameters std::unordered_map logit_bias; // logit bias for specific tokens @@ -69,7 +70,6 @@ struct gpt_params { bool perplexity = false; // compute perplexity over the prompt bool use_mmap = true; // use mmap for faster loads bool use_mlock = false; // use mlock to keep model in memory - int gpu_layers = 0; // number of layers to store in VRAM bool mem_test = false; // compute maximum memory usage bool verbose_prompt = false; // print prompt tokens before generation }; diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 812e0d4027520..b6a7754d534e6 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -729,7 +729,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type); size_t x_size, y_size, d_size, q_size; - float * d_X; + float * d_X = nullptr; if (!mul_mat_vec) { d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size); } diff --git a/llama.cpp b/llama.cpp index fd890ed3196ca..436d0c67807ff 100644 --- a/llama.cpp +++ b/llama.cpp @@ -813,13 +813,13 @@ struct llama_context_params llama_context_default_params() { struct llama_context_params result = { /*.n_ctx =*/ 512, /*.n_parts =*/ -1, + /*.gpu_layers =*/ 0, /*.seed =*/ -1, /*.f16_kv =*/ false, /*.logits_all =*/ false, /*.vocab_only =*/ false, /*.use_mmap =*/ true, /*.use_mlock =*/ false, - /*.gpu_layers =*/ 0, /*.embedding =*/ false, /*.progress_callback =*/ nullptr, /*.progress_callback_user_data =*/ nullptr, @@ -880,10 +880,10 @@ static void llama_model_load_internal( const std::string & fname, llama_context & lctx, int n_ctx, + int n_gpu_layers, ggml_type memory_type, bool use_mmap, bool use_mlock, - int gpu_layers, bool vocab_only, llama_progress_callback progress_callback, void * progress_callback_user_data) { @@ -1027,15 +1027,30 @@ static void llama_model_load_internal( model.mapping = std::move(ml->mapping); #ifdef GGML_USE_CUBLAS - for (int i = 0; i < std::min(gpu_layers, int(hparams.n_layer)); ++i) { - auto & layer = model.layers[i]; - ggml_cuda_transform_tensor(layer.wq); - ggml_cuda_transform_tensor(layer.wk); - ggml_cuda_transform_tensor(layer.wv); - ggml_cuda_transform_tensor(layer.wo); - ggml_cuda_transform_tensor(layer.w1); - ggml_cuda_transform_tensor(layer.w2); - ggml_cuda_transform_tensor(layer.w3); + { + const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); + + fprintf(stderr, "%s: [cublas] offloading %d layers to GPU\n", __func__, n_gpu); + + size_t vram_total = 0; + + for (int i = 0; i < n_gpu; ++i) { + const auto & layer = model.layers[i]; + + ggml_cuda_transform_tensor(layer.wq); vram_total += ggml_nbytes(layer.wq); + ggml_cuda_transform_tensor(layer.wk); vram_total += ggml_nbytes(layer.wk); + ggml_cuda_transform_tensor(layer.wv); vram_total += ggml_nbytes(layer.wv); + ggml_cuda_transform_tensor(layer.wo); vram_total += ggml_nbytes(layer.wo); + ggml_cuda_transform_tensor(layer.w1); vram_total += ggml_nbytes(layer.w1); + ggml_cuda_transform_tensor(layer.w2); vram_total += ggml_nbytes(layer.w2); + ggml_cuda_transform_tensor(layer.w3); vram_total += ggml_nbytes(layer.w3); + } + if (n_gpu_layers > (int) hparams.n_layer) { + fprintf(stderr, "%s: [cublas] offloading output layer to GPU\n", __func__); + ggml_cuda_transform_tensor(model.output); vram_total += ggml_nbytes(model.output); + } + + fprintf(stderr, "%s: [cublas] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024); } #endif @@ -1048,15 +1063,15 @@ static bool llama_model_load( const std::string & fname, llama_context & lctx, int n_ctx, + int n_gpu_layers, ggml_type memory_type, bool use_mmap, bool use_mlock, - int gpu_layers, bool vocab_only, llama_progress_callback progress_callback, void *progress_callback_user_data) { try { - llama_model_load_internal(fname, lctx, n_ctx, memory_type, use_mmap, use_mlock, gpu_layers, + llama_model_load_internal(fname, lctx, n_ctx, n_gpu_layers, memory_type, use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data); return true; } catch (const std::string & err) { @@ -2114,8 +2129,8 @@ struct llama_context * llama_init_from_file( ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; - if (!llama_model_load(path_model, *ctx, params.n_ctx, memory_type, - params.use_mmap, params.use_mlock, params.gpu_layers, params.vocab_only, + if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_gpu_layers, memory_type, + params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback, params.progress_callback_user_data)) { fprintf(stderr, "%s: failed to load model\n", __func__); llama_free(ctx); diff --git a/llama.h b/llama.h index 07e686208bc01..78017fc86e1af 100644 --- a/llama.h +++ b/llama.h @@ -54,16 +54,16 @@ extern "C" { typedef void (*llama_progress_callback)(float progress, void *ctx); struct llama_context_params { - int n_ctx; // text context - int n_parts; // -1 for default - int seed; // RNG seed, -1 for random + int n_ctx; // text context + int n_parts; // -1 for default + int n_gpu_layers; // number of layers to store in VRAM + int seed; // RNG seed, -1 for random bool f16_kv; // use fp16 for KV cache bool logits_all; // the llama_eval() call computes all logits, not just the last one bool vocab_only; // only load the vocabulary, no weights bool use_mmap; // use mmap if possible bool use_mlock; // force system to keep model in RAM - int gpu_layers; // number of layers to store in VRAM bool embedding; // embedding mode only // called with a progress value between 0 and 1, pass NULL to disable