-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathGemini_Voice_Assistant.py
99 lines (88 loc) · 3.5 KB
/
Gemini_Voice_Assistant.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import asyncio
import base64
import json
import os
import pyaudio
from websockets.asyncio.client import connect
class GeminiVoiceAssistant:
def __init__(self):
self._audio_queue = asyncio.Queue()
self._api_key = os.environ.get("GEMINI_API_KEY")
self._model = "gemini-2.0-flash-exp"
self._uri = f"wss://generativelanguage.googleapis.com/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={self._api_key}"
# Audio settings
self._FORMAT = pyaudio.paInt16
self._CHANNELS = 1
self._CHUNK = 512
self._RATE = 16000
async def _connect_to_gemini(self):
return await connect(
self._uri, additional_headers={"Content-Type": "application/json"}
)
async def _start_audio_streaming(self):
async with asyncio.TaskGroup() as tg:
tg.create_task(self._capture_audio())
tg.create_task(self._stream_audio())
tg.create_task(self._play_response())
async def _capture_audio(self):
audio = pyaudio.PyAudio()
stream = audio.open(
format=self._FORMAT,
channels=self._CHANNELS,
rate=self._RATE,
input=True,
frames_per_buffer=self._CHUNK,
)
while True:
data = await asyncio.to_thread(stream.read, self._CHUNK)
await self._ws.send(
json.dumps(
{
"realtime_input": {
"media_chunks": [
{
"data": base64.b64encode(data).decode(),
"mime_type": "audio/pcm",
}
]
}
}
)
)
async def _stream_audio(self):
async for msg in self._ws:
response = json.loads(msg)
try:
audio_data = response["serverContent"]["modelTurn"]["parts"][0][
"inlineData"
]["data"]
self._audio_queue.put_nowait(base64.b64decode(audio_data))
except KeyError:
pass
try:
turn_complete = response["serverContent"]["turnComplete"]
except KeyError:
pass
else:
if turn_complete:
# If you interrupt the model, it sends an end_of_turn. For interruptions to work, we need to empty out the audio queue
print("\nEnd of turn")
while not self._audio_queue.empty():
self._audio_queue.get_nowait()
async def _play_response(self):
audio = pyaudio.PyAudio()
stream = audio.open(
format=self._FORMAT, channels=self._CHANNELS, rate=24000, output=True
)
while True:
data = await self._audio_queue.get()
await asyncio.to_thread(stream.write, data)
async def start(self):
self._ws = await self._connect_to_gemini()
await self._ws.send(json.dumps({"setup": {"model": f"models/{self._model}"}}))
await self._ws.recv(decode=False)
print("Connected to Gemini, You can start talking now")
await self._start_audio_streaming()
if __name__ == "__main__":
client = GeminiVoiceAssistant()
asyncio.run(client.start())