Skip to content

Commit

Permalink
sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
danemadsen committed Feb 6, 2025
1 parent 7d55758 commit 7bfc022
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 2 deletions.
3 changes: 3 additions & 0 deletions lib/src/llama_cpp_native.dart
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class LlamaCppNative {
ContextParams contextParams,
SamplingParams samplingParams
) {
lib.ggml_backend_load_all();
lib.llama_backend_init();
log("backend loaded");

Expand All @@ -30,6 +31,8 @@ class LlamaCppNative {
nativeModelParams
);
assert(model.address != 0, 'Failed to load model');

malloc.free(nativeModelPath);
log("Model loaded");

final nativeContextParams = contextParams.toNative();
Expand Down
86 changes: 85 additions & 1 deletion src/api.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#include "api.h"

static llama_model * model;
static llama_context * ctx;
static llama_sampler * smpl;

struct api_params api_default_params() {
auto default_model_params = llama_model_default_params();
auto default_context_params = llama_context_default_params();
Expand Down Expand Up @@ -38,7 +42,7 @@ struct api_params api_default_params() {

/*.greedy =*/ false,
/*.infill =*/ false,
/*.seed =*/ LLAMA_DEFAULT_SEED,
/*.seed =*/ NULL,
/*.top_k =*/ NULL,
/*.top_p =*/ NULL,
/*.min_p =*/ NULL,
Expand All @@ -55,4 +59,84 @@ struct api_params api_default_params() {
};

return result;
}

int api_init(struct api_params params) {
ggml_backend_load_all();

auto model_params = llama_model_default_params();
model_params.vocab_only = params.vocab_only;
model_params.use_mmap = params.use_mmap;
model_params.use_mlock = params.use_mlock;
model_params.check_tensors = params.check_tensors;

model = llama_model_load_from_file(params.model_path, model_params);

auto context_params = llama_context_default_params();
context_params.n_ctx = params.n_ctx;
context_params.n_batch = params.n_batch;
context_params.n_ubatch = params.n_ubatch;
context_params.n_seq_max = params.n_seq_max;
context_params.n_threads = params.n_threads;
context_params.n_threads_batch = params.n_threads_batch;
context_params.rope_scaling_type = params.rope_scaling_type;
context_params.pooling_type = params.pooling_type;
context_params.attention_type = params.attention_type;
context_params.rope_freq_base = params.rope_freq_base;
context_params.rope_freq_scale = params.rope_freq_scale;
context_params.yarn_ext_factor = params.yarn_ext_factor;
context_params.yarn_attn_factor = params.yarn_attn_factor;
context_params.yarn_beta_fast = params.yarn_beta_fast;
context_params.yarn_beta_slow = params.yarn_beta_slow;
context_params.yarn_orig_ctx = params.yarn_orig_ctx;
context_params.defrag_thold = params.defrag_thold;
context_params.type_k = params.type_k;
context_params.type_v = params.type_v;
context_params.logits_all = params.logits_all;
context_params.embeddings = params.embeddings;
context_params.offload_kqv = params.offload_kqv;
context_params.flash_attn = params.flash_attn;
context_params.no_perf = params.no_perf;

ctx = llama_init_from_model(model, context_params);

auto vocab = llama_model_get_vocab(model);
smpl = llama_sampler_chain_init(llama_sampler_chain_default_params());

if (params.greedy) {
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
}

if (params.infill) {
llama_sampler_chain_add(smpl, llama_sampler_init_infill(vocab));
}

if (params.seed != NULL) {
llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.seed));
}

if (params.top_k != NULL) {
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.top_k));
}

if (params.top_p.p != NULL && params.top_p.min_keep != NULL) {
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.top_p.p, params.top_p.min_keep));
}

if (params.min_p.p != NULL && params.min_p.min_keep != NULL) {
llama_sampler_chain_add(smpl, llama_sampler_init_min_p(params.min_p.p, params.min_p.min_keep));
}

if (params.typical_p.p != NULL && params.typical_p.min_keep != NULL) {
llama_sampler_chain_add(smpl, llama_sampler_init_typical(params.typical_p.p, params.typical_p.min_keep));
}

if (params.temperature.temperature != NULL) {
if (params.temperature.delta != NULL && params.temperature.exponent != NULL) {
llama_sampler_chain_add(smpl, llama_sampler_init_temp_ext(params.temperature.temperature, params.temperature.delta, params.temperature.exponent));
}
else {
llama_sampler_chain_add(smpl, llama_sampler_init_temp(params.temperature.temperature));
}
}
}
2 changes: 1 addition & 1 deletion src/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ struct p_sampling_params {
};

struct temperature_sampling_params {
float temperture;
float temperature;
float delta;
float exponent;
};
Expand Down

0 comments on commit 7bfc022

Please sign in to comment.