diff --git a/lib/src/llama_native.dart b/lib/src/llama_native.dart index f1f1859..b5db29d 100644 --- a/lib/src/llama_native.dart +++ b/lib/src/llama_native.dart @@ -2,11 +2,13 @@ part of '../llama.dart'; class LlamaNative { ffi.Pointer _model = ffi.nullptr; + ffi.Pointer _context = ffi.nullptr; + ffi.Pointer _sampler = ffi.nullptr; String _modelPath; ModelParams _modelParams; - ContextParams contextParams; - SamplingParams samplingParams; + ContextParams _contextParams; + SamplingParams _samplingParams; Completer? _completer; @@ -22,16 +24,33 @@ class LlamaNative { _initModel(); } + set contextParams(ContextParams contextParams) { + _contextParams = contextParams; + + _initContext(); + } + + set samplingParams(SamplingParams samplingParams) { + _samplingParams = samplingParams; + + _initSampler(); + } + LlamaNative({ required String modelPath, ModelParams modelParams = const ModelParams(), - this.contextParams = const ContextParams(), - this.samplingParams = const SamplingParams() - }) : _modelPath = modelPath, _modelParams = modelParams { + ContextParams contextParams = const ContextParams(), + SamplingParams samplingParams = const SamplingParams() + }) : _modelPath = modelPath, + _modelParams = modelParams, + _contextParams = contextParams, + _samplingParams = samplingParams { lib.ggml_backend_load_all(); lib.llama_backend_init(); _initModel(); + _initContext(); + _initSampler(); } void _initModel() { @@ -46,14 +65,36 @@ class LlamaNative { nativeModelPath, nativeModelParams ); + assert(_model != ffi.nullptr, 'Failed to load model'); + } + + void _initContext() { + final nativeContextParams = _contextParams.toNative(); + + if (_context != ffi.nullptr) { + lib.llama_free(_context); + } + + _context = lib.llama_init_from_model(_model, nativeContextParams); + assert(_context != ffi.nullptr, 'Failed to initialize context'); + } + + void _initSampler() { + if (_sampler != ffi.nullptr) { + lib.llama_sampler_free(_sampler); + } + + final vocab = lib.llama_model_get_vocab(_model); + _sampler = _samplingParams.toNative(vocab); + assert(_sampler != ffi.nullptr, 'Failed to initialize sampler'); } - Stream prompt(List messages) { + Stream prompt(List messages) async* { assert(_model != ffi.nullptr, 'Model is not loaded'); _completer = Completer(); - final nCtx = contextParams.nCtx ?? 2048; + final nCtx = lib.llama_n_ctx(_context); ffi.Pointer formatted = calloc(nCtx); @@ -86,20 +127,9 @@ class LlamaNative { final prompt = formatted.cast().toDartString(); - return _generate(prompt); - } - - Stream _generate(String prompt) async* { final vocab = lib.llama_model_get_vocab(_model); - final sampler = samplingParams.toNative(vocab); - assert(sampler != ffi.nullptr, 'Failed to initialize sampler'); - - final nativeContextParams = contextParams.toNative(); - - final context = lib.llama_init_from_model(_model, nativeContextParams); - assert(context != ffi.nullptr, 'Failed to initialize context'); - final isFirst = lib.llama_get_kv_cache_used_cells(context) == 0; + final isFirst = lib.llama_get_kv_cache_used_cells(_context) == 0; final nPromptTokens = -lib.llama_tokenize(vocab, prompt.toNativeUtf8().cast(), prompt.length, ffi.nullptr, 0, isFirst, true); ffi.Pointer promptTokens = calloc(nPromptTokens); @@ -112,18 +142,17 @@ class LlamaNative { int newTokenId; while (!_completer!.isCompleted) { - final nCtx = lib.llama_n_ctx(context); - final nCtxUsed = lib.llama_get_kv_cache_used_cells(context); + final nCtxUsed = lib.llama_get_kv_cache_used_cells(_context); if (nCtxUsed + batch.n_tokens > nCtx) { throw LlamaException('Context size exceeded'); } - if (lib.llama_decode(context, batch) != 0) { + if (lib.llama_decode(_context, batch) != 0) { throw LlamaException('Failed to decode'); } - newTokenId = lib.llama_sampler_sample(sampler, context, -1); + newTokenId = lib.llama_sampler_sample(_sampler, _context, -1); // is it an end of generation? if (lib.llama_vocab_is_eog(vocab, newTokenId)) { @@ -153,8 +182,6 @@ class LlamaNative { } lib.llama_batch_free(batch); - lib.llama_sampler_free(sampler); - lib.llama_free(context); calloc.free(promptTokens); } @@ -164,5 +191,7 @@ class LlamaNative { void free() { lib.llama_free_model(_model); + lib.llama_sampler_free(_sampler); + lib.llama_free(_context); } } \ No newline at end of file