Skip to content

Commit

Permalink
context
Browse files Browse the repository at this point in the history
  • Loading branch information
danemadsen committed Feb 7, 2025
1 parent 2e723b9 commit c7a4d2b
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 157 deletions.
230 changes: 147 additions & 83 deletions lib/src/context_params.dart
Original file line number Diff line number Diff line change
Expand Up @@ -3,102 +3,98 @@ part of '../llama.dart';

class ContextParams {
// text context, 0 = from model
int nCtx;
final int? nCtx;

// logical maximum batch size that can be submitted to llama_decode
int nBatch;
final int? nBatch;

// physical maximum batch size
int nUBatch;
final int? nUBatch;

// max number of sequences (i.e. distinct states for recurrent models)
int nSeqMax;
final int? nSeqMax;

// number of threads to use for generation
int nThreads;
final int? nThreads;

// number of threads to use for batch processing
int nThreadsBatch;
final int? nThreadsBatch;

// RoPE scaling type, from `enum llama_rope_scaling_type`
RopeScalingType ropeScalingType;
final RopeScalingType? ropeScalingType;

// whether to pool (sum) embedding results by sequence id
PoolingType poolingType;
final PoolingType? poolingType;

// attention type to use for embeddings
AttentionType attentionType;
final AttentionType? attentionType;

// RoPE base frequency, 0 = from model
double ropeFrequencyBase;
final double? ropeFrequencyBase;

// RoPE frequency scaling factor, 0 = from model
double ropeFrequencyScale;
final double? ropeFrequencyScale;

// YaRN extrapolation mix factor, negative = from model
double yarnExtrapolationFactor;
final double? yarnExtrapolationFactor;

// YaRN magnitude scaling factor
double yarnAttenuationFactor;
final double? yarnAttenuationFactor;

// YaRN low correction dim
double yarnBetaFast;
final double? yarnBetaFast;

// YaRN high correction dim
double yarnBetaSlow;
final double? yarnBetaSlow;

// YaRN original context size
int yarnOriginalContext;
final int? yarnOriginalContext;

// defragment the KV cache if holes/size > thold, < 0 disabled (default)
double defragmentationThreshold;
final double? defragmentationThreshold;

// data type for K cache
GgmlType typeK;
final GgmlType? typeK;

// data type for V cache
GgmlType typeV;

// if true, extract logits for each token
bool logitsAll;
final GgmlType? typeV;

// if true, extract embeddings (together with logits)
bool embeddings;
final bool? embeddings;

// whether to offload the KQV ops (including the KV cache) to GPU
bool offloadKqv;
final bool? offloadKqv;

// whether to use flash attention
bool flashAttention;
final bool? flashAttention;

// whether to measure performance timings
bool noPerformance;

ContextParams({
this.nCtx = 512,
this.nBatch = 2048,
this.nUBatch = 512,
this.nSeqMax = 1,
this.nThreads = 4,
this.nThreadsBatch = 4,
this.ropeScalingType = RopeScalingType.unspecified,
this.poolingType = PoolingType.unspecified,
this.attentionType = AttentionType.unspecified,
this.ropeFrequencyBase = 0.0,
this.ropeFrequencyScale = 0.0,
this.yarnExtrapolationFactor = -1.0,
this.yarnAttenuationFactor = 1.0,
this.yarnBetaFast = 32.0,
this.yarnBetaSlow = 1.0,
this.yarnOriginalContext = 0,
this.defragmentationThreshold = -1.0,
this.typeK = GgmlType.f16,
this.typeV = GgmlType.f16,
this.logitsAll = false,
this.embeddings = false,
this.offloadKqv = true,
this.flashAttention = false,
this.noPerformance = true,
final bool? noPerformance;

const ContextParams({
this.nCtx,
this.nBatch,
this.nUBatch,
this.nSeqMax,
this.nThreads,
this.nThreadsBatch,
this.ropeScalingType,
this.poolingType,
this.attentionType,
this.ropeFrequencyBase,
this.ropeFrequencyScale,
this.yarnExtrapolationFactor,
this.yarnAttenuationFactor,
this.yarnBetaFast,
this.yarnBetaSlow,
this.yarnOriginalContext,
this.defragmentationThreshold,
this.typeK,
this.typeV,
this.embeddings,
this.offloadKqv,
this.flashAttention,
this.noPerformance,
});

factory ContextParams.fromMap(Map<String, dynamic> map) => ContextParams(
Expand Down Expand Up @@ -130,32 +126,100 @@ class ContextParams {
factory ContextParams.fromJson(String source) => ContextParams.fromMap(jsonDecode(source));

llama_context_params toNative() {
final llama_context_params contextParams = calloc<llama_context_params>().ref;

contextParams.n_ctx = nCtx;
contextParams.n_batch = nBatch;
contextParams.n_ubatch = nUBatch;
contextParams.n_seq_max = nSeqMax;
contextParams.n_threads = nThreads;
contextParams.n_threads_batch = nThreadsBatch;
contextParams.rope_scaling_type = ropeScalingType.index - 1;
contextParams.pooling_type = poolingType.index - 1;
contextParams.attention_type = attentionType.index - 1;
contextParams.rope_freq_base = ropeFrequencyBase;
contextParams.rope_freq_scale = ropeFrequencyScale;
contextParams.yarn_ext_factor = yarnExtrapolationFactor;
contextParams.yarn_attn_factor = yarnAttenuationFactor;
contextParams.yarn_beta_fast = yarnBetaFast;
contextParams.yarn_beta_slow = yarnBetaSlow;
contextParams.yarn_orig_ctx = yarnOriginalContext;
contextParams.defrag_thold = defragmentationThreshold;
contextParams.type_k = typeK.index;
contextParams.type_v = typeV.index;
contextParams.logits_all = logitsAll;
contextParams.embeddings = embeddings;
contextParams.offload_kqv = offloadKqv;
contextParams.flash_attn = flashAttention;
contextParams.no_perf = noPerformance;
final llama_context_params contextParams = lib.llama_context_default_params();
log('Context params initialized');

if (nCtx != null) {
contextParams.n_ctx = nCtx!;
}

if (nBatch != null) {
contextParams.n_batch = nBatch!;
}

if (nUBatch != null) {
contextParams.n_ubatch = nUBatch!;
}

if (nSeqMax != null) {
contextParams.n_seq_max = nSeqMax!;
}

if (nThreads != null) {
contextParams.n_threads = nThreads!;
}

if (nThreadsBatch != null) {
contextParams.n_threads_batch = nThreadsBatch!;
}

if (ropeScalingType != null) {
contextParams.rope_scaling_type = ropeScalingType!.index;
}

if (poolingType != null) {
contextParams.pooling_type = poolingType!.index;
}

if (attentionType != null) {
contextParams.attention_type = attentionType!.index;
}

if (ropeFrequencyBase != null) {
contextParams.rope_freq_base = ropeFrequencyBase!;
}

if (ropeFrequencyScale != null) {
contextParams.rope_freq_scale = ropeFrequencyScale!;
}

if (yarnExtrapolationFactor != null) {
contextParams.yarn_ext_factor = yarnExtrapolationFactor!;
}

if (yarnAttenuationFactor != null) {
contextParams.yarn_attn_factor = yarnAttenuationFactor!;
}

if (yarnBetaFast != null) {
contextParams.yarn_beta_fast = yarnBetaFast!;
}

if (yarnBetaSlow != null) {
contextParams.yarn_beta_slow = yarnBetaSlow!;
}

if (yarnOriginalContext != null) {
contextParams.yarn_orig_ctx = yarnOriginalContext!;
}

if (defragmentationThreshold != null) {
contextParams.defrag_thold = defragmentationThreshold!;
}

if (typeK != null) {
contextParams.type_k = typeK!.index;
}

if (typeV != null) {
contextParams.type_v = typeV!.index;
}

if (embeddings != null) {
contextParams.embeddings = embeddings!;
}

if (offloadKqv != null) {
contextParams.offload_kqv = offloadKqv!;
}

if (flashAttention != null) {
contextParams.flash_attn = flashAttention!;
}

if (noPerformance != null) {
contextParams.no_perf = noPerformance!;
}

return contextParams;
}
Expand All @@ -167,9 +231,9 @@ class ContextParams {
'nSeqMax': nSeqMax,
'nThreads': nThreads,
'nThreadsBatch': nThreadsBatch,
'ropeScalingType': ropeScalingType.name,
'poolingType': poolingType.name,
'attentionType': attentionType.name,
'ropeScalingType': ropeScalingType?.name,
'poolingType': poolingType?.name,
'attentionType': attentionType?.name,
'ropeFrequencyBase': ropeFrequencyBase,
'ropeFrequencyScale': ropeFrequencyScale,
'yarnExtrapolationFactor': yarnExtrapolationFactor,
Expand All @@ -178,8 +242,8 @@ class ContextParams {
'yarnBetaSlow': yarnBetaSlow,
'yarnOriginalContext': yarnOriginalContext,
'defragmentationThreshold': defragmentationThreshold,
'typeK': typeK.name,
'typeV': typeV.name,
'typeK': typeK?.name,
'typeV': typeV?.name,
'embeddings': embeddings,
'offloadKqv': offloadKqv,
'flashAttention': flashAttention,
Expand Down
19 changes: 11 additions & 8 deletions lib/src/llama_cpp_native.dart
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@ class LlamaCppNative {

LlamaCppNative({
required String modelPath,
required ModelParams modelParams,
required this.contextParams,
required this.samplingParams
ModelParams modelParams = const ModelParams(),
this.contextParams = const ContextParams(),
this.samplingParams = const SamplingParams()
}) : _modelPath = modelPath, _modelParams = modelParams {
lib.ggml_backend_load_all();
lib.llama_backend_init();

_initModel();
}

Expand All @@ -50,7 +53,7 @@ class LlamaCppNative {

_completer = Completer();

final nCtx = contextParams.nCtx;
final nCtx = contextParams.nCtx ?? 2048;

ffi.Pointer<ffi.Char> formatted = calloc<ffi.Char>(nCtx);

Expand Down Expand Up @@ -87,15 +90,15 @@ class LlamaCppNative {
}

Stream<String> _generate(String prompt) async* {
final vocab = lib.llama_model_get_vocab(_model);
final sampler = samplingParams.toNative(vocab);
assert(sampler != ffi.nullptr, 'Failed to initialize sampler');

final nativeContextParams = contextParams.toNative();

final context = lib.llama_init_from_model(_model, nativeContextParams);
assert(context != ffi.nullptr, 'Failed to initialize context');

final vocab = lib.llama_model_get_vocab(_model);
final sampler = samplingParams.toNative(vocab);
assert(sampler != ffi.nullptr, 'Failed to initialize sampler');

final isFirst = lib.llama_get_kv_cache_used_cells(context) == 0;

final nPromptTokens = -lib.llama_tokenize(vocab, prompt.toNativeUtf8().cast<ffi.Char>(), prompt.length, ffi.nullptr, 0, isFirst, true);
Expand Down
10 changes: 5 additions & 5 deletions lib/src/model_params.dart
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
part of '../llama.dart';

class ModelParams {
bool? vocabOnly;
bool? useMmap;
bool? useMlock;
bool? checkTensors;
final bool? vocabOnly;
final bool? useMmap;
final bool? useMlock;
final bool? checkTensors;

ModelParams({
const ModelParams({
this.vocabOnly,
this.useMmap,
this.useMlock,
Expand Down
Loading

0 comments on commit c7a4d2b

Please sign in to comment.