Skip to content

update to llama.cpp b5688 #115

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ include(FetchContent)
set(BUILD_SHARED_LIBS ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(BUILD_SHARED_LIBS OFF)
set(LLAMA_BUILD_TOOLS ON)
set(LLAMA_CURL OFF)

option(LLAMA_VERBOSE "llama: verbose output" OFF)

Expand All @@ -15,7 +17,7 @@ option(LLAMA_VERBOSE "llama: verbose output" OFF)
FetchContent_Declare(
json
GIT_REPOSITORY https://github.com/nlohmann/json
GIT_TAG v3.11.3
GIT_TAG v3.12.0
)
FetchContent_MakeAvailable(json)

Expand All @@ -25,7 +27,7 @@ set(LLAMA_BUILD_COMMON ON)
FetchContent_Declare(
llama.cpp
GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git
GIT_TAG b4916
GIT_TAG b5688
)
FetchContent_MakeAvailable(llama.cpp)

Expand Down Expand Up @@ -96,7 +98,7 @@ add_library(jllama SHARED src/main/cpp/jllama.cpp src/main/cpp/server.hpp src/ma

set_target_properties(jllama PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_include_directories(jllama PRIVATE src/main/cpp ${JNI_INCLUDE_DIRS})
target_link_libraries(jllama PRIVATE common llama nlohmann_json)
target_link_libraries(jllama PRIVATE common mtmd llama nlohmann_json)
target_compile_features(jllama PRIVATE cxx_std_11)

target_compile_definitions(jllama PRIVATE
Expand Down
50 changes: 24 additions & 26 deletions src/main/cpp/jllama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo
return;
}

SRV_INF("loading model '%s'\n", params.model.c_str());
SRV_INF("loading model '%s'\n", params.model.path.c_str());

common_init();

Expand Down Expand Up @@ -413,15 +413,12 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo

const auto model_meta = ctx_server->model_meta();

if (!params.speculative.model.empty() || !params.speculative.hf_repo.empty()) {
SRV_INF("loading draft model '%s'\n", params.speculative.model.c_str());
if (!params.speculative.model.path.empty() || !params.speculative.model.hf_repo.empty()) {
SRV_INF("loading draft model '%s'\n", params.speculative.model.path.c_str());
auto params_dft = params;

params_dft.devices = params.speculative.devices;
params_dft.hf_file = params.speculative.hf_file;
params_dft.hf_repo = params.speculative.hf_repo;
params_dft.model = params.speculative.model;
params_dft.model_url = params.speculative.model_url;
params_dft.n_ctx = params.speculative.n_ctx == 0 ? params.n_ctx / params.n_parallel : params.speculative.n_ctx;
params_dft.n_gpu_layers = params.speculative.n_gpu_layers;
params_dft.n_parallel = 1;
Expand All @@ -431,12 +428,12 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo
llama_model *model_dft = llama_init_dft.model.get();

if (model_dft == nullptr) {
SRV_ERR("failed to load draft model, '%s'\n", params.speculative.model.c_str());
SRV_ERR("failed to load draft model, '%s'\n", params.speculative.model.path.c_str());
}

if (!common_speculative_are_compatible(ctx_server->ctx, llama_init_dft.context.get())) {
SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n",
params.speculative.model.c_str(), params.model.c_str());
params.speculative.model.path.c_str(), params.model.path.c_str());
}

const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get());
Expand Down Expand Up @@ -511,7 +508,7 @@ JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv
task.id = ctx_server->queue_tasks.get_new_id();
task.index = i;

task.prompt_tokens = std::move(tokenized_prompts[i]);
task.prompt_tokens = server_tokens(tokenized_prompts[i], false);
task.params = server_task::params_from_json_cmpl(ctx_server->ctx, ctx_server->params_base, data);
task.id_selected_slot = json_value(data, "id_slot", -1);

Expand All @@ -520,7 +517,7 @@ JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv
task.params.oaicompat_cmpl_id = completion_id;
// oaicompat_model is already populated by params_from_json_cmpl

tasks.push_back(task);
tasks.push_back(std::move(task));
}
} catch (const std::exception &e) {
const auto &err = format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST);
Expand All @@ -529,10 +526,10 @@ JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv
}

ctx_server->queue_results.add_waiting_tasks(tasks);
ctx_server->queue_tasks.post(tasks);

const auto task_ids = server_task::get_list_id(tasks);

ctx_server->queue_tasks.post(std::move(tasks));

if (task_ids.size() != 1) {
env->ThrowNew(c_llama_error, "multitasking currently not supported");
return 0;
Expand Down Expand Up @@ -600,24 +597,24 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env,

SRV_INF("Calling embedding '%s'\n", prompt.c_str());

const auto tokens = tokenize_mixed(ctx_server->vocab, prompt, true, true);
auto tokens = tokenize_mixed(ctx_server->vocab, prompt, true, true);
std::vector<server_task> tasks;

server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);

task.id = ctx_server->queue_tasks.get_new_id();
task.index = 0;
task.prompt_tokens = std::move(tokens);
task.prompt_tokens = server_tokens(tokens, false);

// OAI-compat
task.params.oaicompat = OAICOMPAT_TYPE_NONE;

tasks.push_back(task);
tasks.push_back(std::move(task));

ctx_server->queue_results.add_waiting_tasks(tasks);
ctx_server->queue_tasks.post(tasks);

std::unordered_set<int> task_ids = server_task::get_list_id(tasks);

ctx_server->queue_tasks.post(std::move(tasks));
const auto id_task = *task_ids.begin();
json responses = json::array();

Expand Down Expand Up @@ -677,7 +674,7 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jo
jlong server_handle = env->GetLongField(obj, f_model_pointer);
auto *ctx_server = reinterpret_cast<server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)

if (!ctx_server->params_base.reranking || ctx_server->params_base.embedding) {
if (!ctx_server->params_base.embedding || ctx_server->params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) {
env->ThrowNew(c_llama_error,
"This server does not support reranking. Start it with `--reranking` and without `--embedding`");
return nullptr;
Expand All @@ -702,14 +699,15 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jo
auto task = server_task(SERVER_TASK_TYPE_RERANK);
task.id = ctx_server->queue_tasks.get_new_id();
task.index = i;
task.prompt_tokens = format_rerank(ctx_server->vocab, tokenized_query, tokenized_docs[i]);
tasks.push_back(task);
auto tokens = format_rerank(ctx_server->vocab, tokenized_query, tokenized_docs[i]);
task.prompt_tokens = server_tokens(tokens, false);
tasks.push_back(std::move(task));
}
ctx_server->queue_results.add_waiting_tasks(tasks);
ctx_server->queue_tasks.post(tasks);
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);

ctx_server->queue_tasks.post(std::move(tasks));
// get the result
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
std::vector<server_task_result_ptr> results(task_ids.size());

// Create a new HashMap instance
Expand Down Expand Up @@ -754,14 +752,14 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jo

JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate(JNIEnv *env, jobject obj, jstring jparams) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
auto *ctx_server = reinterpret_cast<server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)
const auto *ctx_server = reinterpret_cast<server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)

std::string c_params = parse_jstring(env, jparams);
json data = json::parse(c_params);

json templateData =
oaicompat_completion_params_parse(data, ctx_server->params_base.use_jinja,
ctx_server->params_base.reasoning_format, ctx_server->chat_templates.get());
std::vector<raw_buffer> files;
json templateData = oaicompat_chat_params_parse(data, ctx_server->oai_parser_opt, files);

std::string tok_str = templateData.at("prompt");
jstring jtok_str = env->NewStringUTF(tok_str.c_str());

Expand Down
Loading
Loading