You've already forked Curso-lenguaje-python
Add chatgpt bot
This commit is contained in:
875
catch-all/06_bots_telegram/08_chatgpt_bot/bot/bot.py
Normal file
875
catch-all/06_bots_telegram/08_chatgpt_bot/bot/bot.py
Normal file
@@ -0,0 +1,875 @@
|
||||
import io
|
||||
import logging
|
||||
import asyncio
|
||||
import traceback
|
||||
import html
|
||||
import json
|
||||
from datetime import datetime
|
||||
import openai
|
||||
|
||||
import telegram
|
||||
from telegram import (
|
||||
Update,
|
||||
User,
|
||||
InlineKeyboardButton,
|
||||
InlineKeyboardMarkup,
|
||||
BotCommand
|
||||
)
|
||||
from telegram.ext import (
|
||||
Application,
|
||||
ApplicationBuilder,
|
||||
CallbackContext,
|
||||
CommandHandler,
|
||||
MessageHandler,
|
||||
CallbackQueryHandler,
|
||||
AIORateLimiter,
|
||||
filters
|
||||
)
|
||||
from telegram.constants import ParseMode, ChatAction
|
||||
|
||||
import config
|
||||
import database
|
||||
import openai_utils
|
||||
|
||||
import base64
|
||||
|
||||
# setup
|
||||
db = database.Database()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
user_semaphores = {}
|
||||
user_tasks = {}
|
||||
|
||||
HELP_MESSAGE = """Commands:
|
||||
⚪ /retry – Regenerate last bot answer
|
||||
⚪ /new – Start new dialog
|
||||
⚪ /mode – Select chat mode
|
||||
⚪ /settings – Show settings
|
||||
⚪ /balance – Show balance
|
||||
⚪ /help – Show help
|
||||
|
||||
🎨 Generate images from text prompts in <b>👩🎨 Artist</b> /mode
|
||||
👥 Add bot to <b>group chat</b>: /help_group_chat
|
||||
🎤 You can send <b>Voice Messages</b> instead of text
|
||||
"""
|
||||
|
||||
HELP_GROUP_CHAT_MESSAGE = """You can add bot to any <b>group chat</b> to help and entertain its participants!
|
||||
|
||||
Instructions (see <b>video</b> below):
|
||||
1. Add the bot to the group chat
|
||||
2. Make it an <b>admin</b>, so that it can see messages (all other rights can be restricted)
|
||||
3. You're awesome!
|
||||
|
||||
To get a reply from the bot in the chat – @ <b>tag</b> it or <b>reply</b> to its message.
|
||||
For example: "{bot_username} write a poem about Telegram"
|
||||
"""
|
||||
|
||||
|
||||
def split_text_into_chunks(text, chunk_size):
|
||||
for i in range(0, len(text), chunk_size):
|
||||
yield text[i:i + chunk_size]
|
||||
|
||||
|
||||
async def register_user_if_not_exists(update: Update, context: CallbackContext, user: User):
|
||||
if not db.check_if_user_exists(user.id):
|
||||
db.add_new_user(
|
||||
user.id,
|
||||
update.message.chat_id,
|
||||
username=user.username,
|
||||
first_name=user.first_name,
|
||||
last_name= user.last_name
|
||||
)
|
||||
db.start_new_dialog(user.id)
|
||||
|
||||
if db.get_user_attribute(user.id, "current_dialog_id") is None:
|
||||
db.start_new_dialog(user.id)
|
||||
|
||||
if user.id not in user_semaphores:
|
||||
user_semaphores[user.id] = asyncio.Semaphore(1)
|
||||
|
||||
if db.get_user_attribute(user.id, "current_model") is None:
|
||||
db.set_user_attribute(user.id, "current_model", config.models["available_text_models"][0])
|
||||
|
||||
# back compatibility for n_used_tokens field
|
||||
n_used_tokens = db.get_user_attribute(user.id, "n_used_tokens")
|
||||
if isinstance(n_used_tokens, int) or isinstance(n_used_tokens, float): # old format
|
||||
new_n_used_tokens = {
|
||||
"gpt-3.5-turbo": {
|
||||
"n_input_tokens": 0,
|
||||
"n_output_tokens": n_used_tokens
|
||||
}
|
||||
}
|
||||
db.set_user_attribute(user.id, "n_used_tokens", new_n_used_tokens)
|
||||
|
||||
# voice message transcription
|
||||
if db.get_user_attribute(user.id, "n_transcribed_seconds") is None:
|
||||
db.set_user_attribute(user.id, "n_transcribed_seconds", 0.0)
|
||||
|
||||
# image generation
|
||||
if db.get_user_attribute(user.id, "n_generated_images") is None:
|
||||
db.set_user_attribute(user.id, "n_generated_images", 0)
|
||||
|
||||
|
||||
async def is_bot_mentioned(update: Update, context: CallbackContext):
|
||||
try:
|
||||
message = update.message
|
||||
|
||||
if message.chat.type == "private":
|
||||
return True
|
||||
|
||||
if message.text is not None and ("@" + context.bot.username) in message.text:
|
||||
return True
|
||||
|
||||
if message.reply_to_message is not None:
|
||||
if message.reply_to_message.from_user.id == context.bot.id:
|
||||
return True
|
||||
except:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
async def start_handle(update: Update, context: CallbackContext):
|
||||
await register_user_if_not_exists(update, context, update.message.from_user)
|
||||
user_id = update.message.from_user.id
|
||||
|
||||
db.set_user_attribute(user_id, "last_interaction", datetime.now())
|
||||
db.start_new_dialog(user_id)
|
||||
|
||||
reply_text = "Hi! I'm <b>ChatGPT</b> bot implemented with OpenAI API 🤖\n\n"
|
||||
reply_text += HELP_MESSAGE
|
||||
|
||||
await update.message.reply_text(reply_text, parse_mode=ParseMode.HTML)
|
||||
await show_chat_modes_handle(update, context)
|
||||
|
||||
|
||||
async def help_handle(update: Update, context: CallbackContext):
|
||||
await register_user_if_not_exists(update, context, update.message.from_user)
|
||||
user_id = update.message.from_user.id
|
||||
db.set_user_attribute(user_id, "last_interaction", datetime.now())
|
||||
await update.message.reply_text(HELP_MESSAGE, parse_mode=ParseMode.HTML)
|
||||
|
||||
|
||||
async def help_group_chat_handle(update: Update, context: CallbackContext):
|
||||
await register_user_if_not_exists(update, context, update.message.from_user)
|
||||
user_id = update.message.from_user.id
|
||||
db.set_user_attribute(user_id, "last_interaction", datetime.now())
|
||||
|
||||
text = HELP_GROUP_CHAT_MESSAGE.format(bot_username="@" + context.bot.username)
|
||||
|
||||
await update.message.reply_text(text, parse_mode=ParseMode.HTML)
|
||||
await update.message.reply_video(config.help_group_chat_video_path)
|
||||
|
||||
|
||||
async def retry_handle(update: Update, context: CallbackContext):
|
||||
await register_user_if_not_exists(update, context, update.message.from_user)
|
||||
if await is_previous_message_not_answered_yet(update, context): return
|
||||
|
||||
user_id = update.message.from_user.id
|
||||
db.set_user_attribute(user_id, "last_interaction", datetime.now())
|
||||
|
||||
dialog_messages = db.get_dialog_messages(user_id, dialog_id=None)
|
||||
if len(dialog_messages) == 0:
|
||||
await update.message.reply_text("No message to retry 🤷♂️")
|
||||
return
|
||||
|
||||
last_dialog_message = dialog_messages.pop()
|
||||
db.set_dialog_messages(user_id, dialog_messages, dialog_id=None) # last message was removed from the context
|
||||
|
||||
await message_handle(update, context, message=last_dialog_message["user"], use_new_dialog_timeout=False)
|
||||
|
||||
async def _vision_message_handle_fn(
|
||||
update: Update, context: CallbackContext, use_new_dialog_timeout: bool = True
|
||||
):
|
||||
logger.info('_vision_message_handle_fn')
|
||||
user_id = update.message.from_user.id
|
||||
current_model = db.get_user_attribute(user_id, "current_model")
|
||||
|
||||
if current_model != "gpt-4-vision-preview" and current_model != "gpt-4o":
|
||||
await update.message.reply_text(
|
||||
"🥲 Images processing is only available for <b>gpt-4-vision-preview</b> and <b>gpt-4o</b> model. Please change your settings in /settings",
|
||||
parse_mode=ParseMode.HTML,
|
||||
)
|
||||
return
|
||||
|
||||
chat_mode = db.get_user_attribute(user_id, "current_chat_mode")
|
||||
|
||||
# new dialog timeout
|
||||
if use_new_dialog_timeout:
|
||||
if (datetime.now() - db.get_user_attribute(user_id, "last_interaction")).seconds > config.new_dialog_timeout and len(db.get_dialog_messages(user_id)) > 0:
|
||||
db.start_new_dialog(user_id)
|
||||
await update.message.reply_text(f"Starting new dialog due to timeout (<b>{config.chat_modes[chat_mode]['name']}</b> mode) ✅", parse_mode=ParseMode.HTML)
|
||||
db.set_user_attribute(user_id, "last_interaction", datetime.now())
|
||||
|
||||
buf = None
|
||||
if update.message.effective_attachment:
|
||||
photo = update.message.effective_attachment[-1]
|
||||
photo_file = await context.bot.get_file(photo.file_id)
|
||||
|
||||
# store file in memory, not on disk
|
||||
buf = io.BytesIO()
|
||||
await photo_file.download_to_memory(buf)
|
||||
buf.name = "image.jpg" # file extension is required
|
||||
buf.seek(0) # move cursor to the beginning of the buffer
|
||||
|
||||
# in case of CancelledError
|
||||
n_input_tokens, n_output_tokens = 0, 0
|
||||
|
||||
try:
|
||||
# send placeholder message to user
|
||||
placeholder_message = await update.message.reply_text("...")
|
||||
message = update.message.caption or update.message.text or ''
|
||||
|
||||
# send typing action
|
||||
await update.message.chat.send_action(action="typing")
|
||||
|
||||
dialog_messages = db.get_dialog_messages(user_id, dialog_id=None)
|
||||
parse_mode = {"html": ParseMode.HTML, "markdown": ParseMode.MARKDOWN}[
|
||||
config.chat_modes[chat_mode]["parse_mode"]
|
||||
]
|
||||
|
||||
chatgpt_instance = openai_utils.ChatGPT(model=current_model)
|
||||
if config.enable_message_streaming:
|
||||
gen = chatgpt_instance.send_vision_message_stream(
|
||||
message,
|
||||
dialog_messages=dialog_messages,
|
||||
image_buffer=buf,
|
||||
chat_mode=chat_mode,
|
||||
)
|
||||
else:
|
||||
(
|
||||
answer,
|
||||
(n_input_tokens, n_output_tokens),
|
||||
n_first_dialog_messages_removed,
|
||||
) = await chatgpt_instance.send_vision_message(
|
||||
message,
|
||||
dialog_messages=dialog_messages,
|
||||
image_buffer=buf,
|
||||
chat_mode=chat_mode,
|
||||
)
|
||||
|
||||
async def fake_gen():
|
||||
yield "finished", answer, (
|
||||
n_input_tokens,
|
||||
n_output_tokens,
|
||||
), n_first_dialog_messages_removed
|
||||
|
||||
gen = fake_gen()
|
||||
|
||||
prev_answer = ""
|
||||
async for gen_item in gen:
|
||||
(
|
||||
status,
|
||||
answer,
|
||||
(n_input_tokens, n_output_tokens),
|
||||
n_first_dialog_messages_removed,
|
||||
) = gen_item
|
||||
|
||||
answer = answer[:4096] # telegram message limit
|
||||
|
||||
# update only when 100 new symbols are ready
|
||||
if abs(len(answer) - len(prev_answer)) < 100 and status != "finished":
|
||||
continue
|
||||
|
||||
try:
|
||||
await context.bot.edit_message_text(
|
||||
answer,
|
||||
chat_id=placeholder_message.chat_id,
|
||||
message_id=placeholder_message.message_id,
|
||||
parse_mode=parse_mode,
|
||||
)
|
||||
except telegram.error.BadRequest as e:
|
||||
if str(e).startswith("Message is not modified"):
|
||||
continue
|
||||
else:
|
||||
await context.bot.edit_message_text(
|
||||
answer,
|
||||
chat_id=placeholder_message.chat_id,
|
||||
message_id=placeholder_message.message_id,
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.01) # wait a bit to avoid flooding
|
||||
|
||||
prev_answer = answer
|
||||
|
||||
# update user data
|
||||
if buf is not None:
|
||||
base_image = base64.b64encode(buf.getvalue()).decode("utf-8")
|
||||
new_dialog_message = {"user": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": message,
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"image": base_image,
|
||||
}
|
||||
]
|
||||
, "bot": answer, "date": datetime.now()}
|
||||
else:
|
||||
new_dialog_message = {"user": [{"type": "text", "text": message}], "bot": answer, "date": datetime.now()}
|
||||
|
||||
db.set_dialog_messages(
|
||||
user_id,
|
||||
db.get_dialog_messages(user_id, dialog_id=None) + [new_dialog_message],
|
||||
dialog_id=None
|
||||
)
|
||||
|
||||
db.update_n_used_tokens(user_id, current_model, n_input_tokens, n_output_tokens)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# note: intermediate token updates only work when enable_message_streaming=True (config.yml)
|
||||
db.update_n_used_tokens(user_id, current_model, n_input_tokens, n_output_tokens)
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
error_text = f"Something went wrong during completion. Reason: {e}"
|
||||
logger.error(error_text)
|
||||
await update.message.reply_text(error_text)
|
||||
return
|
||||
|
||||
async def unsupport_message_handle(update: Update, context: CallbackContext, message=None):
|
||||
error_text = f"I don't know how to read files or videos. Send the picture in normal mode (Quick Mode)."
|
||||
logger.error(error_text)
|
||||
await update.message.reply_text(error_text)
|
||||
return
|
||||
|
||||
async def message_handle(update: Update, context: CallbackContext, message=None, use_new_dialog_timeout=True):
|
||||
# check if bot was mentioned (for group chats)
|
||||
if not await is_bot_mentioned(update, context):
|
||||
return
|
||||
|
||||
# check if message is edited
|
||||
if update.edited_message is not None:
|
||||
await edited_message_handle(update, context)
|
||||
return
|
||||
|
||||
_message = message or update.message.text
|
||||
|
||||
# remove bot mention (in group chats)
|
||||
if update.message.chat.type != "private":
|
||||
_message = _message.replace("@" + context.bot.username, "").strip()
|
||||
|
||||
await register_user_if_not_exists(update, context, update.message.from_user)
|
||||
if await is_previous_message_not_answered_yet(update, context): return
|
||||
|
||||
user_id = update.message.from_user.id
|
||||
chat_mode = db.get_user_attribute(user_id, "current_chat_mode")
|
||||
|
||||
if chat_mode == "artist":
|
||||
await generate_image_handle(update, context, message=message)
|
||||
return
|
||||
|
||||
current_model = db.get_user_attribute(user_id, "current_model")
|
||||
|
||||
async def message_handle_fn():
|
||||
# new dialog timeout
|
||||
if use_new_dialog_timeout:
|
||||
if (datetime.now() - db.get_user_attribute(user_id, "last_interaction")).seconds > config.new_dialog_timeout and len(db.get_dialog_messages(user_id)) > 0:
|
||||
db.start_new_dialog(user_id)
|
||||
await update.message.reply_text(f"Starting new dialog due to timeout (<b>{config.chat_modes[chat_mode]['name']}</b> mode) ✅", parse_mode=ParseMode.HTML)
|
||||
db.set_user_attribute(user_id, "last_interaction", datetime.now())
|
||||
|
||||
# in case of CancelledError
|
||||
n_input_tokens, n_output_tokens = 0, 0
|
||||
|
||||
try:
|
||||
# send placeholder message to user
|
||||
placeholder_message = await update.message.reply_text("...")
|
||||
|
||||
# send typing action
|
||||
await update.message.chat.send_action(action="typing")
|
||||
|
||||
if _message is None or len(_message) == 0:
|
||||
await update.message.reply_text("🥲 You sent <b>empty message</b>. Please, try again!", parse_mode=ParseMode.HTML)
|
||||
return
|
||||
|
||||
dialog_messages = db.get_dialog_messages(user_id, dialog_id=None)
|
||||
parse_mode = {
|
||||
"html": ParseMode.HTML,
|
||||
"markdown": ParseMode.MARKDOWN
|
||||
}[config.chat_modes[chat_mode]["parse_mode"]]
|
||||
|
||||
chatgpt_instance = openai_utils.ChatGPT(model=current_model)
|
||||
if config.enable_message_streaming:
|
||||
gen = chatgpt_instance.send_message_stream(_message, dialog_messages=dialog_messages, chat_mode=chat_mode)
|
||||
else:
|
||||
answer, (n_input_tokens, n_output_tokens), n_first_dialog_messages_removed = await chatgpt_instance.send_message(
|
||||
_message,
|
||||
dialog_messages=dialog_messages,
|
||||
chat_mode=chat_mode
|
||||
)
|
||||
|
||||
async def fake_gen():
|
||||
yield "finished", answer, (n_input_tokens, n_output_tokens), n_first_dialog_messages_removed
|
||||
|
||||
gen = fake_gen()
|
||||
|
||||
prev_answer = ""
|
||||
|
||||
async for gen_item in gen:
|
||||
status, answer, (n_input_tokens, n_output_tokens), n_first_dialog_messages_removed = gen_item
|
||||
|
||||
answer = answer[:4096] # telegram message limit
|
||||
|
||||
# update only when 100 new symbols are ready
|
||||
if abs(len(answer) - len(prev_answer)) < 100 and status != "finished":
|
||||
continue
|
||||
|
||||
try:
|
||||
await context.bot.edit_message_text(answer, chat_id=placeholder_message.chat_id, message_id=placeholder_message.message_id, parse_mode=parse_mode)
|
||||
except telegram.error.BadRequest as e:
|
||||
if str(e).startswith("Message is not modified"):
|
||||
continue
|
||||
else:
|
||||
await context.bot.edit_message_text(answer, chat_id=placeholder_message.chat_id, message_id=placeholder_message.message_id)
|
||||
|
||||
await asyncio.sleep(0.01) # wait a bit to avoid flooding
|
||||
|
||||
prev_answer = answer
|
||||
|
||||
# update user data
|
||||
new_dialog_message = {"user": [{"type": "text", "text": _message}], "bot": answer, "date": datetime.now()}
|
||||
|
||||
db.set_dialog_messages(
|
||||
user_id,
|
||||
db.get_dialog_messages(user_id, dialog_id=None) + [new_dialog_message],
|
||||
dialog_id=None
|
||||
)
|
||||
|
||||
db.update_n_used_tokens(user_id, current_model, n_input_tokens, n_output_tokens)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# note: intermediate token updates only work when enable_message_streaming=True (config.yml)
|
||||
db.update_n_used_tokens(user_id, current_model, n_input_tokens, n_output_tokens)
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
error_text = f"Something went wrong during completion. Reason: {e}"
|
||||
logger.error(error_text)
|
||||
await update.message.reply_text(error_text)
|
||||
return
|
||||
|
||||
# send message if some messages were removed from the context
|
||||
if n_first_dialog_messages_removed > 0:
|
||||
if n_first_dialog_messages_removed == 1:
|
||||
text = "✍️ <i>Note:</i> Your current dialog is too long, so your <b>first message</b> was removed from the context.\n Send /new command to start new dialog"
|
||||
else:
|
||||
text = f"✍️ <i>Note:</i> Your current dialog is too long, so <b>{n_first_dialog_messages_removed} first messages</b> were removed from the context.\n Send /new command to start new dialog"
|
||||
await update.message.reply_text(text, parse_mode=ParseMode.HTML)
|
||||
|
||||
async with user_semaphores[user_id]:
|
||||
if current_model == "gpt-4-vision-preview" or current_model == "gpt-4o" or update.message.photo is not None and len(update.message.photo) > 0:
|
||||
|
||||
logger.error(current_model)
|
||||
# What is this? ^^^
|
||||
|
||||
if current_model != "gpt-4o" and current_model != "gpt-4-vision-preview":
|
||||
current_model = "gpt-4o"
|
||||
db.set_user_attribute(user_id, "current_model", "gpt-4o")
|
||||
task = asyncio.create_task(
|
||||
_vision_message_handle_fn(update, context, use_new_dialog_timeout=use_new_dialog_timeout)
|
||||
)
|
||||
else:
|
||||
task = asyncio.create_task(
|
||||
message_handle_fn()
|
||||
)
|
||||
|
||||
user_tasks[user_id] = task
|
||||
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
await update.message.reply_text("✅ Canceled", parse_mode=ParseMode.HTML)
|
||||
else:
|
||||
pass
|
||||
finally:
|
||||
if user_id in user_tasks:
|
||||
del user_tasks[user_id]
|
||||
|
||||
|
||||
async def is_previous_message_not_answered_yet(update: Update, context: CallbackContext):
|
||||
await register_user_if_not_exists(update, context, update.message.from_user)
|
||||
|
||||
user_id = update.message.from_user.id
|
||||
if user_semaphores[user_id].locked():
|
||||
text = "⏳ Please <b>wait</b> for a reply to the previous message\n"
|
||||
text += "Or you can /cancel it"
|
||||
await update.message.reply_text(text, reply_to_message_id=update.message.id, parse_mode=ParseMode.HTML)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
async def voice_message_handle(update: Update, context: CallbackContext):
|
||||
# check if bot was mentioned (for group chats)
|
||||
if not await is_bot_mentioned(update, context):
|
||||
return
|
||||
|
||||
await register_user_if_not_exists(update, context, update.message.from_user)
|
||||
if await is_previous_message_not_answered_yet(update, context): return
|
||||
|
||||
user_id = update.message.from_user.id
|
||||
db.set_user_attribute(user_id, "last_interaction", datetime.now())
|
||||
|
||||
voice = update.message.voice
|
||||
voice_file = await context.bot.get_file(voice.file_id)
|
||||
|
||||
# store file in memory, not on disk
|
||||
buf = io.BytesIO()
|
||||
await voice_file.download_to_memory(buf)
|
||||
buf.name = "voice.oga" # file extension is required
|
||||
buf.seek(0) # move cursor to the beginning of the buffer
|
||||
|
||||
transcribed_text = await openai_utils.transcribe_audio(buf)
|
||||
text = f"🎤: <i>{transcribed_text}</i>"
|
||||
await update.message.reply_text(text, parse_mode=ParseMode.HTML)
|
||||
|
||||
# update n_transcribed_seconds
|
||||
db.set_user_attribute(user_id, "n_transcribed_seconds", voice.duration + db.get_user_attribute(user_id, "n_transcribed_seconds"))
|
||||
|
||||
await message_handle(update, context, message=transcribed_text)
|
||||
|
||||
|
||||
async def generate_image_handle(update: Update, context: CallbackContext, message=None):
|
||||
await register_user_if_not_exists(update, context, update.message.from_user)
|
||||
if await is_previous_message_not_answered_yet(update, context): return
|
||||
|
||||
user_id = update.message.from_user.id
|
||||
db.set_user_attribute(user_id, "last_interaction", datetime.now())
|
||||
|
||||
await update.message.chat.send_action(action="upload_photo")
|
||||
|
||||
message = message or update.message.text
|
||||
|
||||
try:
|
||||
image_urls = await openai_utils.generate_images(message, n_images=config.return_n_generated_images, size=config.image_size)
|
||||
except openai.error.InvalidRequestError as e:
|
||||
if str(e).startswith("Your request was rejected as a result of our safety system"):
|
||||
text = "🥲 Your request <b>doesn't comply</b> with OpenAI's usage policies.\nWhat did you write there, huh?"
|
||||
await update.message.reply_text(text, parse_mode=ParseMode.HTML)
|
||||
return
|
||||
else:
|
||||
raise
|
||||
|
||||
# token usage
|
||||
db.set_user_attribute(user_id, "n_generated_images", config.return_n_generated_images + db.get_user_attribute(user_id, "n_generated_images"))
|
||||
|
||||
for i, image_url in enumerate(image_urls):
|
||||
await update.message.chat.send_action(action="upload_photo")
|
||||
await update.message.reply_photo(image_url, parse_mode=ParseMode.HTML)
|
||||
|
||||
|
||||
async def new_dialog_handle(update: Update, context: CallbackContext):
|
||||
await register_user_if_not_exists(update, context, update.message.from_user)
|
||||
if await is_previous_message_not_answered_yet(update, context): return
|
||||
|
||||
user_id = update.message.from_user.id
|
||||
db.set_user_attribute(user_id, "last_interaction", datetime.now())
|
||||
db.set_user_attribute(user_id, "current_model", "gpt-3.5-turbo")
|
||||
|
||||
db.start_new_dialog(user_id)
|
||||
await update.message.reply_text("Starting new dialog ✅")
|
||||
|
||||
chat_mode = db.get_user_attribute(user_id, "current_chat_mode")
|
||||
await update.message.reply_text(f"{config.chat_modes[chat_mode]['welcome_message']}", parse_mode=ParseMode.HTML)
|
||||
|
||||
|
||||
async def cancel_handle(update: Update, context: CallbackContext):
|
||||
await register_user_if_not_exists(update, context, update.message.from_user)
|
||||
|
||||
user_id = update.message.from_user.id
|
||||
db.set_user_attribute(user_id, "last_interaction", datetime.now())
|
||||
|
||||
if user_id in user_tasks:
|
||||
task = user_tasks[user_id]
|
||||
task.cancel()
|
||||
else:
|
||||
await update.message.reply_text("<i>Nothing to cancel...</i>", parse_mode=ParseMode.HTML)
|
||||
|
||||
|
||||
def get_chat_mode_menu(page_index: int):
|
||||
n_chat_modes_per_page = config.n_chat_modes_per_page
|
||||
text = f"Select <b>chat mode</b> ({len(config.chat_modes)} modes available):"
|
||||
|
||||
# buttons
|
||||
chat_mode_keys = list(config.chat_modes.keys())
|
||||
page_chat_mode_keys = chat_mode_keys[page_index * n_chat_modes_per_page:(page_index + 1) * n_chat_modes_per_page]
|
||||
|
||||
keyboard = []
|
||||
for chat_mode_key in page_chat_mode_keys:
|
||||
name = config.chat_modes[chat_mode_key]["name"]
|
||||
keyboard.append([InlineKeyboardButton(name, callback_data=f"set_chat_mode|{chat_mode_key}")])
|
||||
|
||||
# pagination
|
||||
if len(chat_mode_keys) > n_chat_modes_per_page:
|
||||
is_first_page = (page_index == 0)
|
||||
is_last_page = ((page_index + 1) * n_chat_modes_per_page >= len(chat_mode_keys))
|
||||
|
||||
if is_first_page:
|
||||
keyboard.append([
|
||||
InlineKeyboardButton("»", callback_data=f"show_chat_modes|{page_index + 1}")
|
||||
])
|
||||
elif is_last_page:
|
||||
keyboard.append([
|
||||
InlineKeyboardButton("«", callback_data=f"show_chat_modes|{page_index - 1}"),
|
||||
])
|
||||
else:
|
||||
keyboard.append([
|
||||
InlineKeyboardButton("«", callback_data=f"show_chat_modes|{page_index - 1}"),
|
||||
InlineKeyboardButton("»", callback_data=f"show_chat_modes|{page_index + 1}")
|
||||
])
|
||||
|
||||
reply_markup = InlineKeyboardMarkup(keyboard)
|
||||
|
||||
return text, reply_markup
|
||||
|
||||
|
||||
async def show_chat_modes_handle(update: Update, context: CallbackContext):
|
||||
await register_user_if_not_exists(update, context, update.message.from_user)
|
||||
if await is_previous_message_not_answered_yet(update, context): return
|
||||
|
||||
user_id = update.message.from_user.id
|
||||
db.set_user_attribute(user_id, "last_interaction", datetime.now())
|
||||
|
||||
text, reply_markup = get_chat_mode_menu(0)
|
||||
await update.message.reply_text(text, reply_markup=reply_markup, parse_mode=ParseMode.HTML)
|
||||
|
||||
|
||||
async def show_chat_modes_callback_handle(update: Update, context: CallbackContext):
|
||||
await register_user_if_not_exists(update.callback_query, context, update.callback_query.from_user)
|
||||
if await is_previous_message_not_answered_yet(update.callback_query, context): return
|
||||
|
||||
user_id = update.callback_query.from_user.id
|
||||
db.set_user_attribute(user_id, "last_interaction", datetime.now())
|
||||
|
||||
query = update.callback_query
|
||||
await query.answer()
|
||||
|
||||
page_index = int(query.data.split("|")[1])
|
||||
if page_index < 0:
|
||||
return
|
||||
|
||||
text, reply_markup = get_chat_mode_menu(page_index)
|
||||
try:
|
||||
await query.edit_message_text(text, reply_markup=reply_markup, parse_mode=ParseMode.HTML)
|
||||
except telegram.error.BadRequest as e:
|
||||
if str(e).startswith("Message is not modified"):
|
||||
pass
|
||||
|
||||
|
||||
async def set_chat_mode_handle(update: Update, context: CallbackContext):
|
||||
await register_user_if_not_exists(update.callback_query, context, update.callback_query.from_user)
|
||||
user_id = update.callback_query.from_user.id
|
||||
|
||||
query = update.callback_query
|
||||
await query.answer()
|
||||
|
||||
chat_mode = query.data.split("|")[1]
|
||||
|
||||
db.set_user_attribute(user_id, "current_chat_mode", chat_mode)
|
||||
db.start_new_dialog(user_id)
|
||||
|
||||
await context.bot.send_message(
|
||||
update.callback_query.message.chat.id,
|
||||
f"{config.chat_modes[chat_mode]['welcome_message']}",
|
||||
parse_mode=ParseMode.HTML
|
||||
)
|
||||
|
||||
|
||||
def get_settings_menu(user_id: int):
|
||||
current_model = db.get_user_attribute(user_id, "current_model")
|
||||
text = config.models["info"][current_model]["description"]
|
||||
|
||||
text += "\n\n"
|
||||
score_dict = config.models["info"][current_model]["scores"]
|
||||
for score_key, score_value in score_dict.items():
|
||||
text += "🟢" * score_value + "⚪️" * (5 - score_value) + f" – {score_key}\n\n"
|
||||
|
||||
text += "\nSelect <b>model</b>:"
|
||||
|
||||
# buttons to choose models
|
||||
buttons = []
|
||||
for model_key in config.models["available_text_models"]:
|
||||
title = config.models["info"][model_key]["name"]
|
||||
if model_key == current_model:
|
||||
title = "✅ " + title
|
||||
|
||||
buttons.append(
|
||||
InlineKeyboardButton(title, callback_data=f"set_settings|{model_key}")
|
||||
)
|
||||
reply_markup = InlineKeyboardMarkup([buttons])
|
||||
|
||||
return text, reply_markup
|
||||
|
||||
|
||||
async def settings_handle(update: Update, context: CallbackContext):
|
||||
await register_user_if_not_exists(update, context, update.message.from_user)
|
||||
if await is_previous_message_not_answered_yet(update, context): return
|
||||
|
||||
user_id = update.message.from_user.id
|
||||
db.set_user_attribute(user_id, "last_interaction", datetime.now())
|
||||
|
||||
text, reply_markup = get_settings_menu(user_id)
|
||||
await update.message.reply_text(text, reply_markup=reply_markup, parse_mode=ParseMode.HTML)
|
||||
|
||||
|
||||
async def set_settings_handle(update: Update, context: CallbackContext):
|
||||
await register_user_if_not_exists(update.callback_query, context, update.callback_query.from_user)
|
||||
user_id = update.callback_query.from_user.id
|
||||
|
||||
query = update.callback_query
|
||||
await query.answer()
|
||||
|
||||
_, model_key = query.data.split("|")
|
||||
db.set_user_attribute(user_id, "current_model", model_key)
|
||||
db.start_new_dialog(user_id)
|
||||
|
||||
text, reply_markup = get_settings_menu(user_id)
|
||||
try:
|
||||
await query.edit_message_text(text, reply_markup=reply_markup, parse_mode=ParseMode.HTML)
|
||||
except telegram.error.BadRequest as e:
|
||||
if str(e).startswith("Message is not modified"):
|
||||
pass
|
||||
|
||||
|
||||
async def show_balance_handle(update: Update, context: CallbackContext):
|
||||
await register_user_if_not_exists(update, context, update.message.from_user)
|
||||
|
||||
user_id = update.message.from_user.id
|
||||
db.set_user_attribute(user_id, "last_interaction", datetime.now())
|
||||
|
||||
# count total usage statistics
|
||||
total_n_spent_dollars = 0
|
||||
total_n_used_tokens = 0
|
||||
|
||||
n_used_tokens_dict = db.get_user_attribute(user_id, "n_used_tokens")
|
||||
n_generated_images = db.get_user_attribute(user_id, "n_generated_images")
|
||||
n_transcribed_seconds = db.get_user_attribute(user_id, "n_transcribed_seconds")
|
||||
|
||||
details_text = "🏷️ Details:\n"
|
||||
for model_key in sorted(n_used_tokens_dict.keys()):
|
||||
n_input_tokens, n_output_tokens = n_used_tokens_dict[model_key]["n_input_tokens"], n_used_tokens_dict[model_key]["n_output_tokens"]
|
||||
total_n_used_tokens += n_input_tokens + n_output_tokens
|
||||
|
||||
n_input_spent_dollars = config.models["info"][model_key]["price_per_1000_input_tokens"] * (n_input_tokens / 1000)
|
||||
n_output_spent_dollars = config.models["info"][model_key]["price_per_1000_output_tokens"] * (n_output_tokens / 1000)
|
||||
total_n_spent_dollars += n_input_spent_dollars + n_output_spent_dollars
|
||||
|
||||
details_text += f"- {model_key}: <b>{n_input_spent_dollars + n_output_spent_dollars:.03f}$</b> / <b>{n_input_tokens + n_output_tokens} tokens</b>\n"
|
||||
|
||||
# image generation
|
||||
image_generation_n_spent_dollars = config.models["info"]["dalle-2"]["price_per_1_image"] * n_generated_images
|
||||
if n_generated_images != 0:
|
||||
details_text += f"- DALL·E 2 (image generation): <b>{image_generation_n_spent_dollars:.03f}$</b> / <b>{n_generated_images} generated images</b>\n"
|
||||
|
||||
total_n_spent_dollars += image_generation_n_spent_dollars
|
||||
|
||||
# voice recognition
|
||||
voice_recognition_n_spent_dollars = config.models["info"]["whisper"]["price_per_1_min"] * (n_transcribed_seconds / 60)
|
||||
if n_transcribed_seconds != 0:
|
||||
details_text += f"- Whisper (voice recognition): <b>{voice_recognition_n_spent_dollars:.03f}$</b> / <b>{n_transcribed_seconds:.01f} seconds</b>\n"
|
||||
|
||||
total_n_spent_dollars += voice_recognition_n_spent_dollars
|
||||
|
||||
|
||||
text = f"You spent <b>{total_n_spent_dollars:.03f}$</b>\n"
|
||||
text += f"You used <b>{total_n_used_tokens}</b> tokens\n\n"
|
||||
text += details_text
|
||||
|
||||
await update.message.reply_text(text, parse_mode=ParseMode.HTML)
|
||||
|
||||
|
||||
async def edited_message_handle(update: Update, context: CallbackContext):
|
||||
if update.edited_message.chat.type == "private":
|
||||
text = "🥲 Unfortunately, message <b>editing</b> is not supported"
|
||||
await update.edited_message.reply_text(text, parse_mode=ParseMode.HTML)
|
||||
|
||||
|
||||
async def error_handle(update: Update, context: CallbackContext) -> None:
|
||||
logger.error(msg="Exception while handling an update:", exc_info=context.error)
|
||||
|
||||
try:
|
||||
# collect error message
|
||||
tb_list = traceback.format_exception(None, context.error, context.error.__traceback__)
|
||||
tb_string = "".join(tb_list)
|
||||
update_str = update.to_dict() if isinstance(update, Update) else str(update)
|
||||
message = (
|
||||
f"An exception was raised while handling an update\n"
|
||||
f"<pre>update = {html.escape(json.dumps(update_str, indent=2, ensure_ascii=False))}"
|
||||
"</pre>\n\n"
|
||||
f"<pre>{html.escape(tb_string)}</pre>"
|
||||
)
|
||||
|
||||
# split text into multiple messages due to 4096 character limit
|
||||
for message_chunk in split_text_into_chunks(message, 4096):
|
||||
try:
|
||||
await context.bot.send_message(update.effective_chat.id, message_chunk, parse_mode=ParseMode.HTML)
|
||||
except telegram.error.BadRequest:
|
||||
# answer has invalid characters, so we send it without parse_mode
|
||||
await context.bot.send_message(update.effective_chat.id, message_chunk)
|
||||
except:
|
||||
await context.bot.send_message(update.effective_chat.id, "Some error in error handler")
|
||||
|
||||
async def post_init(application: Application):
|
||||
await application.bot.set_my_commands([
|
||||
BotCommand("/new", "Start new dialog"),
|
||||
BotCommand("/mode", "Select chat mode"),
|
||||
BotCommand("/retry", "Re-generate response for previous query"),
|
||||
BotCommand("/balance", "Show balance"),
|
||||
BotCommand("/settings", "Show settings"),
|
||||
BotCommand("/help", "Show help message"),
|
||||
])
|
||||
|
||||
def run_bot() -> None:
|
||||
application = (
|
||||
ApplicationBuilder()
|
||||
.token(config.telegram_token)
|
||||
.concurrent_updates(True)
|
||||
.rate_limiter(AIORateLimiter(max_retries=5))
|
||||
.http_version("1.1")
|
||||
.get_updates_http_version("1.1")
|
||||
.post_init(post_init)
|
||||
.build()
|
||||
)
|
||||
|
||||
# add handlers
|
||||
user_filter = filters.ALL
|
||||
if len(config.allowed_telegram_usernames) > 0:
|
||||
usernames = [x for x in config.allowed_telegram_usernames if isinstance(x, str)]
|
||||
any_ids = [x for x in config.allowed_telegram_usernames if isinstance(x, int)]
|
||||
user_ids = [x for x in any_ids if x > 0]
|
||||
group_ids = [x for x in any_ids if x < 0]
|
||||
user_filter = filters.User(username=usernames) | filters.User(user_id=user_ids) | filters.Chat(chat_id=group_ids)
|
||||
|
||||
application.add_handler(CommandHandler("start", start_handle, filters=user_filter))
|
||||
application.add_handler(CommandHandler("help", help_handle, filters=user_filter))
|
||||
application.add_handler(CommandHandler("help_group_chat", help_group_chat_handle, filters=user_filter))
|
||||
|
||||
application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND & user_filter, message_handle))
|
||||
application.add_handler(MessageHandler(filters.PHOTO & ~filters.COMMAND & user_filter, message_handle))
|
||||
application.add_handler(MessageHandler(filters.VIDEO & ~filters.COMMAND & user_filter, unsupport_message_handle))
|
||||
application.add_handler(MessageHandler(filters.Document.ALL & ~filters.COMMAND & user_filter, unsupport_message_handle))
|
||||
application.add_handler(CommandHandler("retry", retry_handle, filters=user_filter))
|
||||
application.add_handler(CommandHandler("new", new_dialog_handle, filters=user_filter))
|
||||
application.add_handler(CommandHandler("cancel", cancel_handle, filters=user_filter))
|
||||
|
||||
application.add_handler(MessageHandler(filters.VOICE & user_filter, voice_message_handle))
|
||||
|
||||
application.add_handler(CommandHandler("mode", show_chat_modes_handle, filters=user_filter))
|
||||
application.add_handler(CallbackQueryHandler(show_chat_modes_callback_handle, pattern="^show_chat_modes"))
|
||||
application.add_handler(CallbackQueryHandler(set_chat_mode_handle, pattern="^set_chat_mode"))
|
||||
|
||||
application.add_handler(CommandHandler("settings", settings_handle, filters=user_filter))
|
||||
application.add_handler(CallbackQueryHandler(set_settings_handle, pattern="^set_settings"))
|
||||
|
||||
application.add_handler(CommandHandler("balance", show_balance_handle, filters=user_filter))
|
||||
|
||||
application.add_error_handler(error_handle)
|
||||
|
||||
# start the bot
|
||||
application.run_polling()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_bot()
|
||||
35
catch-all/06_bots_telegram/08_chatgpt_bot/bot/config.py
Normal file
35
catch-all/06_bots_telegram/08_chatgpt_bot/bot/config.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import yaml
|
||||
import dotenv
|
||||
from pathlib import Path
|
||||
|
||||
config_dir = Path(__file__).parent.parent.resolve() / "config"
|
||||
|
||||
# load yaml config
|
||||
with open(config_dir / "config.yml", 'r') as f:
|
||||
config_yaml = yaml.safe_load(f)
|
||||
|
||||
# load .env config
|
||||
config_env = dotenv.dotenv_values(config_dir / "config.env")
|
||||
|
||||
# config parameters
|
||||
telegram_token = config_yaml["telegram_token"]
|
||||
openai_api_key = config_yaml["openai_api_key"]
|
||||
openai_api_base = config_yaml.get("openai_api_base", None)
|
||||
allowed_telegram_usernames = config_yaml["allowed_telegram_usernames"]
|
||||
new_dialog_timeout = config_yaml["new_dialog_timeout"]
|
||||
enable_message_streaming = config_yaml.get("enable_message_streaming", True)
|
||||
return_n_generated_images = config_yaml.get("return_n_generated_images", 1)
|
||||
image_size = config_yaml.get("image_size", "512x512")
|
||||
n_chat_modes_per_page = config_yaml.get("n_chat_modes_per_page", 5)
|
||||
mongodb_uri = f"mongodb://mongo:{config_env['MONGODB_PORT']}"
|
||||
|
||||
# chat_modes
|
||||
with open(config_dir / "chat_modes.yml", 'r') as f:
|
||||
chat_modes = yaml.safe_load(f)
|
||||
|
||||
# models
|
||||
with open(config_dir / "models.yml", 'r') as f:
|
||||
models = yaml.safe_load(f)
|
||||
|
||||
# files
|
||||
help_group_chat_video_path = Path(__file__).parent.parent.resolve() / "static" / "help_group_chat.mp4"
|
||||
128
catch-all/06_bots_telegram/08_chatgpt_bot/bot/database.py
Normal file
128
catch-all/06_bots_telegram/08_chatgpt_bot/bot/database.py
Normal file
@@ -0,0 +1,128 @@
|
||||
from typing import Optional, Any
|
||||
|
||||
import pymongo
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import config
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self):
|
||||
self.client = pymongo.MongoClient(config.mongodb_uri)
|
||||
self.db = self.client["chatgpt_telegram_bot"]
|
||||
|
||||
self.user_collection = self.db["user"]
|
||||
self.dialog_collection = self.db["dialog"]
|
||||
|
||||
def check_if_user_exists(self, user_id: int, raise_exception: bool = False):
|
||||
if self.user_collection.count_documents({"_id": user_id}) > 0:
|
||||
return True
|
||||
else:
|
||||
if raise_exception:
|
||||
raise ValueError(f"User {user_id} does not exist")
|
||||
else:
|
||||
return False
|
||||
|
||||
def add_new_user(
|
||||
self,
|
||||
user_id: int,
|
||||
chat_id: int,
|
||||
username: str = "",
|
||||
first_name: str = "",
|
||||
last_name: str = "",
|
||||
):
|
||||
user_dict = {
|
||||
"_id": user_id,
|
||||
"chat_id": chat_id,
|
||||
|
||||
"username": username,
|
||||
"first_name": first_name,
|
||||
"last_name": last_name,
|
||||
|
||||
"last_interaction": datetime.now(),
|
||||
"first_seen": datetime.now(),
|
||||
|
||||
"current_dialog_id": None,
|
||||
"current_chat_mode": "assistant",
|
||||
"current_model": config.models["available_text_models"][0],
|
||||
|
||||
"n_used_tokens": {},
|
||||
|
||||
"n_generated_images": 0,
|
||||
"n_transcribed_seconds": 0.0 # voice message transcription
|
||||
}
|
||||
|
||||
if not self.check_if_user_exists(user_id):
|
||||
self.user_collection.insert_one(user_dict)
|
||||
|
||||
def start_new_dialog(self, user_id: int):
|
||||
self.check_if_user_exists(user_id, raise_exception=True)
|
||||
|
||||
dialog_id = str(uuid.uuid4())
|
||||
dialog_dict = {
|
||||
"_id": dialog_id,
|
||||
"user_id": user_id,
|
||||
"chat_mode": self.get_user_attribute(user_id, "current_chat_mode"),
|
||||
"start_time": datetime.now(),
|
||||
"model": self.get_user_attribute(user_id, "current_model"),
|
||||
"messages": []
|
||||
}
|
||||
|
||||
# add new dialog
|
||||
self.dialog_collection.insert_one(dialog_dict)
|
||||
|
||||
# update user's current dialog
|
||||
self.user_collection.update_one(
|
||||
{"_id": user_id},
|
||||
{"$set": {"current_dialog_id": dialog_id}}
|
||||
)
|
||||
|
||||
return dialog_id
|
||||
|
||||
def get_user_attribute(self, user_id: int, key: str):
|
||||
self.check_if_user_exists(user_id, raise_exception=True)
|
||||
user_dict = self.user_collection.find_one({"_id": user_id})
|
||||
|
||||
if key not in user_dict:
|
||||
return None
|
||||
|
||||
return user_dict[key]
|
||||
|
||||
def set_user_attribute(self, user_id: int, key: str, value: Any):
|
||||
self.check_if_user_exists(user_id, raise_exception=True)
|
||||
self.user_collection.update_one({"_id": user_id}, {"$set": {key: value}})
|
||||
|
||||
def update_n_used_tokens(self, user_id: int, model: str, n_input_tokens: int, n_output_tokens: int):
|
||||
n_used_tokens_dict = self.get_user_attribute(user_id, "n_used_tokens")
|
||||
|
||||
if model in n_used_tokens_dict:
|
||||
n_used_tokens_dict[model]["n_input_tokens"] += n_input_tokens
|
||||
n_used_tokens_dict[model]["n_output_tokens"] += n_output_tokens
|
||||
else:
|
||||
n_used_tokens_dict[model] = {
|
||||
"n_input_tokens": n_input_tokens,
|
||||
"n_output_tokens": n_output_tokens
|
||||
}
|
||||
|
||||
self.set_user_attribute(user_id, "n_used_tokens", n_used_tokens_dict)
|
||||
|
||||
def get_dialog_messages(self, user_id: int, dialog_id: Optional[str] = None):
|
||||
self.check_if_user_exists(user_id, raise_exception=True)
|
||||
|
||||
if dialog_id is None:
|
||||
dialog_id = self.get_user_attribute(user_id, "current_dialog_id")
|
||||
|
||||
dialog_dict = self.dialog_collection.find_one({"_id": dialog_id, "user_id": user_id})
|
||||
return dialog_dict["messages"]
|
||||
|
||||
def set_dialog_messages(self, user_id: int, dialog_messages: list, dialog_id: Optional[str] = None):
|
||||
self.check_if_user_exists(user_id, raise_exception=True)
|
||||
|
||||
if dialog_id is None:
|
||||
dialog_id = self.get_user_attribute(user_id, "current_dialog_id")
|
||||
|
||||
self.dialog_collection.update_one(
|
||||
{"_id": dialog_id, "user_id": user_id},
|
||||
{"$set": {"messages": dialog_messages}}
|
||||
)
|
||||
364
catch-all/06_bots_telegram/08_chatgpt_bot/bot/openai_utils.py
Normal file
364
catch-all/06_bots_telegram/08_chatgpt_bot/bot/openai_utils.py
Normal file
@@ -0,0 +1,364 @@
|
||||
import base64
|
||||
from io import BytesIO
|
||||
import config
|
||||
import logging
|
||||
|
||||
import tiktoken
|
||||
import openai
|
||||
|
||||
|
||||
# setup openai
|
||||
openai.api_key = config.openai_api_key
|
||||
if config.openai_api_base is not None:
|
||||
openai.api_base = config.openai_api_base
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
OPENAI_COMPLETION_OPTIONS = {
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 1000,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"request_timeout": 60.0,
|
||||
}
|
||||
|
||||
|
||||
class ChatGPT:
|
||||
def __init__(self, model="gpt-3.5-turbo"):
|
||||
assert model in {"text-davinci-003", "gpt-3.5-turbo-16k", "gpt-3.5-turbo", "gpt-4", "gpt-4o", "gpt-4-1106-preview", "gpt-4-vision-preview"}, f"Unknown model: {model}"
|
||||
self.model = model
|
||||
|
||||
async def send_message(self, message, dialog_messages=[], chat_mode="assistant"):
|
||||
if chat_mode not in config.chat_modes.keys():
|
||||
raise ValueError(f"Chat mode {chat_mode} is not supported")
|
||||
|
||||
n_dialog_messages_before = len(dialog_messages)
|
||||
answer = None
|
||||
while answer is None:
|
||||
try:
|
||||
if self.model in {"gpt-3.5-turbo-16k", "gpt-3.5-turbo", "gpt-4", "gpt-4o", "gpt-4-1106-preview", "gpt-4-vision-preview"}:
|
||||
messages = self._generate_prompt_messages(message, dialog_messages, chat_mode)
|
||||
|
||||
r = await openai.ChatCompletion.acreate(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
**OPENAI_COMPLETION_OPTIONS
|
||||
)
|
||||
answer = r.choices[0].message["content"]
|
||||
elif self.model == "text-davinci-003":
|
||||
prompt = self._generate_prompt(message, dialog_messages, chat_mode)
|
||||
r = await openai.Completion.acreate(
|
||||
engine=self.model,
|
||||
prompt=prompt,
|
||||
**OPENAI_COMPLETION_OPTIONS
|
||||
)
|
||||
answer = r.choices[0].text
|
||||
else:
|
||||
raise ValueError(f"Unknown model: {self.model}")
|
||||
|
||||
answer = self._postprocess_answer(answer)
|
||||
n_input_tokens, n_output_tokens = r.usage.prompt_tokens, r.usage.completion_tokens
|
||||
except openai.error.InvalidRequestError as e: # too many tokens
|
||||
if len(dialog_messages) == 0:
|
||||
raise ValueError("Dialog messages is reduced to zero, but still has too many tokens to make completion") from e
|
||||
|
||||
# forget first message in dialog_messages
|
||||
dialog_messages = dialog_messages[1:]
|
||||
|
||||
n_first_dialog_messages_removed = n_dialog_messages_before - len(dialog_messages)
|
||||
|
||||
return answer, (n_input_tokens, n_output_tokens), n_first_dialog_messages_removed
|
||||
|
||||
async def send_message_stream(self, message, dialog_messages=[], chat_mode="assistant"):
|
||||
if chat_mode not in config.chat_modes.keys():
|
||||
raise ValueError(f"Chat mode {chat_mode} is not supported")
|
||||
|
||||
n_dialog_messages_before = len(dialog_messages)
|
||||
answer = None
|
||||
while answer is None:
|
||||
try:
|
||||
if self.model in {"gpt-3.5-turbo-16k", "gpt-3.5-turbo", "gpt-4","gpt-4o", "gpt-4-1106-preview"}:
|
||||
messages = self._generate_prompt_messages(message, dialog_messages, chat_mode)
|
||||
|
||||
r_gen = await openai.ChatCompletion.acreate(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
**OPENAI_COMPLETION_OPTIONS
|
||||
)
|
||||
|
||||
answer = ""
|
||||
async for r_item in r_gen:
|
||||
delta = r_item.choices[0].delta
|
||||
|
||||
if "content" in delta:
|
||||
answer += delta.content
|
||||
n_input_tokens, n_output_tokens = self._count_tokens_from_messages(messages, answer, model=self.model)
|
||||
n_first_dialog_messages_removed = 0
|
||||
|
||||
yield "not_finished", answer, (n_input_tokens, n_output_tokens), n_first_dialog_messages_removed
|
||||
|
||||
|
||||
elif self.model == "text-davinci-003":
|
||||
prompt = self._generate_prompt(message, dialog_messages, chat_mode)
|
||||
r_gen = await openai.Completion.acreate(
|
||||
engine=self.model,
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
**OPENAI_COMPLETION_OPTIONS
|
||||
)
|
||||
|
||||
answer = ""
|
||||
async for r_item in r_gen:
|
||||
answer += r_item.choices[0].text
|
||||
n_input_tokens, n_output_tokens = self._count_tokens_from_prompt(prompt, answer, model=self.model)
|
||||
n_first_dialog_messages_removed = n_dialog_messages_before - len(dialog_messages)
|
||||
yield "not_finished", answer, (n_input_tokens, n_output_tokens), n_first_dialog_messages_removed
|
||||
|
||||
answer = self._postprocess_answer(answer)
|
||||
|
||||
except openai.error.InvalidRequestError as e: # too many tokens
|
||||
if len(dialog_messages) == 0:
|
||||
raise e
|
||||
|
||||
# forget first message in dialog_messages
|
||||
dialog_messages = dialog_messages[1:]
|
||||
|
||||
yield "finished", answer, (n_input_tokens, n_output_tokens), n_first_dialog_messages_removed # sending final answer
|
||||
|
||||
async def send_vision_message(
|
||||
self,
|
||||
message,
|
||||
dialog_messages=[],
|
||||
chat_mode="assistant",
|
||||
image_buffer: BytesIO = None,
|
||||
):
|
||||
n_dialog_messages_before = len(dialog_messages)
|
||||
answer = None
|
||||
while answer is None:
|
||||
try:
|
||||
if self.model == "gpt-4-vision-preview" or self.model == "gpt-4o":
|
||||
messages = self._generate_prompt_messages(
|
||||
message, dialog_messages, chat_mode, image_buffer
|
||||
)
|
||||
r = await openai.ChatCompletion.acreate(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
**OPENAI_COMPLETION_OPTIONS
|
||||
)
|
||||
answer = r.choices[0].message.content
|
||||
else:
|
||||
raise ValueError(f"Unsupported model: {self.model}")
|
||||
|
||||
answer = self._postprocess_answer(answer)
|
||||
n_input_tokens, n_output_tokens = (
|
||||
r.usage.prompt_tokens,
|
||||
r.usage.completion_tokens,
|
||||
)
|
||||
except openai.error.InvalidRequestError as e: # too many tokens
|
||||
if len(dialog_messages) == 0:
|
||||
raise ValueError(
|
||||
"Dialog messages is reduced to zero, but still has too many tokens to make completion"
|
||||
) from e
|
||||
|
||||
# forget first message in dialog_messages
|
||||
dialog_messages = dialog_messages[1:]
|
||||
|
||||
n_first_dialog_messages_removed = n_dialog_messages_before - len(
|
||||
dialog_messages
|
||||
)
|
||||
|
||||
return (
|
||||
answer,
|
||||
(n_input_tokens, n_output_tokens),
|
||||
n_first_dialog_messages_removed,
|
||||
)
|
||||
|
||||
async def send_vision_message_stream(
|
||||
self,
|
||||
message,
|
||||
dialog_messages=[],
|
||||
chat_mode="assistant",
|
||||
image_buffer: BytesIO = None,
|
||||
):
|
||||
n_dialog_messages_before = len(dialog_messages)
|
||||
answer = None
|
||||
while answer is None:
|
||||
try:
|
||||
if self.model == "gpt-4-vision-preview" or self.model == "gpt-4o":
|
||||
messages = self._generate_prompt_messages(
|
||||
message, dialog_messages, chat_mode, image_buffer
|
||||
)
|
||||
|
||||
r_gen = await openai.ChatCompletion.acreate(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
**OPENAI_COMPLETION_OPTIONS,
|
||||
)
|
||||
|
||||
answer = ""
|
||||
async for r_item in r_gen:
|
||||
delta = r_item.choices[0].delta
|
||||
if "content" in delta:
|
||||
answer += delta.content
|
||||
(
|
||||
n_input_tokens,
|
||||
n_output_tokens,
|
||||
) = self._count_tokens_from_messages(
|
||||
messages, answer, model=self.model
|
||||
)
|
||||
n_first_dialog_messages_removed = (
|
||||
n_dialog_messages_before - len(dialog_messages)
|
||||
)
|
||||
yield "not_finished", answer, (
|
||||
n_input_tokens,
|
||||
n_output_tokens,
|
||||
), n_first_dialog_messages_removed
|
||||
|
||||
answer = self._postprocess_answer(answer)
|
||||
|
||||
except openai.error.InvalidRequestError as e: # too many tokens
|
||||
if len(dialog_messages) == 0:
|
||||
raise e
|
||||
# forget first message in dialog_messages
|
||||
dialog_messages = dialog_messages[1:]
|
||||
|
||||
yield "finished", answer, (
|
||||
n_input_tokens,
|
||||
n_output_tokens,
|
||||
), n_first_dialog_messages_removed
|
||||
|
||||
def _generate_prompt(self, message, dialog_messages, chat_mode):
|
||||
prompt = config.chat_modes[chat_mode]["prompt_start"]
|
||||
prompt += "\n\n"
|
||||
|
||||
# add chat context
|
||||
if len(dialog_messages) > 0:
|
||||
prompt += "Chat:\n"
|
||||
for dialog_message in dialog_messages:
|
||||
prompt += f"User: {dialog_message['user']}\n"
|
||||
prompt += f"Assistant: {dialog_message['bot']}\n"
|
||||
|
||||
# current message
|
||||
prompt += f"User: {message}\n"
|
||||
prompt += "Assistant: "
|
||||
|
||||
return prompt
|
||||
|
||||
def _encode_image(self, image_buffer: BytesIO) -> bytes:
|
||||
return base64.b64encode(image_buffer.read()).decode("utf-8")
|
||||
|
||||
def _generate_prompt_messages(self, message, dialog_messages, chat_mode, image_buffer: BytesIO = None):
|
||||
prompt = config.chat_modes[chat_mode]["prompt_start"]
|
||||
|
||||
messages = [{"role": "system", "content": prompt}]
|
||||
|
||||
for dialog_message in dialog_messages:
|
||||
messages.append({"role": "user", "content": dialog_message["user"]})
|
||||
messages.append({"role": "assistant", "content": dialog_message["bot"]})
|
||||
|
||||
if image_buffer is not None:
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": message,
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url" : {
|
||||
|
||||
"url": f"data:image/jpeg;base64,{self._encode_image(image_buffer)}",
|
||||
"detail":"high"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
)
|
||||
else:
|
||||
messages.append({"role": "user", "content": message})
|
||||
|
||||
return messages
|
||||
|
||||
def _postprocess_answer(self, answer):
|
||||
answer = answer.strip()
|
||||
return answer
|
||||
|
||||
def _count_tokens_from_messages(self, messages, answer, model="gpt-3.5-turbo"):
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
|
||||
if model == "gpt-3.5-turbo-16k":
|
||||
tokens_per_message = 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||
tokens_per_name = -1 # if there's a name, the role is omitted
|
||||
elif model == "gpt-3.5-turbo":
|
||||
tokens_per_message = 4
|
||||
tokens_per_name = -1
|
||||
elif model == "gpt-4":
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
elif model == "gpt-4-1106-preview":
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
elif model == "gpt-4-vision-preview":
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
elif model == "gpt-4o":
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
raise ValueError(f"Unknown model: {model}")
|
||||
|
||||
# input
|
||||
n_input_tokens = 0
|
||||
for message in messages:
|
||||
n_input_tokens += tokens_per_message
|
||||
if isinstance(message["content"], list):
|
||||
for sub_message in message["content"]:
|
||||
if "type" in sub_message:
|
||||
if sub_message["type"] == "text":
|
||||
n_input_tokens += len(encoding.encode(sub_message["text"]))
|
||||
elif sub_message["type"] == "image_url":
|
||||
pass
|
||||
else:
|
||||
if "type" in message:
|
||||
if message["type"] == "text":
|
||||
n_input_tokens += len(encoding.encode(message["text"]))
|
||||
elif message["type"] == "image_url":
|
||||
pass
|
||||
|
||||
|
||||
n_input_tokens += 2
|
||||
|
||||
# output
|
||||
n_output_tokens = 1 + len(encoding.encode(answer))
|
||||
|
||||
return n_input_tokens, n_output_tokens
|
||||
|
||||
def _count_tokens_from_prompt(self, prompt, answer, model="text-davinci-003"):
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
|
||||
n_input_tokens = len(encoding.encode(prompt)) + 1
|
||||
n_output_tokens = len(encoding.encode(answer))
|
||||
|
||||
return n_input_tokens, n_output_tokens
|
||||
|
||||
|
||||
async def transcribe_audio(audio_file) -> str:
|
||||
r = await openai.Audio.atranscribe("whisper-1", audio_file)
|
||||
return r["text"] or ""
|
||||
|
||||
|
||||
async def generate_images(prompt, n_images=4, size="512x512"):
|
||||
r = await openai.Image.acreate(prompt=prompt, n=n_images, size=size)
|
||||
image_urls = [item.url for item in r.data]
|
||||
return image_urls
|
||||
|
||||
|
||||
async def is_content_acceptable(prompt):
|
||||
r = await openai.Moderation.acreate(input=prompt)
|
||||
return not all(r.results[0].categories.values())
|
||||
Reference in New Issue
Block a user