Skip to content

Commit

Permalink
chat message
Browse files Browse the repository at this point in the history
  • Loading branch information
danemadsen committed Feb 10, 2025
1 parent 94085e0 commit 63403b5
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 30 deletions.
4 changes: 2 additions & 2 deletions example/lib/main.dart
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,13 @@ class _LlamaAppState extends State<LlamaApp> {
}

setState(() {
_messages.add(ChatMessage(role: 'user', content: value));
_messages.add(UserChatMessage(value));
_controller.clear();
});

final stream = _model!.prompt(_messages.copy());

_messages.add(ChatMessage(role: 'assistant', content: ''));
_messages.add(AssistantChatMessage(''));

await for (var response in stream) {
setState(() {
Expand Down
79 changes: 58 additions & 21 deletions lib/src/chat_message.dart
Original file line number Diff line number Diff line change
@@ -1,26 +1,42 @@
part of '../lcpp.dart';

typedef ChatMessageRecord = (
typedef _ChatMessageRecord = (
String role,
String content
);

class ChatMessage {
final String role;
abstract class ChatMessage {
String get role;

String content;

ChatMessage({
required this.role,
required this.content,
});
ChatMessage(this.content);

factory ChatMessage.withRole({
required String role,
required String content,
}) {
switch (role) {
case 'user':
return UserChatMessage(content);
case 'assistant':
return AssistantChatMessage(content);
case 'system':
return SystemChatMessage(content);
default:
throw ArgumentError('Invalid role: $role');
}
}

ChatMessage.fromRecord(ChatMessageRecord record)
: role = record.$1,
content = record.$2;
factory ChatMessage._fromRecord(_ChatMessageRecord record) => ChatMessage.withRole(
role: record.$1,
content: record.$2
);

ChatMessage.fromNative(llama_chat_message message)
: role = message.role.cast<Utf8>().toDartString(),
content = message.content.cast<Utf8>().toDartString();
factory ChatMessage.fromNative(llama_chat_message message) => ChatMessage.withRole(
role: message.role.cast<Utf8>().toDartString(),
content: message.content.cast<Utf8>().toDartString()
);

llama_chat_message toNative() {
final message = calloc<llama_chat_message>();
Expand All @@ -30,18 +46,39 @@ class ChatMessage {
return message.ref;
}

ChatMessageRecord toRecord() => (
_ChatMessageRecord _toRecord() => (
role,
content
);
}

class UserChatMessage extends ChatMessage {
@override
String get role => 'user';

UserChatMessage(super.content);
}

class AssistantChatMessage extends ChatMessage {
@override
String get role => 'assistant';

AssistantChatMessage(super.content);
}

class SystemChatMessage extends ChatMessage {
@override
String get role => 'system';

SystemChatMessage(super.content);
}

extension ChatMessages on List<ChatMessage> {
static List<ChatMessage> fromRecords(List<ChatMessageRecord> records) {
static List<ChatMessage> _fromRecords(List<_ChatMessageRecord> records) {
final List<ChatMessage> messages = [];

for (var record in records) {
messages.add(ChatMessage.fromRecord(record));
messages.add(ChatMessage._fromRecord(record));
}

return messages;
Expand All @@ -57,11 +94,11 @@ extension ChatMessages on List<ChatMessage> {
return messages;
}

List<ChatMessageRecord> toRecords() {
final List<ChatMessageRecord> records = [];
List<_ChatMessageRecord> _toRecords() {
final List<_ChatMessageRecord> records = [];

for (var i = 0; i < length; i++) {
records.add(this[i].toRecord());
records.add(this[i]._toRecord());
}

return records;
Expand All @@ -71,14 +108,14 @@ extension ChatMessages on List<ChatMessage> {
final List<ChatMessage> messages = [];

for (var message in this) {
messages.add(ChatMessage(role: message.role, content: message.content));
messages.add(ChatMessage.withRole(role: message.role, content: message.content));
}

return messages;
}
}

extension LlamaChatMessagePtrExtension on ffi.Pointer<llama_chat_message> {
extension _LlamaChatMessagePtrExtension on ffi.Pointer<llama_chat_message> {
void free(int length) {
for (var i = 0; i < length; i++) {
calloc.free(this[i].role);
Expand Down
6 changes: 3 additions & 3 deletions lib/src/llama_isolated.dart
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ void _isolateEntry(_SerializableIsolateArguments args) async {
);

await for (final data in receivePort) {
if (data is List<ChatMessageRecord>) {
final messages = ChatMessages.fromRecords(data);
if (data is List<_ChatMessageRecord>) {
final messages = ChatMessages._fromRecords(data);
final stream = llamaCppNative.prompt(messages);

await for (final response in stream) {
Expand Down Expand Up @@ -125,7 +125,7 @@ class LlamaIsolated implements Llama {

_responseController = StreamController<String>();

_sendPort!.send(messages.toRecords());
_sendPort!.send(messages._toRecords());

await for (final response in _responseController.stream) {
yield response;
Expand Down
5 changes: 1 addition & 4 deletions lib/src/llama_native.dart
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,7 @@ class LlamaNative implements Llama {
batch = Llama.lib.llama_batch_get_one(newTokenId, 1);
}

messagesCopy.add(ChatMessage(
role: 'assistant',
content: response
));
messagesCopy.add(AssistantChatMessage(response));

messagesPtr = messagesCopy.toNative();

Expand Down

0 comments on commit 63403b5

Please sign in to comment.