-
Notifications
You must be signed in to change notification settings - Fork 13
/
main.py
623 lines (562 loc) · 26.4 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
import asyncio
import os
import logging
import shelve
import datetime
import time
import traceback
import hashlib
import base64
import copy
from collections import defaultdict
from richtext import RichText
import openai
from telethon import TelegramClient, events, errors, functions, types
import signal
def debug_signal_handler(signal, frame):
breakpoint()
signal.signal(signal.SIGUSR1, debug_signal_handler)
ADMIN_ID = 71863318
GPT_35_PROMPT = 'You are ChatGPT, a large language model trained by OpenAI, based on the GPT-3.5 architecture.\nKnowledge cutoff: 2021-09\nCurrent date: {current_date}'
GPT_4_PROMPT = 'You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture.\nKnowledge cutoff: 2021-09\nCurrent date: {current_date}'
GPT_4_PROMPT_2 = 'You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture.\nKnowledge cutoff: 2023-04\nCurrent date: {current_date}'
GPT_4_TURBO_PROMPT = 'You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture.\nKnowledge cutoff: 2023-12\nCurrent date: {current_date}\n\nImage input capabilities: Enabled\nPersonality: v2'
GPT_4O_PROMPT = 'You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture.\nKnowledge cutoff: 2023-10\nCurrent date: {current_date}\n\nImage input capabilities: Enabled\nPersonality: v2'
MODELS = [
{'prefix': '$', 'model': 'gpt-4o-2024-08-06', 'prompt_template': GPT_4O_PROMPT},
{'prefix': '4o$', 'model': 'gpt-4o-2024-08-06', 'prompt_template': GPT_4O_PROMPT},
{'prefix': '4om$', 'model': 'gpt-4o-mini-2024-07-18', 'prompt_template': GPT_4O_PROMPT},
{'prefix': '4$', 'model': 'gpt-4-turbo-2024-04-09', 'prompt_template': GPT_4_TURBO_PROMPT},
{'prefix': '3$', 'model': 'gpt-3.5-turbo-0125', 'prompt_template': GPT_35_PROMPT},
{'prefix': 'gpt-4o-mini-2024-07-18$', 'model': 'gpt-4o-mini-2024-07-18', 'prompt_template': GPT_4O_PROMPT},
{'prefix': 'gpt-4o-mini$', 'model': 'gpt-4o-mini', 'prompt_template': GPT_4O_PROMPT},
{'prefix': 'gpt-4o-2024-05-13$', 'model': 'gpt-4o-2024-05-13', 'prompt_template': GPT_4O_PROMPT},
{'prefix': 'gpt-4o-2024-08-06$', 'model': 'gpt-4o-2024-08-06', 'prompt_template': GPT_4O_PROMPT},
{'prefix': 'gpt-4o$', 'model': 'gpt-4o', 'prompt_template': GPT_4O_PROMPT},
{'prefix': 'gpt-4-turbo-2024-04-09$', 'model': 'gpt-4-turbo-2024-04-09', 'prompt_template': GPT_4_TURBO_PROMPT},
{'prefix': 'gpt-4-0125-preview$', 'model': 'gpt-4-0125-preview', 'prompt_template': GPT_4_TURBO_PROMPT},
{'prefix': 'gpt-4-1106-preview$', 'model': 'gpt-4-1106-preview', 'prompt_template': GPT_4_PROMPT_2},
{'prefix': 'gpt-4-1106-vision-preview$', 'model': 'gpt-4-1106-vision-preview', 'prompt_template': GPT_4_PROMPT_2},
{'prefix': 'gpt-4-0613$', 'model': 'gpt-4-0613', 'prompt_template': GPT_4_PROMPT},
{'prefix': 'gpt-4-32k-0613$', 'model': 'gpt-4-32k-0613', 'prompt_template': GPT_4_PROMPT},
{'prefix': 'gpt-3.5-turbo-0125$', 'model': 'gpt-3.5-turbo-0125', 'prompt_template': GPT_35_PROMPT},
{'prefix': 'gpt-3.5-turbo-1106$', 'model': 'gpt-3.5-turbo-1106', 'prompt_template': GPT_35_PROMPT},
{'prefix': 'gpt-3.5-turbo-0613$', 'model': 'gpt-3.5-turbo-0613', 'prompt_template': GPT_35_PROMPT},
{'prefix': 'gpt-3.5-turbo-16k-0613$', 'model': 'gpt-3.5-turbo-16k-0613', 'prompt_template': GPT_35_PROMPT},
{'prefix': 'gpt-3.5-turbo-0301$', 'model': 'gpt-3.5-turbo-0301', 'prompt_template': GPT_35_PROMPT},
]
DEFAULT_MODEL = 'gpt-4-0613' # For compatibility with the old database format
def get_prompt(model):
for m in MODELS:
if m['model'] == model:
return m['prompt_template'].replace('{current_date}', (datetime.datetime.now(datetime.UTC) + datetime.timedelta(hours=8)).strftime('%Y-%m-%d'))
raise ValueError('Model not found')
aclient = openai.AsyncOpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
max_retries=0,
timeout=15,
)
TELEGRAM_BOT_TOKEN = os.getenv("TELEGRAM_BOT_TOKEN")
TELEGRAM_API_ID = int(os.getenv("TELEGRAM_API_ID"))
TELEGRAM_API_HASH = os.getenv("TELEGRAM_API_HASH")
TELEGRAM_LENGTH_LIMIT = 4096
TELEGRAM_MIN_INTERVAL = 3
OPENAI_MAX_RETRY = 3
OPENAI_RETRY_INTERVAL = 10
FIRST_BATCH_DELAY = 1
TEXT_FILE_SIZE_LIMIT = 1_000_000
TRIGGERS_LIMIT = 20
telegram_last_timestamp = defaultdict(lambda: None)
telegram_rate_limit_lock = defaultdict(asyncio.Lock)
class PendingReplyManager:
def __init__(self):
self.messages = {}
def add(self, reply_id):
assert reply_id not in self.messages
self.messages[reply_id] = asyncio.Event()
def remove(self, reply_id):
if reply_id not in self.messages:
return
self.messages[reply_id].set()
del self.messages[reply_id]
async def wait_for(self, reply_id):
if reply_id not in self.messages:
return
logging.info('PendingReplyManager waiting for %r', reply_id)
await self.messages[reply_id].wait()
logging.info('PendingReplyManager waiting for %r finished', reply_id)
def within_interval(chat_id):
if telegram_rate_limit_lock[chat_id].locked():
return True
global telegram_last_timestamp
if telegram_last_timestamp[chat_id] is None:
return False
remaining_time = telegram_last_timestamp[chat_id] + TELEGRAM_MIN_INTERVAL - time.time()
return remaining_time > 0
def ensure_interval(interval=TELEGRAM_MIN_INTERVAL):
def decorator(func):
async def new_func(*args, **kwargs):
chat_id = args[0]
async with telegram_rate_limit_lock[chat_id]:
global telegram_last_timestamp
if telegram_last_timestamp[chat_id] is not None:
remaining_time = telegram_last_timestamp[chat_id] + interval - time.time()
if remaining_time > 0:
await asyncio.sleep(remaining_time)
result = await func(*args, **kwargs)
telegram_last_timestamp[chat_id] = time.time()
return result
return new_func
return decorator
def retry(max_retry=30, interval=10):
def decorator(func):
async def new_func(*args, **kwargs):
for _ in range(max_retry - 1):
try:
return await func(*args, **kwargs)
except errors.FloodWaitError as e:
logging.exception(e)
await asyncio.sleep(interval)
return await func(*args, **kwargs)
return new_func
return decorator
def is_whitelist(chat_id):
whitelist = db['whitelist']
return chat_id in whitelist
def add_whitelist(chat_id):
whitelist = db['whitelist']
whitelist.add(chat_id)
db['whitelist'] = whitelist
def del_whitelist(chat_id):
whitelist = db['whitelist']
whitelist.discard(chat_id)
db['whitelist'] = whitelist
def get_whitelist():
return db['whitelist']
def only_admin(func):
async def new_func(message):
if message.sender_id != ADMIN_ID:
await send_message(message.chat_id, 'Only admin can use this command', message.id)
return
await func(message)
return new_func
def only_private(func):
async def new_func(message):
if message.chat_id != message.sender_id:
await send_message(message.chat_id, 'This command only works in private chat', message.id)
return
await func(message)
return new_func
def only_whitelist(func):
async def new_func(message):
if not is_whitelist(message.chat_id):
if message.chat_id == message.sender_id:
await send_message(message.chat_id, 'This chat is not in whitelist', message.id)
return
await func(message)
return new_func
def save_photo(photo_blob): # TODO: change to async
h = hashlib.sha256(photo_blob).hexdigest()
dir = f'photos/{h[:2]}/{h[2:4]}'
path = f'{dir}/{h}'
if not os.path.isfile(path):
os.makedirs(dir, exist_ok=True)
with open(path, 'wb') as f:
f.write(photo_blob)
return h
def load_photo(h):
dir = f'photos/{h[:2]}/{h[2:4]}'
path = f'{dir}/{h}'
with open(path, 'rb') as f:
return f.read()
async def completion(chat_history, model, chat_id, msg_id, task_id): # chat_history = [user, ai, user, ai, ..., user]
assert len(chat_history) % 2 == 1
system_prompt = get_prompt(model)
messages=[{"role": "system", "content": system_prompt}] if system_prompt else []
roles = ["user", "assistant"]
role_id = 0
for msg in chat_history:
messages.append({"role": roles[role_id], "content": msg})
role_id = 1 - role_id
def remove_image(messages):
new_messages = copy.deepcopy(messages)
for message in new_messages:
if 'content' in message:
if isinstance(message['content'], list):
for obj in message['content']:
if obj['type'] == 'image_url':
obj['image_url']['url'] = obj['image_url']['url'][:50] + '...'
return new_messages
logging.info('Request (chat_id=%r, msg_id=%r, task_id=%r): %s', chat_id, msg_id, task_id, remove_image(messages))
stream = await aclient.chat.completions.create(model=model, messages=messages, stream=True)
finished = False
async for response in stream:
logging.info('Response (chat_id=%r, msg_id=%r, task_id=%r): %s', chat_id, msg_id, task_id, response)
assert not finished
obj = response.choices[0]
if obj.delta.role is not None:
if obj.delta.role != 'assistant':
raise ValueError("Role error")
if obj.delta.content is not None:
yield obj.delta.content
if obj.finish_reason is not None or ('finish_details' in obj.model_extra and obj.finish_details is not None):
assert all(item is None for item in [
obj.delta.content,
obj.delta.function_call,
obj.delta.role,
obj.delta.tool_calls,
])
finish_reason = obj.finish_reason
if 'finish_details' in obj.model_extra and obj.finish_details is not None:
assert finish_reason is None
finish_reason = obj.finish_details['type']
if finish_reason == 'length':
yield '\n\n[!] Error: Output truncated due to limit'
elif finish_reason == 'stop':
pass
elif finish_reason is not None:
if obj.finish_reason is not None:
yield f'\n\n[!] Error: finish_reason="{finish_reason}"'
else:
yield f'\n\n[!] Error: finish_details="{obj.finish_details}"'
finished = True
def construct_chat_history(chat_id, msg_id):
messages = []
should_be_bot = False
model = DEFAULT_MODEL
has_image = False
while True:
key = repr((chat_id, msg_id))
if key not in db:
logging.error('History message not found (chat_id=%r, msg_id=%r)', chat_id, msg_id)
return None, None
is_bot, message, reply_id, *params = db[key]
if params:
if params[0] is not None:
model = params[0]
if is_bot != should_be_bot:
logging.error('Role does not match (chat_id=%r, msg_id=%r)', chat_id, msg_id)
return None, None
if isinstance(message, list):
new_message = []
for obj in message:
if obj['type'] == 'text':
new_message.append(obj)
elif obj['type'] == 'image':
blob = load_photo(obj['hash'])
blob_base64 = base64.b64encode(blob).decode()
image_url = 'data:image/jpeg;base64,' + blob_base64
new_message.append({'type': 'image_url', 'image_url': {'url': image_url, 'detail': 'high'}})
has_image = True
else:
raise ValueError('Unknown message type in chat history')
message = new_message
messages.append(message)
should_be_bot = not should_be_bot
if reply_id is None:
break
msg_id = reply_id
if len(messages) % 2 != 1:
logging.error('First message not from user (chat_id=%r, msg_id=%r)', chat_id, msg_id)
return None, None
return messages[::-1], model
@only_admin
async def add_whitelist_handler(message):
if is_whitelist(message.chat_id):
await send_message(message.chat_id, 'Already in whitelist', message.id)
return
add_whitelist(message.chat_id)
await send_message(message.chat_id, 'Whitelist added', message.id)
@only_admin
async def del_whitelist_handler(message):
if not is_whitelist(message.chat_id):
await send_message(message.chat_id, 'Not in whitelist', message.id)
return
del_whitelist(message.chat_id)
await send_message(message.chat_id, 'Whitelist deleted', message.id)
@only_admin
@only_private
async def get_whitelist_handler(message):
await send_message(message.chat_id, str(get_whitelist()), message.id)
@only_whitelist
async def list_models_handler(message):
text = ''
for m in MODELS:
text += f'Prefix: "{m["prefix"]}", model: {m["model"]}\n'
await send_message(message.chat_id, text, message.id)
@retry()
@ensure_interval()
async def send_message(chat_id, text, reply_to_message_id):
logging.info('Sending message: chat_id=%r, reply_to_message_id=%r, text=%r', chat_id, reply_to_message_id, text)
text = RichText(text)
text, entities = text.to_telegram()
msg = await bot.send_message(
chat_id,
text,
reply_to=reply_to_message_id,
link_preview=False,
formatting_entities=entities,
)
logging.info('Message sent: chat_id=%r, reply_to_message_id=%r, message_id=%r', chat_id, reply_to_message_id, msg.id)
return msg.id
@retry()
@ensure_interval()
async def edit_message(chat_id, text, message_id):
logging.info('Editing message: chat_id=%r, message_id=%r, text=%r', chat_id, message_id, text)
text = RichText(text)
text, entities = text.to_telegram()
try:
await bot.edit_message(
chat_id,
message_id,
text,
link_preview=False,
formatting_entities=entities,
)
except errors.MessageNotModifiedError as e:
logging.info('Message not modified: chat_id=%r, message_id=%r', chat_id, message_id)
else:
logging.info('Message edited: chat_id=%r, message_id=%r', chat_id, message_id)
@retry()
@ensure_interval()
async def delete_message(chat_id, message_id):
logging.info('Deleting message: chat_id=%r, message_id=%r', chat_id, message_id)
await bot.delete_messages(
chat_id,
message_id,
)
logging.info('Message deleted: chat_id=%r, message_id=%r', chat_id, message_id)
class BotReplyMessages:
def __init__(self, chat_id, orig_msg_id, prefix):
self.prefix = prefix
self.msg_len = TELEGRAM_LENGTH_LIMIT - len(prefix)
assert self.msg_len > 0
self.chat_id = chat_id
self.orig_msg_id = orig_msg_id
self.replied_msgs = []
self.text = ''
async def __aenter__(self):
return self
async def __aexit__(self, type, value, tb):
await self.finalize()
for msg_id, _ in self.replied_msgs:
pending_reply_manager.remove((self.chat_id, msg_id))
async def _force_update(self, text):
slices = []
while len(text) > self.msg_len:
slices.append(text[:self.msg_len])
text = text[self.msg_len:]
if text:
slices.append(text)
if not slices:
slices = [''] # deal with empty message
for i in range(min(len(slices), len(self.replied_msgs))):
msg_id, msg_text = self.replied_msgs[i]
if slices[i] != msg_text:
await edit_message(self.chat_id, self.prefix + slices[i], msg_id)
self.replied_msgs[i] = (msg_id, slices[i])
if len(slices) > len(self.replied_msgs):
for i in range(len(self.replied_msgs), len(slices)):
if i == 0:
reply_to = self.orig_msg_id
else:
reply_to, _ = self.replied_msgs[i - 1]
msg_id = await send_message(self.chat_id, self.prefix + slices[i], reply_to)
self.replied_msgs.append((msg_id, slices[i]))
pending_reply_manager.add((self.chat_id, msg_id))
if len(self.replied_msgs) > len(slices):
for i in range(len(slices), len(self.replied_msgs)):
msg_id, _ = self.replied_msgs[i]
await delete_message(self.chat_id, msg_id)
pending_reply_manager.remove((self.chat_id, msg_id))
self.replied_msgs = self.replied_msgs[:len(slices)]
async def update(self, text):
self.text = text
if not within_interval(self.chat_id):
await self._force_update(self.text)
async def finalize(self):
await self._force_update(self.text)
@only_whitelist
async def reply_handler(message):
chat_id = message.chat_id
sender_id = message.sender_id
msg_id = message.id
text = message.message
logging.info('New message: chat_id=%r, sender_id=%r, msg_id=%r, text=%r, photo=%s, document=%s', chat_id, sender_id, msg_id, text, message.photo, message.document)
reply_to_id = None
models = None
extra_photo_message = None
extra_document_message = None
if not text and message.photo is None and message.document is None: # unknown media types
return
if message.is_reply:
if message.reply_to.quote_text is not None:
return
reply_to_message = await message.get_reply_message()
if reply_to_message.sender_id == bot_id: # user reply to bot message
reply_to_id = message.reply_to.reply_to_msg_id
await pending_reply_manager.wait_for((chat_id, reply_to_id))
elif reply_to_message.photo is not None: # user reply to a photo
extra_photo_message = reply_to_message
elif reply_to_message.document is not None: # user reply to a document
extra_document_message = reply_to_message
else:
return
if not message.is_reply or extra_photo_message is not None or extra_document_message is not None: # new message
if '$' not in text:
if chat_id == sender_id: # if in private chat, send hint
await send_message(chat_id, '[!] Error: Please start a new conversation with $ or reply to a bot message', msg_id)
return
prefix, text = text.split('$', 1)
if '\n' in prefix:
if chat_id == sender_id:
await send_message(chat_id, '[!] Error: Please start a new conversation with $ or reply to a bot message', msg_id)
return
triggers = prefix.split(',')
models = []
for t in triggers:
for m in MODELS:
if m['prefix'] == t.strip() + '$':
models.append(m['model'])
break
if models and len(triggers) > TRIGGERS_LIMIT:
await send_message(chat_id, f'[!] Error: Too many triggers (limit: {TRIGGERS_LIMIT})', msg_id)
return
if chat_id == sender_id and len(models) != len(triggers):
await send_message(chat_id, '[!] Error: Unknown trigger in prefix', msg_id)
return
if not models:
return
photo_message = message if message.photo is not None else extra_photo_message
photo_hash = None
if photo_message is not None:
if photo_message.grouped_id is not None:
await send_message(chat_id, '[!] Error: Grouped photos are not yet supported, but will be supported soon', msg_id)
return
photo_blob = await photo_message.download_media(bytes)
photo_hash = save_photo(photo_blob)
document_message = message if message.document is not None else extra_document_message
document_text = None
if document_message is not None:
if document_message.grouped_id is not None:
await send_message(chat_id, '[!] Error: Grouped files are not yet supported, but will be supported soon', msg_id)
return
if document_message.document.size > TEXT_FILE_SIZE_LIMIT:
await send_message(chat_id, '[!] Error: File too large', msg_id)
return
document_blob = await document_message.download_media(bytes)
try:
document_text = document_blob.decode()
assert all(c != '\x00' for c in document_text)
except:
await send_message(chat_id, '[!] Error: File is not text file or not valid UTF-8', msg_id)
return
if photo_hash:
new_message = [{'type': 'text', 'text': text}, {'type': 'image', 'hash': photo_hash}]
elif document_text:
if text:
new_message = document_text + '\n\n' + text
else:
new_message = document_text
else:
new_message = text
db[repr((chat_id, msg_id))] = (False, new_message, reply_to_id, None)
chat_history, model = construct_chat_history(chat_id, msg_id)
if chat_history is None:
await send_message(chat_id, f"[!] Error: Unable to proceed with this conversation. Potential causes: the message replied to may be incomplete, contain an error, be a system message, or not exist in the database.", msg_id)
return
models = models if models is not None else [model]
async with asyncio.TaskGroup() as tg:
for task_id, m in enumerate(models):
tg.create_task(process_request(chat_id, msg_id, chat_history, m, task_id))
async def process_request(chat_id, msg_id, chat_history, model, task_id):
error_cnt = 0
while True:
reply = ''
async with BotReplyMessages(chat_id, msg_id, f'[{model}] ') as replymsgs:
try:
stream = completion(chat_history, model, chat_id, msg_id, task_id)
first_update_timestamp = None
async for delta in stream:
reply += delta
if first_update_timestamp is None:
first_update_timestamp = time.time()
if time.time() >= first_update_timestamp + FIRST_BATCH_DELAY:
await replymsgs.update(RichText.from_markdown(reply) + ' [!Generating...]')
await replymsgs.update(RichText.from_markdown(reply))
await replymsgs.finalize()
for message_id, _ in replymsgs.replied_msgs:
db[repr((chat_id, message_id))] = (True, reply, msg_id, model)
return
except Exception as e:
error_cnt += 1
logging.exception('Error (chat_id=%r, msg_id=%r, model=%r, task_id=%r, cnt=%r): %s', chat_id, msg_id, model, task_id, error_cnt, e)
will_retry = not isinstance (e, openai.BadRequestError) and error_cnt <= OPENAI_MAX_RETRY
error_msg = f'[!] Error: {traceback.format_exception_only(e)[-1].strip()}'
if will_retry:
error_msg += f'\nRetrying ({error_cnt}/{OPENAI_MAX_RETRY})...'
if reply:
error_msg = reply + '\n\n' + error_msg
await replymsgs.update(error_msg)
if will_retry:
await asyncio.sleep(OPENAI_RETRY_INTERVAL)
if not will_retry:
break
async def ping(message):
await send_message(message.chat_id, f'chat_id={message.chat_id} user_id={message.sender_id} is_whitelisted={is_whitelist(message.chat_id)}', message.id)
async def main():
global bot_id, pending_reply_manager, db, bot
logFormatter = logging.Formatter("%(asctime)s %(process)d %(levelname)s %(message)s")
rootLogger = logging.getLogger()
rootLogger.setLevel(logging.INFO)
fileHandler = logging.FileHandler(__file__ + ".log")
fileHandler.setFormatter(logFormatter)
rootLogger.addHandler(fileHandler)
consoleHandler = logging.StreamHandler()
consoleHandler.setFormatter(logFormatter)
rootLogger.addHandler(consoleHandler)
with shelve.open('db') as db:
# db[(chat_id, msg_id)] = (is_bot, text, reply_id, model)
# compatible old db format: db[(chat_id, msg_id)] = (is_bot, text, reply_id)
# db['whitelist'] = set(whitelist_chat_ids)
if 'whitelist' not in db:
db['whitelist'] = {ADMIN_ID}
bot_id = int(TELEGRAM_BOT_TOKEN.split(':')[0])
pending_reply_manager = PendingReplyManager()
async with await TelegramClient('bot', TELEGRAM_API_ID, TELEGRAM_API_HASH).start(bot_token=TELEGRAM_BOT_TOKEN) as bot:
bot.parse_mode = None
me = await bot.get_me()
@bot.on(events.NewMessage)
async def process(event):
if event.message.chat_id is None:
return
if event.message.sender_id is None or event.message.sender_id == bot_id:
return
if event.message.message is None:
return
text = event.message.message
if text == '/ping' or text == f'/ping@{me.username}':
await ping(event.message)
elif text == '/list_models' or text == f'/list_models@{me.username}':
await list_models_handler(event.message)
elif text == '/add_whitelist' or text == f'/add_whitelist@{me.username}':
await add_whitelist_handler(event.message)
elif text == '/del_whitelist' or text == f'/del_whitelist@{me.username}':
await del_whitelist_handler(event.message)
elif text == '/get_whitelist' or text == f'/get_whitelist@{me.username}':
await get_whitelist_handler(event.message)
else:
await reply_handler(event.message)
assert await bot(functions.bots.SetBotCommandsRequest(
scope=types.BotCommandScopeDefault(),
lang_code='en',
commands=[types.BotCommand(command, description) for command, description in [
('ping', 'Test bot connectivity'),
('list_models', 'List supported models'),
('add_whitelist', 'Add this group to whitelist (only admin)'),
('del_whitelist', 'Delete this group from whitelist (only admin)'),
('get_whitelist', 'List groups in whitelist (only admin)'),
]]
))
await bot.run_until_disconnected()
asyncio.run(main())