Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose DRY Sampler [Incomplete] #91

Merged
merged 9 commits into from
Nov 18, 2024
15 changes: 15 additions & 0 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,16 @@ public WritableMap completion(ReadableMap params) {
params.hasKey("ignore_eos") ? params.getBoolean("ignore_eos") : false,
// double[][] logit_bias,
logit_bias,
// float dry_multiplier,
params.hasKey("dry_multiplier") ? (float) params.getDouble("dry_multiplier") : 0.00f,
// float dry_base,
params.hasKey("dry_base") ? (float) params.getDouble("dry_base") : 1.75f,
// int dry_allowed_length,
params.hasKey("dry_allowed_length") ? params.getInt("dry_allowed_length") : 2,
// int dry_penalty_last_n,
params.hasKey("dry_penalty_last_n") ? params.getInt("dry_penalty_last_n") : -1,
// String[] dry_sequence_breakers, when undef, we use the default definition from common.h
params.hasKey("dry_sequence_breakers") ? params.getArray("dry_sequence_breakers").toArrayList().toArray(new String[0]) : new String[]{"\n", ":", "\"", "*"},
// PartialCompletionCallback partial_completion_callback
new PartialCompletionCallback(
this,
Expand Down Expand Up @@ -445,6 +455,11 @@ protected static native WritableMap doCompletion(
String[] stop,
boolean ignore_eos,
double[][] logit_bias,
float dry_multiplier,
float dry_base,
int dry_allowed_length,
int dry_penalty_last_n,
String[] dry_sequence_breakers,
PartialCompletionCallback partial_completion_callback
);
protected static native void stopCompletion(long contextPtr);
Expand Down
25 changes: 25 additions & 0 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,11 @@ Java_com_rnllama_LlamaContext_doCompletion(
jobjectArray stop,
jboolean ignore_eos,
jobjectArray logit_bias,
jfloat dry_multiplier,
jfloat dry_base,
jint dry_allowed_length,
jint dry_penalty_last_n,
jobjectArray dry_sequence_breakers,
jobject partial_completion_callback
) {
UNUSED(thiz);
Expand Down Expand Up @@ -573,12 +578,32 @@ Java_com_rnllama_LlamaContext_doCompletion(
sparams.grammar = env->GetStringUTFChars(grammar, nullptr);
sparams.xtc_threshold = xtc_threshold;
sparams.xtc_probability = xtc_probability;
sparams.dry_multiplier = dry_multiplier;
sparams.dry_base = dry_base;
sparams.dry_allowed_length = dry_allowed_length;
sparams.dry_penalty_last_n = dry_penalty_last_n;

sparams.logit_bias.clear();
if (ignore_eos) {
sparams.logit_bias[llama_token_eos(llama->model)].bias = -INFINITY;
}

// dry break seq

jint size = env->GetArrayLength(dry_sequence_breakers);
std::vector<std::string> dry_sequence_breakers_vector;

for (jint i = 0; i < size; i++) {
jstring javaString = (jstring)env->GetObjectArrayElement(dry_sequence_breakers, i);
const char *nativeString = env->GetStringUTFChars(javaString, 0);
dry_sequence_breakers_vector.push_back(std::string(nativeString));
env->ReleaseStringUTFChars(javaString, nativeString);
env->DeleteLocalRef(javaString);
}

sparams.dry_sequence_breakers = dry_sequence_breakers_vector;

// logit bias
const int n_vocab = llama_n_vocab(llama_get_model(llama->ctx));
jsize logit_bias_len = env->GetArrayLength(logit_bias);

Expand Down
37 changes: 23 additions & 14 deletions example/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -383,22 +383,32 @@
{
messages: msgs,
n_predict: 100,
grammar,
seed: -1,
n_probs: 0,

Check warning on line 389 in example/src/App.tsx

View workflow job for this annotation

GitHub Actions / lint

Unexpected constant condition
// Sampling params
top_k: 40,
top_p: 0.5,
min_p: 0.05,
xtc_probability: 0.5,
xtc_threshold: 0.1,
typical_p: 1.0,
temperature: 0.7,
top_k: 40, // <= 0 to use vocab size
top_p: 0.5, // 1.0 = disabled
typical_p: 1.0, // 1.0 = disabled
penalty_last_n: 256, // 0 = disable penalty, -1 = context size
penalty_repeat: 1.18, // 1.0 = disabled
penalty_freq: 0.0, // 0.0 = disabled
penalty_present: 0.0, // 0.0 = disabled
mirostat: 0, // 0/1/2
mirostat_tau: 5, // target entropy
mirostat_eta: 0.1, // learning rate
penalize_nl: false, // penalize newlines
seed: -1, // random seed
n_probs: 0, // Show probabilities
penalty_last_n: 64,
penalty_repeat: 1.0,
penalty_freq: 0.0,
penalty_present: 0.0,
dry_multiplier: 0,
dry_base: 1.75,
dry_allowed_length: 2,
dry_penalty_last_n: -1,
dry_sequence_breakers: ["\n", ":", "\"", "*"],
mirostat: 0,
mirostat_tau: 5,
mirostat_eta: 0.1,
penalize_nl: false,
ignore_eos: false,
stop: [
'</s>',
'<|end|>',
Expand All @@ -410,7 +420,6 @@
'<|end_of_turn|>',
'<|endoftext|>',
],
grammar,
// n_threads: 4,
// logit_bias: [[15043,1.0]],
},
Expand All @@ -430,7 +439,7 @@
})
}
return [
{

Check warning on line 442 in example/src/App.tsx

View workflow job for this annotation

GitHub Actions / lint

'msgs' is already declared in the upper scope on line 297 column 11
author: system,
createdAt,
id,
Expand Down Expand Up @@ -465,7 +474,7 @@
}
}
return msg
})

Check warning on line 477 in example/src/App.tsx

View workflow job for this annotation

GitHub Actions / lint

'msgs' is already declared in the upper scope on line 297 column 11
}
return msgs
})
Expand Down
13 changes: 13 additions & 0 deletions ios/RNLlamaContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,19 @@ - (NSDictionary *)completion:(NSDictionary *)params
if (params[@"xtc_probability"]) sparams.xtc_probability = [params[@"xtc_probability"] doubleValue];
if (params[@"typical_p"]) sparams.typ_p = [params[@"typical_p"] doubleValue];

if (params[@"dry_multiplier"]) sparams.dry_multiplier = [params[@"dry_multiplier"] doubleValue];
if (params[@"dry_base"]) sparams.dry_base = [params[@"dry_base"] doubleValue];
if (params[@"dry_allowed_length"]) sparams.dry_allowed_length = [params[@"dry_allowed_length"] intValue];
if (params[@"dry_penalty_last_n"]) sparams.dry_penalty_last_n = [params[@"dry_penalty_last_n"] intValue];

// dry break seq
if (params[@"dry_sequence_breakers"] && [params[@"dry_sequence_breakers"] isKindOfClass:[NSArray class]]) {
NSArray *dry_sequence_breakers = params[@"dry_sequence_breakers"];
for (NSString *s in dry_sequence_breakers) {
sparams.dry_sequence_breakers.push_back([s UTF8String]);
}
}

if (params[@"grammar"]) {
sparams.grammar = [params[@"grammar"] UTF8String];
}
Expand Down
6 changes: 6 additions & 0 deletions src/NativeRNLlama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ export type NativeCompletionParams = {
penalize_nl?: boolean
seed?: number

dry_multiplier?: number
dry_base?: number
dry_allowed_length?: number
dry_penalty_last_n?: number
dry_sequence_breakers?: Array<string>

ignore_eos?: boolean
logit_bias?: Array<Array<number>>

Expand Down
Loading