Skip to content

Feature: Add llava support #577

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions LLama/Common/LLamaSamplingContext.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using LLama.Grammars;
using LLama.Native;
using System;

namespace LLama.Common
{
public class LLamaSamplingContext
{
public LLamaSamplingParams parameters;

// mirostat sampler state
public float mirostat_mu;

public IntPtr grammar;
// internal
public IntPtr parsed_grammar;

// TODO: replace with ring-buffer
public LLamaToken[] prev;
public LLamaTokenData[] cur;
}
}
53 changes: 53 additions & 0 deletions LLama/Common/LLamaSamplingParams.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
using LLama.Native;
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;

namespace LLama.Common
{
[StructLayout(LayoutKind.Sequential)]
public class LLamaSamplingParams
{
public int n_prev = 64; // number of previous tokens to remember
public int n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
public int top_k = 40; // <= 0 to use vocab size
public float top_p = 0.95f; // 1.0 = disabled
public float min_p = 0.05f; // 0.0 = disabled
public float tfs_z = 1.00f; // 1.0 = disabled
public float typical_p = 1.00f; // 1.0 = disabled
public float temp = 0.70f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
public float dynatemp_range = 0.0f; //0.0 = disabled
public float dynatemp_exponent = 1.0f; // controls how entropy maps to temperature in dynamic temperature sampler
public int penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
public float penalty_repeat = 1.10f; // 1.0 = disabled
public float penalty_freq = 0.00f; // 0.0 = disabled
public float penalty_present = 0.00f; // 0.0 = disabled
public int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
public float mirostat_tau = 5.00f; // target entropy
public float mirostat_eta = 0.10f; // learning rate
public bool penalize_nl = true; // consider newlines as a repeatable token

public string samplers_sequence = "kfypmt"; // top_k, tail_free, typical_p, top_p, min_p, temp

public string grammar = string.Empty; // optional BNF-like grammar to constrain sampling

// Classifier-Free Guidance
// https://arxiv.org/abs/2306.17806

public string cfg_negative_prompt = string.Empty; // string to help guidance
public float cfg_scale = 1.0f; // how strong is guidance

//std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
public IntPtr logit_bias;
public LLamaToken[] penalty_prompt_tokens;
public bool use_penalty_prompt_tokens = false;
}
public struct logit_bias_struct
{
public LLamaToken token;
public float bias;
}


}
9 changes: 8 additions & 1 deletion LLama/Common/ModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.Text.Json.Serialization;
using LLama.Native;
using System.Collections.Generic;
using System;

namespace LLama.Common
{
Expand Down Expand Up @@ -43,7 +44,7 @@ public record ModelParams
public string LoraBase { get; set; } = string.Empty;

/// <inheritdoc />
public uint? Threads { get; set; }
public uint? Threads { get; set; } = GetNumPhysicalCores();

/// <inheritdoc />
public uint? BatchThreads { get; set; }
Expand Down Expand Up @@ -126,5 +127,11 @@ private ModelParams()
// This constructor (default parameterless constructor) is used by Newtonsoft to deserialize!
ModelPath = "";
}

private static uint GetNumPhysicalCores()
{
int n_cores = Environment.ProcessorCount;
return (uint)(n_cores > 0 ? (n_cores <= 4 ? n_cores : n_cores / 2) : 4);
}
}
}
15 changes: 15 additions & 0 deletions LLama/LLava/ImageEmbed.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;

namespace LLama.LLava
{
[StructLayout(LayoutKind.Sequential)]
unsafe
public class LLavaImageEmbed
{
public float* embed;
public int n_image_pos;
}
}
14 changes: 14 additions & 0 deletions LLama/LLava/LLavaContext.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using LLama.Native;
using System;
using System.Runtime.InteropServices;

namespace LLama.LLava
{
[StructLayout(LayoutKind.Sequential)]
public class LLavaContext
{
public IntPtr ClipContext;
public SafeLLamaContextHandle LLamaContext;
public SafeLlamaModelHandle model;
}
}
11 changes: 11 additions & 0 deletions LLama/LLavaContext.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama
{
internal class LLavaContext
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems it's duplicated with LLava/LLavaContext.

{

}
}
47 changes: 47 additions & 0 deletions LLama/LLavaInteractExecutor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
using LLama.Native;
using Microsoft.Extensions.Logging;
using System.Threading.Tasks;

namespace LLama
{
public class LLavaInteractExecutor /*: InteractiveExecutor*/
{
/// <summary>
/// weights of LLava model
/// </summary>
protected SafeLlavaModelHandle handle;

/// <summary>
///
/// </summary>
/// <param name="context"></param>
/// <param name="logger"></param>
//public LLavaInteractExecutor(SafeLlavaModelHandle handel, ILogger? logger = null)
//{
// this.handle = handel;
// this.logger = logger;
//}

//protected override Task PreprocessInputs(string prompt, byte[] imageByte, InferStateArgs args)
//{
// if (_is_prompt_run)
// {
// // When running the first input (prompt) in inteactive mode, we should specially process it.
// _embed_inps = Context.Tokenize(text, true).ToList();
// }
// else
// {
// if (!text.EndsWith("\n"))
// {
// text += "\n";
// }
// var line_inp = Context.Tokenize(text, false);
// _embed_inps.AddRange(line_inp);
// args.RemainedTokens -= line_inp.Length;
// }

// return Task.CompletedTask;

//}
}
}
9 changes: 6 additions & 3 deletions LLama/Native/LLamaNativeBatch.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,13 @@ public unsafe struct LLamaNativeBatch
/// </summary>
public byte* logits;

/// <summary>
/// There is an error occured(_all_pos_1 get a wrong num) during using llama_batch_get_one in llama.dll, have to change private to public
/// </summary>
// Note from llama.cpp:
// > helpers for smooth API transition - can be deprecated in the future
// > for future-proof code, use the above fields instead and ignore everything below
private LLamaPos _all_pos_0;
private LLamaPos _all_pos_1;
private LLamaSeqId _all_seq_id;
public LLamaPos _all_pos_0;
public LLamaPos _all_pos_1;
public LLamaSeqId _all_seq_id;
}
12 changes: 12 additions & 0 deletions LLama/Native/LLamaVocabType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Native
{
public enum LLamaVocabType
{
LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece
LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding
};
}
1 change: 1 addition & 0 deletions LLama/Native/NativeApi.Load.cs
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ string TryFindPath(string filename)
}

internal const string libraryName = "llama";
internal const string llavaLibName = "llava_shared";
private const string cudaVersionFile = "version.json";
private const string loggingPrefix = "[LLamaSharp Native]";
private static bool enableLogging = false;
Expand Down
16 changes: 14 additions & 2 deletions LLama/Native/NativeApi.Sampling.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,20 @@ public static void llama_sample_apply_guidance(SafeLLamaContextHandle ctx, Span<
/// <param name="min_keep"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_sample_min_p(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float p, ulong min_keep);




/// <summary>
/// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772.
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <param name="min_temp"></param>
/// <param name="max_temp"></param>
/// <param name="exponent_val"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_sample_entropy(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float min_temp, float max_temp, float exponent_val);


/// <summary>
/// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
/// </summary>
Expand Down
71 changes: 69 additions & 2 deletions LLama/Native/NativeApi.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System;
using LLama.LLava;
using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

#pragma warning disable IDE1006 // Naming Styles
Expand Down Expand Up @@ -187,6 +189,7 @@ public static void llama_empty_call()
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe float* llama_get_logits_ith(SafeLLamaContextHandle ctx, int i);


/// <summary>
/// Get the embeddings for the input
/// </summary>
Expand Down Expand Up @@ -330,7 +333,7 @@ public static int llama_token_to_piece(SafeLlamaModelHandle model, LLamaToken ll
/// </summary>
/// <param name="logCallback"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_log_set(LLamaLogCallback logCallback);
public static extern void llama_log_set(LLamaLogCallback logCallback);

/// <summary>
/// Clear the KV cache
Expand Down Expand Up @@ -438,5 +441,69 @@ public static int llama_token_to_piece(SafeLlamaModelHandle model, LLamaToken ll
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_set_n_threads(SafeLLamaContextHandle ctx, uint n_threads, uint n_threads_batch);

/// <summary>
/// Get vocab type from model
/// </summary>
/// <param name="model"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public extern static LLamaVocabType llama_vocab_type(SafeLlamaModelHandle model);

/// <summary>
///
/// </summary>
/// <param name="tokens"></param>
/// <param name="n_tokens"></param>
/// <param name="pos_0"></param>
/// <param name="seq_id"></param>
/// <returns></returns>

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
unsafe public extern static LLamaNativeBatch llama_batch_get_one(LLamaToken* tokens, int n_tokens, ref int pos_0, int seq_id);
/// <summary>
/// Set the Init time to ggml process
/// </summary>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void ggml_time_init();

/// <summary>
/// Load clip model
/// </summary>
/// <param name="clip_model_path">Clip model path</param>
/// <param name="verbosity"></param>
/// <returns></returns>
[DllImport(llavaLibName, CallingConvention = CallingConvention.Cdecl)]
public extern static IntPtr clip_model_load(string clip_model_path, int verbosity = 1);

/// <summary>
/// Free the clip context
/// </summary>
/// <param name="ctx">Clip context</param>
[DllImport(llavaLibName, CallingConvention = CallingConvention.Cdecl)]
public static extern void clip_free(IntPtr ctx);

/// <summary>
/// Read an image from bytes
/// </summary>
/// <param name="ctx_clip"></param>
/// <param name="n_threads"></param>
/// <param name="bytes"></param>
/// <param name="image_bytes_length"></param>
/// <returns></returns>
[DllImport(llavaLibName, CallingConvention = CallingConvention.Cdecl)]
public extern static LLavaImageEmbed llava_image_embed_make_with_bytes(IntPtr ctx_clip, int n_threads, byte[] bytes, int image_bytes_length);

/// <summary>
/// Embed an image and get the token length
/// </summary>
/// <param name="ctx_llama"></param>
/// <param name="image_embed"></param>
/// <param name="n_batch"></param>
/// <param name="n_past"></param>
/// <returns></returns>
[DllImport(llavaLibName, CallingConvention = CallingConvention.Cdecl)]
public extern unsafe static bool llava_eval_image_embed(SafeLLamaContextHandle ctx_llama, LLavaImageEmbed image_embed, int n_batch, ref int n_past);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use a separate file such as NativeApi.LLava.cs to add code related to llava only.


}
}
Loading