Skip to content

Commit

Permalink
free memory
Browse files Browse the repository at this point in the history
  • Loading branch information
danemadsen committed Feb 10, 2025
1 parent 34f408c commit 751d8d0
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 102 deletions.
1 change: 0 additions & 1 deletion lib/llama.dart
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ part 'src/llama_isolated.dart';
part 'src/llama_exception.dart';
part 'src/llama.dart';
part 'src/llama_native.dart';
part 'src/llama_isolate_entry.dart';
part 'src/model_params.dart';
part 'src/chat_message.dart';
part 'src/context_params.dart';
Expand Down
11 changes: 11 additions & 0 deletions lib/src/chat_message.dart
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,15 @@ extension ChatMessages on List<ChatMessage> {

return messages;
}
}

extension LlamaChatMessagePtrExtension on ffi.Pointer<llama_chat_message> {
void free(int length) {
for (var i = 0; i < length; i++) {
calloc.free(this[i].role);
calloc.free(this[i].content);
}

calloc.free(this);
}
}
2 changes: 2 additions & 0 deletions lib/src/llama.dart
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,6 @@ abstract interface class Llama {
Stream<String> prompt(List<ChatMessage> messages);

void stop();

void free();
}
78 changes: 0 additions & 78 deletions lib/src/llama_isolate_entry.dart

This file was deleted.

99 changes: 89 additions & 10 deletions lib/src/llama_isolated.dart
Original file line number Diff line number Diff line change
@@ -1,10 +1,81 @@
part of '../llama.dart';

typedef StringResponse = (
bool error,
String message
typedef _IsolateArguments = ({
ModelParams modelParams,
ContextParams contextParams,
SamplingParams samplingParams,
SendPort sendPort
});

extension _IsolateArgumentsExtension on _IsolateArguments {
_SerializableIsolateArguments get toSerializable => (
modelParams.toJson(),
contextParams.toJson(),
samplingParams.toJson(),
sendPort
);
}

typedef _SerializableIsolateArguments = (
String,
String,
String,
SendPort
);

extension _SerializableIsolateArgumentsExtension on _SerializableIsolateArguments {
ModelParams get modelParams => ModelParams.fromJson(this.$1);

ContextParams get contextParams => ContextParams.fromJson(this.$2);

SamplingParams get samplingParams => SamplingParams.fromJson(this.$3);

SendPort get sendPort => this.$4;
}

void _isolateEntry(_SerializableIsolateArguments args) async {
final SendPort sendPort = args.sendPort;
final LlamaNative llamaCppNative;

try {
final receivePort = ReceivePort();
sendPort.send(receivePort.sendPort);

llamaCppNative = LlamaNative(
modelParams: args.modelParams,
contextParams: args.contextParams,
samplingParams: args.samplingParams
);

await for (final data in receivePort) {
if (data is List<ChatMessageRecord>) {
final messages = ChatMessages.fromRecords(data);
final stream = llamaCppNative.prompt(messages);

await for (final response in stream) {
sendPort.send(response);
}

sendPort.send(false);
}
else if (data is bool) {
if (data) {
llamaCppNative.free();
return;
}
else {
llamaCppNative.stop();
}

sendPort.send(data);
}
}
}
catch (e) {
log('LlamaIsolateEntry: $e');
}
}

class LlamaIsolated implements Llama {
final Completer _initialized = Completer();
StreamController<String> _responseController = StreamController<String>()..close();
Expand All @@ -28,19 +99,20 @@ class LlamaIsolated implements Llama {
sendPort: receivePort.sendPort
);

await Isolate.spawn(LlamaIsolateEntry.entry, isolateParams.toSerializable);
await Isolate.spawn(_isolateEntry, isolateParams.toSerializable);

await for (var data in receivePort) {
await for (final data in receivePort) {
if (data is String) {
_responseController.add(data);
}
else if (data is SendPort) {
_sendPort = data;
_initialized.complete();
}
else if (data == null) {
else if (data is bool) {
_responseController.close();
break;

if (data) return;
}
}
}
Expand All @@ -58,8 +130,6 @@ class LlamaIsolated implements Llama {
await for (final response in _responseController.stream) {
yield response;
}

_responseController.close();
}

@override
Expand All @@ -68,6 +138,15 @@ class LlamaIsolated implements Llama {
await _initialized.future;
}

_sendPort!.send(null);
_sendPort!.send(false);
}

@override
void free() async {
if (!_initialized.isCompleted) {
await _initialized.future;
}

_sendPort!.send(true);
}
}
45 changes: 32 additions & 13 deletions lib/src/llama_native.dart
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class LlamaNative implements Llama {
}

void _initContext() {
assert(_model != ffi.nullptr, 'Model is not loaded');

final nativeContextParams = _contextParams.toNative();

if (_context != ffi.nullptr) {
Expand All @@ -74,6 +76,8 @@ class LlamaNative implements Llama {
}

void _initSampler() {
assert(_model != ffi.nullptr, 'Model is not loaded');

if (_sampler != ffi.nullptr) {
Llama.lib.llama_sampler_free(_sampler);
}
Expand All @@ -99,45 +103,55 @@ class LlamaNative implements Llama {

final template = Llama.lib.llama_model_chat_template(_model, ffi.nullptr);

ffi.Pointer<llama_chat_message> messagesPtr = messagesCopy.toNative();

int newContextLength = Llama.lib.llama_chat_apply_template(
template,
messagesCopy.toNative(),
messagesPtr,
messagesCopy.length,
true,
formatted,
nCtx
);

if (newContextLength > nCtx) {
calloc.free(formatted);
formatted = calloc<ffi.Char>(newContextLength);
newContextLength = Llama.lib.llama_chat_apply_template(
template,
messagesCopy.toNative(),
messagesPtr,
messagesCopy.length,
true,
formatted,
newContextLength
);
}

messagesPtr.free(messagesCopy.length);

if (newContextLength < 0) {
throw Exception('Failed to apply template');
}

final prompt = formatted.cast<Utf8>().toDartString().substring(_contextLength);
calloc.free(formatted);

final vocab = Llama.lib.llama_model_get_vocab(_model);
final isFirst = Llama.lib.llama_get_kv_cache_used_cells(_context) == 0;

final nPromptTokens = -Llama.lib.llama_tokenize(vocab, prompt.toNativeUtf8().cast<ffi.Char>(), prompt.length, ffi.nullptr, 0, isFirst, true);
final promptPtr = prompt.toNativeUtf8().cast<ffi.Char>();

final nPromptTokens = -Llama.lib.llama_tokenize(vocab, promptPtr, prompt.length, ffi.nullptr, 0, isFirst, true);
ffi.Pointer<llama_token> promptTokens = calloc<llama_token>(nPromptTokens);

if (Llama.lib.llama_tokenize(vocab, prompt.toNativeUtf8().cast<ffi.Char>(), prompt.length, promptTokens, nPromptTokens, isFirst, true) < 0) {
if (Llama.lib.llama_tokenize(vocab, promptPtr, prompt.length, promptTokens, nPromptTokens, isFirst, true) < 0) {
throw Exception('Failed to tokenize');
}

calloc.free(promptPtr);

llama_batch batch = Llama.lib.llama_batch_get_one(promptTokens, nPromptTokens);
int newTokenId;
ffi.Pointer<llama_token> newTokenId = calloc<llama_token>(1);

String response = '';

Expand All @@ -153,49 +167,54 @@ class LlamaNative implements Llama {
throw Exception('Failed to decode');
}

newTokenId = Llama.lib.llama_sampler_sample(_sampler, _context, -1);
newTokenId.value = Llama.lib.llama_sampler_sample(_sampler, _context, -1);

// is it an end of generation?
if (Llama.lib.llama_vocab_is_eog(vocab, newTokenId)) {
if (Llama.lib.llama_vocab_is_eog(vocab, newTokenId.value)) {
break;
}

final buffer = calloc<ffi.Char>(256);
final n = Llama.lib.llama_token_to_piece(vocab, newTokenId, buffer, 256, 0, true);
final n = Llama.lib.llama_token_to_piece(vocab, newTokenId.value, buffer, 256, 0, true);
if (n < 0) {
throw Exception('Failed to convert token to piece');
}

final piece = buffer.cast<Utf8>().toDartString();
calloc.free(buffer);
response += piece;
yield piece;

final newTokenPointer = calloc<llama_token>(1);
newTokenPointer.value = newTokenId;

batch = Llama.lib.llama_batch_get_one(newTokenPointer, 1);
batch = Llama.lib.llama_batch_get_one(newTokenId, 1);
}

messagesCopy.add(ChatMessage(
role: 'assistant',
content: response
));

messagesPtr = messagesCopy.toNative();

_contextLength = Llama.lib.llama_chat_apply_template(
template,
messagesCopy.toNative(),
messagesPtr,
messagesCopy.length,
false,
ffi.nullptr,
0
);

messagesPtr.free(messagesCopy.length);
calloc.free(promptTokens);
Llama.lib.llama_batch_free(batch);
}

@override
void stop() {
_completer?.complete();
}

@override
void free() {
Llama.lib.llama_free_model(_model);
Llama.lib.llama_sampler_free(_sampler);
Expand Down

0 comments on commit 751d8d0

Please sign in to comment.