Curso-lenguaje-python/catch-all/06_bots_telegram/09_ollama_bot/bot/run.py
2024-08-18 20:16:34 +02:00

427 lines
10 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from aiogram import Bot, Dispatcher, types
from aiogram.enums import ParseMode
from aiogram.filters.command import Command, CommandStart
from aiogram.types import Message
from aiogram.utils.keyboard import InlineKeyboardBuilder
from func.interactions import *
import asyncio
import traceback
import io
import base64
bot = Bot(token=token)
dp = Dispatcher()
start_kb = InlineKeyboardBuilder()
settings_kb = InlineKeyboardBuilder()
start_kb.row(
types.InlineKeyboardButton(text=" About", callback_data="about"),
types.InlineKeyboardButton(text="⚙️ Settings", callback_data="settings"),
)
settings_kb.row(
types.InlineKeyboardButton(text="🔄 Switch LLM", callback_data="switchllm"),
types.InlineKeyboardButton(
text="✏️ Edit system prompt", callback_data="editsystemprompt"
),
)
commands = [
types.BotCommand(command="start", description="Start"),
types.BotCommand(command="reset", description="Reset Chat"),
types.BotCommand(command="history", description="Look through messages"),
]
ACTIVE_CHATS = {}
ACTIVE_CHATS_LOCK = contextLock()
modelname = os.getenv("INITMODEL")
mention = None
CHAT_TYPE_GROUP = "group"
CHAT_TYPE_SUPERGROUP = "supergroup"
async def get_bot_info():
global mention
if mention is None:
get = await bot.get_me()
mention = f"@{get.username}"
return mention
@dp.message(CommandStart())
async def command_start_handler(message: Message) -> None:
start_message = f"Welcome, <b>{message.from_user.full_name}</b>!"
await message.answer(
start_message,
parse_mode=ParseMode.HTML,
reply_markup=start_kb.as_markup(),
disable_web_page_preview=True,
)
@dp.message(Command("reset"))
async def command_reset_handler(message: Message) -> None:
if message.from_user.id in allowed_ids:
if message.from_user.id in ACTIVE_CHATS:
async with ACTIVE_CHATS_LOCK:
ACTIVE_CHATS.pop(message.from_user.id)
logging.info(
f"Chat has been reset for {message.from_user.first_name}"
)
await bot.send_message(
chat_id=message.chat.id,
text="Chat has been reset",
)
@dp.message(Command("history"))
async def command_get_context_handler(message: Message) -> None:
if message.from_user.id in allowed_ids:
if message.from_user.id in ACTIVE_CHATS:
messages = ACTIVE_CHATS.get(message.chat.id)["messages"]
context = ""
for msg in messages:
context += f"*{msg['role'].capitalize()}*: {msg['content']}\n"
await bot.send_message(
chat_id=message.chat.id,
text=context,
parse_mode=ParseMode.MARKDOWN,
)
else:
await bot.send_message(
chat_id=message.chat.id,
text="No chat history available for this user",
)
@dp.callback_query(lambda query: query.data == "settings")
async def settings_callback_handler(query: types.CallbackQuery):
await bot.send_message(
chat_id=query.message.chat.id,
text=f"Choose the right option.",
parse_mode=ParseMode.HTML,
disable_web_page_preview=True,
reply_markup=settings_kb.as_markup()
)
@dp.callback_query(lambda query: query.data == "switchllm")
async def switchllm_callback_handler(query: types.CallbackQuery):
models = await model_list()
switchllm_builder = InlineKeyboardBuilder()
for model in models:
modelname = model["name"]
modelfamilies = ""
if model["details"]["families"]:
modelicon = {"llama": "🦙", "clip": "📷"}
try:
modelfamilies = "".join(
[modelicon[family]
for family in model["details"]["families"]]
)
except KeyError as e:
modelfamilies = f""
switchllm_builder.row(
types.InlineKeyboardButton(
text=f"{modelname} {modelfamilies}",
callback_data=f"model_{modelname}"
)
)
await query.message.edit_text(
f"{len(models)} models available.\n🦙 = Regular\n🦙📷 = Multimodal", reply_markup=switchllm_builder.as_markup(),
)
@dp.callback_query(lambda query: query.data.startswith("model_"))
async def model_callback_handler(query: types.CallbackQuery):
global modelname
global modelfamily
modelname = query.data.split("model_")[1]
await query.answer(f"Chosen model: {modelname}")
@dp.callback_query(lambda query: query.data == "about")
@perms_admins
async def about_callback_handler(query: types.CallbackQuery):
dotenv_model = os.getenv("INITMODEL")
global modelname
await bot.send_message(
chat_id=query.message.chat.id,
text=f"""<b>Your LLMs</b>
Currently using: <code>{modelname}</code>
Default in .env: <code>{dotenv_model}</code>
This project is under <a href='https://github.com/ruecat/ollama-telegram/blob/main/LICENSE'>MIT License.</a>
<a href='https://github.com/ruecat/ollama-telegram'>Source Code</a>
""",
parse_mode=ParseMode.HTML,
disable_web_page_preview=True,
)
@dp.message()
@perms_allowed
async def handle_message(message: types.Message):
await get_bot_info()
if message.chat.type == "private":
await ollama_request(message)
return
if await is_mentioned_in_group_or_supergroup(message):
thread = await collect_message_thread(message)
prompt = format_thread_for_prompt(thread)
await ollama_request(message, prompt)
async def is_mentioned_in_group_or_supergroup(message: types.Message):
if message.chat.type not in ["group", "supergroup"]:
return False
is_mentioned = (
(message.text and message.text.startswith(mention)) or
(message.caption and message.caption.startswith(mention))
)
is_reply_to_bot = (
message.reply_to_message and
message.reply_to_message.from_user.id == bot.id
)
return is_mentioned or is_reply_to_bot
async def collect_message_thread(message: types.Message, thread=None):
if thread is None:
thread = []
thread.insert(0, message)
if message.reply_to_message:
await collect_message_thread(message.reply_to_message, thread)
return thread
def format_thread_for_prompt(thread):
prompt = "Conversation thread:\n\n"
for msg in thread:
sender = "User" if msg.from_user.id != bot.id else "Bot"
content = msg.text or msg.caption or "[No text content]"
prompt += f"{sender}: {content}\n\n"
prompt += "History:"
return prompt
async def process_image(message):
image_base64 = ""
if message.content_type == "photo":
image_buffer = io.BytesIO()
await bot.download(message.photo[-1], destination=image_buffer)
image_base64 = base64.b64encode(
image_buffer.getvalue()
).decode("utf-8")
return image_base64
async def add_prompt_to_active_chats(message, prompt, image_base64, modelname):
async with ACTIVE_CHATS_LOCK:
if ACTIVE_CHATS.get(message.from_user.id) is None:
ACTIVE_CHATS[message.from_user.id] = {
"model": modelname,
"messages": [
{
"role": "user",
"content": prompt,
"images": ([image_base64] if image_base64 else []),
}
],
"stream": True,
}
else:
ACTIVE_CHATS[message.from_user.id]["messages"].append(
{
"role": "user",
"content": prompt,
"images": ([image_base64] if image_base64 else []),
}
)
async def handle_response(message, response_data, full_response):
full_response_stripped = full_response.strip()
if full_response_stripped == "":
return
if response_data.get("done"):
text = f"{full_response_stripped}\n\n⚙️ {modelname}\nGenerated in {
response_data.get('total_duration') / 1e9:.2f}s."
await send_response(message, text)
async with ACTIVE_CHATS_LOCK:
if ACTIVE_CHATS.get(message.from_user.id) is not None:
ACTIVE_CHATS[message.from_user.id]["messages"].append(
{"role": "assistant", "content": full_response_stripped}
)
logging.info(
f"[Response]: '{full_response_stripped}' for {
message.from_user.first_name} {message.from_user.last_name}"
)
return True
return False
async def send_response(message, text):
# A negative message.chat.id is a group message
if message.chat.id < 0 or message.chat.id == message.from_user.id:
await bot.send_message(chat_id=message.chat.id, text=text)
else:
await bot.edit_message_text(
chat_id=message.chat.id,
message_id=message.message_id,
text=text
)
async def ollama_request(message: types.Message, prompt: str = None):
try:
full_response = ""
await bot.send_chat_action(message.chat.id, "typing")
image_base64 = await process_image(message)
if prompt is None:
prompt = message.text or message.caption
await add_prompt_to_active_chats(message, prompt, image_base64, modelname)
logging.info(
f"[OllamaAPI]: Processing '{prompt}' for {
message.from_user.first_name} {message.from_user.last_name}"
)
payload = ACTIVE_CHATS.get(message.from_user.id)
async for response_data in generate(payload, modelname, prompt):
msg = response_data.get("message")
if msg is None:
continue
chunk = msg.get("content", "")
full_response += chunk
if any([c in chunk for c in ".\n!?"]) or response_data.get("done"):
if await handle_response(message, response_data, full_response):
break
except Exception as e:
print(f"""-----
[OllamaAPI-ERR] CAUGHT FAULT!
{traceback.format_exc()}
-----""")
await bot.send_message(
chat_id=message.chat.id,
text=f"Something went wrong.",
parse_mode=ParseMode.HTML,
)
async def main():
await bot.set_my_commands(commands)
await dp.start_polling(bot, skip_update=True)
if __name__ == "__main__":
asyncio.run(main())