Skip to content

Commit

Permalink
native
Browse files Browse the repository at this point in the history
  • Loading branch information
danemadsen committed Feb 8, 2025
1 parent 5a5a29b commit 59593eb
Showing 1 changed file with 54 additions and 25 deletions.
79 changes: 54 additions & 25 deletions lib/src/llama_native.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ part of '../llama.dart';

class LlamaNative {
ffi.Pointer<llama_model> _model = ffi.nullptr;
ffi.Pointer<llama_context> _context = ffi.nullptr;
ffi.Pointer<llama_sampler> _sampler = ffi.nullptr;

String _modelPath;
ModelParams _modelParams;
ContextParams contextParams;
SamplingParams samplingParams;
ContextParams _contextParams;
SamplingParams _samplingParams;

Completer? _completer;

Expand All @@ -22,16 +24,33 @@ class LlamaNative {
_initModel();
}

set contextParams(ContextParams contextParams) {
_contextParams = contextParams;

_initContext();
}

set samplingParams(SamplingParams samplingParams) {
_samplingParams = samplingParams;

_initSampler();
}

LlamaNative({
required String modelPath,
ModelParams modelParams = const ModelParams(),
this.contextParams = const ContextParams(),
this.samplingParams = const SamplingParams()
}) : _modelPath = modelPath, _modelParams = modelParams {
ContextParams contextParams = const ContextParams(),
SamplingParams samplingParams = const SamplingParams()
}) : _modelPath = modelPath,
_modelParams = modelParams,
_contextParams = contextParams,
_samplingParams = samplingParams {
lib.ggml_backend_load_all();
lib.llama_backend_init();

_initModel();
_initContext();
_initSampler();
}

void _initModel() {
Expand All @@ -46,14 +65,36 @@ class LlamaNative {
nativeModelPath,
nativeModelParams
);
assert(_model != ffi.nullptr, 'Failed to load model');
}

void _initContext() {
final nativeContextParams = _contextParams.toNative();

if (_context != ffi.nullptr) {
lib.llama_free(_context);
}

_context = lib.llama_init_from_model(_model, nativeContextParams);
assert(_context != ffi.nullptr, 'Failed to initialize context');
}

void _initSampler() {
if (_sampler != ffi.nullptr) {
lib.llama_sampler_free(_sampler);
}

final vocab = lib.llama_model_get_vocab(_model);
_sampler = _samplingParams.toNative(vocab);
assert(_sampler != ffi.nullptr, 'Failed to initialize sampler');
}

Stream<String> prompt(List<ChatMessage> messages) {
Stream<String> prompt(List<ChatMessage> messages) async* {
assert(_model != ffi.nullptr, 'Model is not loaded');

_completer = Completer();

final nCtx = contextParams.nCtx ?? 2048;
final nCtx = lib.llama_n_ctx(_context);

ffi.Pointer<ffi.Char> formatted = calloc<ffi.Char>(nCtx);

Expand Down Expand Up @@ -86,20 +127,9 @@ class LlamaNative {

final prompt = formatted.cast<Utf8>().toDartString();

return _generate(prompt);
}

Stream<String> _generate(String prompt) async* {
final vocab = lib.llama_model_get_vocab(_model);
final sampler = samplingParams.toNative(vocab);
assert(sampler != ffi.nullptr, 'Failed to initialize sampler');

final nativeContextParams = contextParams.toNative();

final context = lib.llama_init_from_model(_model, nativeContextParams);
assert(context != ffi.nullptr, 'Failed to initialize context');

final isFirst = lib.llama_get_kv_cache_used_cells(context) == 0;
final isFirst = lib.llama_get_kv_cache_used_cells(_context) == 0;

final nPromptTokens = -lib.llama_tokenize(vocab, prompt.toNativeUtf8().cast<ffi.Char>(), prompt.length, ffi.nullptr, 0, isFirst, true);
ffi.Pointer<llama_token> promptTokens = calloc<llama_token>(nPromptTokens);
Expand All @@ -112,18 +142,17 @@ class LlamaNative {
int newTokenId;

while (!_completer!.isCompleted) {
final nCtx = lib.llama_n_ctx(context);
final nCtxUsed = lib.llama_get_kv_cache_used_cells(context);
final nCtxUsed = lib.llama_get_kv_cache_used_cells(_context);

if (nCtxUsed + batch.n_tokens > nCtx) {
throw LlamaException('Context size exceeded');
}

if (lib.llama_decode(context, batch) != 0) {
if (lib.llama_decode(_context, batch) != 0) {
throw LlamaException('Failed to decode');
}

newTokenId = lib.llama_sampler_sample(sampler, context, -1);
newTokenId = lib.llama_sampler_sample(_sampler, _context, -1);

// is it an end of generation?
if (lib.llama_vocab_is_eog(vocab, newTokenId)) {
Expand Down Expand Up @@ -153,8 +182,6 @@ class LlamaNative {
}

lib.llama_batch_free(batch);
lib.llama_sampler_free(sampler);
lib.llama_free(context);
calloc.free(promptTokens);
}

Expand All @@ -164,5 +191,7 @@ class LlamaNative {

void free() {
lib.llama_free_model(_model);
lib.llama_sampler_free(_sampler);
lib.llama_free(_context);
}
}

0 comments on commit 59593eb

Please sign in to comment.