From dbc850825a3b56f9508d209a22fbeedc67d23165 Mon Sep 17 00:00:00 2001 From: Mike Depinet Date: Mon, 7 Oct 2024 11:26:43 -0700 Subject: [PATCH] Client implemented tools (#9) --- .github/workflows/test.yaml | 2 + .gitignore | 3 + CHANGELOG.md | 6 +- example/lib/main.dart | 18 +++++ example/pubspec.lock | 2 +- lib/src/session.dart | 99 +++++++++++++++++++++++++-- pubspec.yaml | 4 +- test/fake_lk.dart | 74 +++++++++++++++++++++ test/ultravox_client_test.dart | 118 ++++++++++++++++++++++++++++++++- 9 files changed, 317 insertions(+), 9 deletions(-) create mode 100644 test/fake_lk.dart diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 02ebbbb..23da136 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -15,6 +15,8 @@ jobs: channel: 'stable' - name: install deps run: flutter pub get + - name: generate mocks + run: dart run build_runner build - name: format run: dart format lib/ test/ --set-exit-if-changed - name: check diff --git a/.gitignore b/.gitignore index 9a6b1d9..1087197 100644 --- a/.gitignore +++ b/.gitignore @@ -44,3 +44,6 @@ migrate_working_dir/ *.iws .idea/ .vscode/ + +# Generated mocks +test/**/*.mocks.dart diff --git a/CHANGELOG.md b/CHANGELOG.md index d6e4c35..bf14d3e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,4 +15,8 @@ ## 0.0.4 * Changed implementation of mute/unmute. It's now `micMuted` and `speakerMuted` -* Added functions for toggling mute of mic (`toggleMicMuted()`) and speaker (`toggleSpeakerMuted()`) \ No newline at end of file +* Added functions for toggling mute of mic (`toggleMicMuted()`) and speaker (`toggleSpeakerMuted()`) + +## 0.0.5 + +* Add client-implemented tools diff --git a/example/lib/main.dart b/example/lib/main.dart index 4843069..c46b2bc 100644 --- a/example/lib/main.dart +++ b/example/lib/main.dart @@ -1,4 +1,5 @@ import 'dart:async'; +import 'dart:convert'; import 'package:flutter/material.dart'; import 'package:ultravox_client/ultravox_client.dart'; @@ -70,9 +71,26 @@ class _MyHomePageState extends State { UltravoxSession.create(experimentalMessages: _debug ? {"debug"} : {}); }); _session!.statusNotifier.addListener(_onStatusChange); + _session!.registerToolImplementation("getSecretMenu", _getSecretMenu); await _session!.joinCall(joinUrl); } + ClientToolResult _getSecretMenu(Object params) { + return ClientToolResult(json.encode({ + "date": DateTime.now().toIso8601String(), + "specialItems": [ + { + "name": "Banana smoothie", + "price": 3.99, + }, + { + "name": "Butter pecan ice cream (one scoop)", + "price": 1.99, + } + ], + })); + } + Future _endCall() async { if (_session == null) { return; diff --git a/example/pubspec.lock b/example/pubspec.lock index 686b46c..316f65d 100644 --- a/example/pubspec.lock +++ b/example/pubspec.lock @@ -475,7 +475,7 @@ packages: path: ".." relative: true source: path - version: "0.0.3" + version: "0.0.4" uuid: dependency: transitive description: diff --git a/lib/src/session.dart b/lib/src/session.dart index 84d3ae5..8da154e 100644 --- a/lib/src/session.dart +++ b/lib/src/session.dart @@ -1,3 +1,5 @@ +import 'dart:async'; + import 'package:flutter/material.dart'; import 'package:livekit_client/livekit_client.dart' as lk; import 'package:web_socket_channel/web_socket_channel.dart'; @@ -91,6 +93,34 @@ class Transcripts extends ChangeNotifier { } } +/// The result type returned by a ClientToolImplementation. +class ClientToolResult { + /// The result of the client tool. + /// + /// This is exactly the string that will be seen by the model. Often JSON. + final String result; + + /// The type of response the tool is providing. + /// + /// Most tools simply provide information back to the model, in which case + /// responseType need not be set. For other tools that are instead interpreted + /// by the server to affect the call, responseType may be set to indicate how + /// the call should be altered. In this case, [result] should be JSON with + /// instructions for the server. The schema depends on the response type. + /// See https://docs.ultravox.ai/tools for more information. + final String? responseType; + + ClientToolResult(this.result, {this.responseType}); +} + +/// A function that fulfills a client-implemented tool. +/// +/// The function should take an object containing the tool's parameters (parsed +/// from JSON) and return a [ClientToolResult] object. It may or may not be +/// asynchronous. +typedef ClientToolImplementation = FutureOr Function( + Object data); + /// Manages a single session with Ultravox. /// /// In addition to providing methods to manage a call, [UltravoxSession] exposes @@ -180,6 +210,7 @@ class UltravoxSession { final lk.Room _room; final lk.EventsListener _listener; late WebSocketChannel _wsChannel; + final _registeredTools = {}; UltravoxSession(this._room, this._experimentalMessages) : _listener = _room.createListener(); @@ -187,6 +218,21 @@ class UltravoxSession { UltravoxSession.create({Set? experimentalMessages}) : this(lk.Room(), experimentalMessages ?? {}); + /// Registers a client tool implementation using the given name. + /// + /// If the call is started with a client-implemented tool, this implementation + /// will be invoked when the model calls the tool. + /// See https://docs.ultravox.ai/tools for more information. + void registerToolImplementation(String name, ClientToolImplementation impl) { + _registeredTools[name] = impl; + } + + /// Convenience batch wrapper for [registerToolImplementation]. + void registerToolImplementations( + Map implementations) { + implementations.forEach(registerToolImplementation); + } + /// Connects to a call using the given [joinUrl]. Future joinCall(String joinUrl) async { if (status != UltravoxSessionStatus.disconnected) { @@ -204,7 +250,7 @@ class UltravoxSession { _wsChannel = WebSocketChannel.connect(url); await _wsChannel.ready; _wsChannel.stream.listen((event) async { - await _handleSocketMessage(event); + await handleSocketMessage(event); }); } @@ -219,8 +265,7 @@ class UltravoxSession { throw Exception( 'Cannot send text while not connected. Current status: $status'); } - final message = jsonEncode({'type': 'input_text_message', 'text': text}); - _room.localParticipant?.publishData(utf8.encode(message), reliable: true); + await _sendData({'type': 'input_text_message', 'text': text}); } Future _disconnect() async { @@ -235,7 +280,8 @@ class UltravoxSession { statusNotifier.value = UltravoxSessionStatus.disconnected; } - Future _handleSocketMessage(dynamic event) async { + @visibleForTesting + Future handleSocketMessage(dynamic event) async { if (event is! String) { throw Exception('Received unexpected message from socket'); } @@ -259,7 +305,7 @@ class UltravoxSession { await _room.startAudio(); } - void _handleDataMessage(lk.DataReceivedEvent event) { + Future _handleDataMessage(lk.DataReceivedEvent event) async { final data = jsonDecode(utf8.decode(event.data)); switch (data['type']) { case 'state': @@ -315,10 +361,53 @@ class UltravoxSession { } } break; + case 'client_tool_invocation': + await _invokeClientTool(data['toolName'] as String, + data['invocationId'] as String, data['parameters'] as Object); default: if (_experimentalMessages.isNotEmpty) { experimentalMessageNotifier.value = data as Map; } } } + + Future _invokeClientTool( + String toolName, String invocationId, Object parameters) async { + final tool = _registeredTools[toolName]; + if (tool == null) { + await _sendData({ + 'type': 'client_tool_result', + 'invocationId': invocationId, + 'errorType': 'undefined', + 'errorMessage': + 'Client tool $toolName is not registered (Flutter client)', + }); + return; + } + try { + final result = await tool(parameters); + final data = { + 'type': 'client_tool_result', + 'invocationId': invocationId, + 'result': result.result, + }; + if (result.responseType != null) { + data['responseType'] = result.responseType!; + } + await _sendData(data); + } catch (e) { + await _sendData({ + 'type': 'client_tool_result', + 'invocationId': invocationId, + 'errorType': 'implementation-error', + 'errorMessage': e.toString(), + }); + } + } + + Future _sendData(Object data) async { + final message = jsonEncode(data); + await _room.localParticipant + ?.publishData(utf8.encode(message), reliable: true); + } } diff --git a/pubspec.yaml b/pubspec.yaml index fad4200..e269925 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -1,6 +1,6 @@ name: ultravox_client description: "Flutter client SDK for Ultravox." -version: 0.0.4 +version: 0.0.5 homepage: https://ultravox.ai repository: https://github.com/fixie-ai/ultravox-client-sdk-flutter topics: @@ -30,4 +30,6 @@ dev_dependencies: flutter_test: sdk: flutter flutter_lints: ^4.0.0 + mockito: ^5.4.4 + build_runner: ^2.4.13 diff --git a/test/fake_lk.dart b/test/fake_lk.dart new file mode 100644 index 0000000..25cd2e0 --- /dev/null +++ b/test/fake_lk.dart @@ -0,0 +1,74 @@ +import 'dart:async'; +import 'dart:collection'; + +import 'package:flutter_test/flutter_test.dart'; +import 'package:mockito/annotations.dart'; +import 'package:livekit_client/livekit_client.dart' as lk; + +@GenerateNiceMocks([ + MockSpec(), + MockSpec(), +]) +import 'fake_lk.mocks.dart'; + +class FakeRoomEvents extends Fake implements lk.EventsListener { + final _listeners = Function(lk.RoomEvent)>[]; + + @override + lk.CancelListenFunc listen(FutureOr Function(lk.RoomEvent) onEvent) { + _listeners.add(onEvent); + return () {}; + } + + @override // Copied from real implementation. + lk.CancelListenFunc on( + FutureOr Function(E) then, { + bool Function(E)? filter, + }) { + return listen((event) async { + // event must be E + if (event is! E) return; + // filter must be true (if filter is used) + if (filter != null && !filter(event as E)) return; + // cast to E + await then(event as E); + }); + } + + void emit(lk.RoomEvent event) { + for (final listener in _listeners) { + listener(event); + } + } +} + +class FakeRoom extends Fake implements lk.Room { + final _events = FakeRoomEvents(); + + @override + lk.EventsListener createListener({bool synchronized = false}) { + return _events; + } + + @override + Future connect( + String url, + String token, { + lk.ConnectOptions? connectOptions, + @Deprecated('deprecated, please use roomOptions in Room constructor') + lk.RoomOptions? roomOptions, + lk.FastConnectOptions? fastConnectOptions, + }) async {} + + @override + UnmodifiableMapView get remoteParticipants => + UnmodifiableMapView({"remote": remoteParticipant}); + final remoteParticipant = MockRemoteParticipant(); + + @override + final MockLocalParticipant localParticipant = MockLocalParticipant(); + + void emit(lk.RoomEvent event) { + _events.emit(event); + } +} diff --git a/test/ultravox_client_test.dart b/test/ultravox_client_test.dart index 3d4c49c..36c9981 100644 --- a/test/ultravox_client_test.dart +++ b/test/ultravox_client_test.dart @@ -1,12 +1,21 @@ +import 'dart:async'; + import 'package:flutter_test/flutter_test.dart'; import 'package:ultravox_client/ultravox_client.dart'; +import 'package:livekit_client/livekit_client.dart' as lk; +import 'package:mockito/mockito.dart'; +import 'dart:convert'; + +import 'fake_lk.dart'; void main() { group('UltravoxSession mute tests', () { + late FakeRoom room; late UltravoxSession session; setUp(() { - session = UltravoxSession.create(); + room = FakeRoom(); + session = UltravoxSession(room, {}); }); test('micMuted getter and setter', () { @@ -48,5 +57,112 @@ void main() { expect(session.speakerMuted, false); expect(speakerMuteCounter, 2); }); + + group('client tool implementations', () { + invokeTool( + FutureOr Function(Object params) impl) async { + session.registerToolImplementation("test-tool", impl); + await session.handleSocketMessage(json.encode({ + "type": "room_info", + "roomUrl": "wss://test-room", + "token": "test-token" + })); + final data = { + "type": "client_tool_invocation", + "toolName": "test-tool", + "invocationId": "call_1", + "parameters": {"foo": "bar"} + }; + room.emit(lk.DataReceivedEvent( + participant: room.remoteParticipant, + data: utf8.encode(json.encode(data)), + topic: null)); + await Future.delayed(const Duration(milliseconds: 1)); + } + + test('basic', () async { + ClientToolResult impl(Object params) { + expect(params, {"foo": "bar"}); + return ClientToolResult("baz"); + } + + await invokeTool(impl); + + final sentData = verify( + room.localParticipant.publishData(captureAny, reliable: true)) + .captured + .single; + final sentJson = json.decode(utf8.decode(sentData as List)); + expect(sentJson, { + "type": "client_tool_result", + "invocationId": "call_1", + "result": "baz" + }); + }); + + test('async tool', () async { + Future impl(Object params) async { + expect(params, {"foo": "bar"}); + await Future.delayed(Duration.zero); + return ClientToolResult("baz"); + } + + await invokeTool(impl); + + final sentData = verify( + room.localParticipant.publishData(captureAny, reliable: true)) + .captured + .single; + final sentJson = json.decode(utf8.decode(sentData as List)); + expect(sentJson, { + "type": "client_tool_result", + "invocationId": "call_1", + "result": "baz" + }); + }); + + test('setting response type', () async { + ClientToolResult impl(Object params) { + expect(params, {"foo": "bar"}); + return ClientToolResult('{"strict": true}', responseType: "hang-up"); + } + + await invokeTool(impl); + + final sentData = verify( + room.localParticipant.publishData(captureAny, reliable: true)) + .captured + .single; + final sentJson = json.decode(utf8.decode(sentData as List)); + expect(sentJson, { + "type": "client_tool_result", + "invocationId": "call_1", + "result": '{"strict": true}', + "responseType": "hang-up" + }); + }); + + test('error', () async { + final testError = Exception("test error"); + ClientToolResult impl(Object params) { + expect(params, {"foo": "bar"}); + throw testError; + } + + await invokeTool(impl); + + final sentData = verify( + room.localParticipant.publishData(captureAny, reliable: true)) + .captured + .single; + final sentJson = json.decode(utf8.decode(sentData as List)); + expect(sentJson, { + "type": "client_tool_result", + "invocationId": "call_1", + "errorType": "implementation-error", + "errorMessage": testError.toString(), + }); + }); + }); }); }