You've already forked Curso-lenguaje-python
Add Ollama bot
This commit is contained in:
@@ -0,0 +1,199 @@
|
||||
# >> interactions
|
||||
import logging
|
||||
import os
|
||||
import aiohttp
|
||||
import json
|
||||
|
||||
from aiogram import types
|
||||
from aiohttp import ClientTimeout
|
||||
from asyncio import Lock
|
||||
from functools import wraps
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
load_dotenv('.env')
|
||||
token = os.getenv("TOKEN")
|
||||
|
||||
allowed_ids = list(map(int, os.getenv("USER_IDS", "").split(",")))
|
||||
admin_ids = list(map(int, os.getenv("ADMIN_IDS", "").split(",")))
|
||||
|
||||
ollama_base_url = os.getenv("OLLAMA_BASE_URL")
|
||||
ollama_port = os.getenv("OLLAMA_PORT", "11434")
|
||||
|
||||
log_level_str = os.getenv("LOG_LEVEL", "INFO")
|
||||
|
||||
allow_all_users_in_groups = bool(
|
||||
int(os.getenv("ALLOW_ALL_USERS_IN_GROUPS", "0")))
|
||||
|
||||
log_levels = list(logging._levelToName.values())
|
||||
|
||||
timeout = os.getenv("TIMEOUT", "3000")
|
||||
|
||||
if log_level_str not in log_levels:
|
||||
|
||||
log_level = logging.DEBUG
|
||||
|
||||
else:
|
||||
|
||||
log_level = logging.getLevelName(log_level_str)
|
||||
|
||||
logging.basicConfig(level=log_level)
|
||||
|
||||
|
||||
async def model_list():
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
|
||||
url = f"http://{ollama_base_url}:{ollama_port}/api/tags"
|
||||
|
||||
async with session.get(url) as response:
|
||||
|
||||
if response.status == 200:
|
||||
|
||||
data = await response.json()
|
||||
return data["models"]
|
||||
|
||||
else:
|
||||
|
||||
return []
|
||||
|
||||
|
||||
async def generate(payload: dict, modelname: str, prompt: str):
|
||||
|
||||
client_timeout = ClientTimeout(total=int(timeout))
|
||||
|
||||
async with aiohttp.ClientSession(timeout=client_timeout) as session:
|
||||
|
||||
url = f"http://{ollama_base_url}:{ollama_port}/api/chat"
|
||||
|
||||
try:
|
||||
|
||||
async with session.post(url, json=payload) as response:
|
||||
|
||||
if response.status != 200:
|
||||
|
||||
raise aiohttp.ClientResponseError(
|
||||
|
||||
status=response.status, message=response.reason
|
||||
|
||||
)
|
||||
|
||||
buffer = b""
|
||||
|
||||
async for chunk in response.content.iter_any():
|
||||
|
||||
buffer += chunk
|
||||
|
||||
while b"\n" in buffer:
|
||||
|
||||
line, buffer = buffer.split(b"\n", 1)
|
||||
line = line.strip()
|
||||
|
||||
if line:
|
||||
|
||||
yield json.loads(line)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
|
||||
print(f"Error during request: {e}")
|
||||
|
||||
|
||||
def perms_allowed(func):
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(message: types.Message = None, query: types.CallbackQuery = None):
|
||||
|
||||
user_id = message.from_user.id if message else query.from_user.id
|
||||
|
||||
if user_id in admin_ids or user_id in allowed_ids:
|
||||
|
||||
if message:
|
||||
|
||||
return await func(message)
|
||||
|
||||
elif query:
|
||||
|
||||
return await func(query=query)
|
||||
|
||||
else:
|
||||
|
||||
if message:
|
||||
|
||||
if message and message.chat.type in ["supergroup", "group"]:
|
||||
|
||||
if allow_all_users_in_groups:
|
||||
|
||||
return await func(message)
|
||||
|
||||
return
|
||||
|
||||
await message.answer("Access Denied")
|
||||
|
||||
elif query:
|
||||
|
||||
if message and message.chat.type in ["supergroup", "group"]:
|
||||
|
||||
return
|
||||
|
||||
await query.answer("Access Denied")
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def perms_admins(func):
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(message: types.Message = None, query: types.CallbackQuery = None):
|
||||
|
||||
user_id = message.from_user.id if message else query.from_user.id
|
||||
|
||||
if user_id in admin_ids:
|
||||
|
||||
if message:
|
||||
|
||||
return await func(message)
|
||||
|
||||
elif query:
|
||||
|
||||
return await func(query=query)
|
||||
|
||||
else:
|
||||
|
||||
if message:
|
||||
|
||||
if message and message.chat.type in ["supergroup", "group"]:
|
||||
|
||||
return
|
||||
|
||||
await message.answer("Access Denied")
|
||||
|
||||
logging.info(
|
||||
f"[MSG] {message.from_user.first_name} {
|
||||
message.from_user.last_name}({message.from_user.id}) is not allowed to use this bot."
|
||||
)
|
||||
|
||||
elif query:
|
||||
|
||||
if message and message.chat.type in ["supergroup", "group"]:
|
||||
|
||||
return
|
||||
|
||||
await query.answer("Access Denied")
|
||||
|
||||
logging.info(
|
||||
f"[QUERY] {message.from_user.first_name} {
|
||||
message.from_user.last_name}({message.from_user.id}) is not allowed to use this bot."
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class contextLock:
|
||||
|
||||
lock = Lock()
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.lock.acquire()
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, exc_traceback):
|
||||
self.lock.release()
|
||||
426
catch-all/06_bots_telegram/09_ollama_bot/bot/run.py
Normal file
426
catch-all/06_bots_telegram/09_ollama_bot/bot/run.py
Normal file
@@ -0,0 +1,426 @@
|
||||
from aiogram import Bot, Dispatcher, types
|
||||
from aiogram.enums import ParseMode
|
||||
from aiogram.filters.command import Command, CommandStart
|
||||
from aiogram.types import Message
|
||||
from aiogram.utils.keyboard import InlineKeyboardBuilder
|
||||
from func.interactions import *
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
import io
|
||||
import base64
|
||||
|
||||
bot = Bot(token=token)
|
||||
dp = Dispatcher()
|
||||
|
||||
start_kb = InlineKeyboardBuilder()
|
||||
|
||||
settings_kb = InlineKeyboardBuilder()
|
||||
|
||||
start_kb.row(
|
||||
types.InlineKeyboardButton(text="ℹ️ About", callback_data="about"),
|
||||
types.InlineKeyboardButton(text="⚙️ Settings", callback_data="settings"),
|
||||
)
|
||||
|
||||
settings_kb.row(
|
||||
types.InlineKeyboardButton(text="🔄 Switch LLM", callback_data="switchllm"),
|
||||
types.InlineKeyboardButton(
|
||||
text="✏️ Edit system prompt", callback_data="editsystemprompt"
|
||||
),
|
||||
)
|
||||
|
||||
commands = [
|
||||
types.BotCommand(command="start", description="Start"),
|
||||
types.BotCommand(command="reset", description="Reset Chat"),
|
||||
types.BotCommand(command="history", description="Look through messages"),
|
||||
]
|
||||
|
||||
ACTIVE_CHATS = {}
|
||||
ACTIVE_CHATS_LOCK = contextLock()
|
||||
|
||||
modelname = os.getenv("INITMODEL")
|
||||
mention = None
|
||||
|
||||
CHAT_TYPE_GROUP = "group"
|
||||
CHAT_TYPE_SUPERGROUP = "supergroup"
|
||||
|
||||
|
||||
async def get_bot_info():
|
||||
|
||||
global mention
|
||||
|
||||
if mention is None:
|
||||
|
||||
get = await bot.get_me()
|
||||
mention = f"@{get.username}"
|
||||
|
||||
return mention
|
||||
|
||||
|
||||
@dp.message(CommandStart())
|
||||
async def command_start_handler(message: Message) -> None:
|
||||
|
||||
start_message = f"Welcome, <b>{message.from_user.full_name}</b>!"
|
||||
|
||||
await message.answer(
|
||||
start_message,
|
||||
parse_mode=ParseMode.HTML,
|
||||
reply_markup=start_kb.as_markup(),
|
||||
disable_web_page_preview=True,
|
||||
)
|
||||
|
||||
|
||||
@dp.message(Command("reset"))
|
||||
async def command_reset_handler(message: Message) -> None:
|
||||
|
||||
if message.from_user.id in allowed_ids:
|
||||
|
||||
if message.from_user.id in ACTIVE_CHATS:
|
||||
|
||||
async with ACTIVE_CHATS_LOCK:
|
||||
|
||||
ACTIVE_CHATS.pop(message.from_user.id)
|
||||
|
||||
logging.info(
|
||||
f"Chat has been reset for {message.from_user.first_name}"
|
||||
)
|
||||
|
||||
await bot.send_message(
|
||||
chat_id=message.chat.id,
|
||||
text="Chat has been reset",
|
||||
)
|
||||
|
||||
|
||||
@dp.message(Command("history"))
|
||||
async def command_get_context_handler(message: Message) -> None:
|
||||
|
||||
if message.from_user.id in allowed_ids:
|
||||
|
||||
if message.from_user.id in ACTIVE_CHATS:
|
||||
|
||||
messages = ACTIVE_CHATS.get(message.chat.id)["messages"]
|
||||
context = ""
|
||||
|
||||
for msg in messages:
|
||||
|
||||
context += f"*{msg['role'].capitalize()}*: {msg['content']}\n"
|
||||
|
||||
await bot.send_message(
|
||||
chat_id=message.chat.id,
|
||||
text=context,
|
||||
parse_mode=ParseMode.MARKDOWN,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
await bot.send_message(
|
||||
chat_id=message.chat.id,
|
||||
text="No chat history available for this user",
|
||||
)
|
||||
|
||||
|
||||
@dp.callback_query(lambda query: query.data == "settings")
|
||||
async def settings_callback_handler(query: types.CallbackQuery):
|
||||
|
||||
await bot.send_message(
|
||||
chat_id=query.message.chat.id,
|
||||
text=f"Choose the right option.",
|
||||
parse_mode=ParseMode.HTML,
|
||||
disable_web_page_preview=True,
|
||||
reply_markup=settings_kb.as_markup()
|
||||
)
|
||||
|
||||
|
||||
@dp.callback_query(lambda query: query.data == "switchllm")
|
||||
async def switchllm_callback_handler(query: types.CallbackQuery):
|
||||
|
||||
models = await model_list()
|
||||
switchllm_builder = InlineKeyboardBuilder()
|
||||
|
||||
for model in models:
|
||||
|
||||
modelname = model["name"]
|
||||
modelfamilies = ""
|
||||
|
||||
if model["details"]["families"]:
|
||||
|
||||
modelicon = {"llama": "🦙", "clip": "📷"}
|
||||
|
||||
try:
|
||||
|
||||
modelfamilies = "".join(
|
||||
[modelicon[family]
|
||||
for family in model["details"]["families"]]
|
||||
)
|
||||
|
||||
except KeyError as e:
|
||||
|
||||
modelfamilies = f"✨"
|
||||
|
||||
switchllm_builder.row(
|
||||
types.InlineKeyboardButton(
|
||||
text=f"{modelname} {modelfamilies}",
|
||||
callback_data=f"model_{modelname}"
|
||||
)
|
||||
)
|
||||
|
||||
await query.message.edit_text(
|
||||
f"{len(models)} models available.\n🦙 = Regular\n🦙📷 = Multimodal", reply_markup=switchllm_builder.as_markup(),
|
||||
)
|
||||
|
||||
|
||||
@dp.callback_query(lambda query: query.data.startswith("model_"))
|
||||
async def model_callback_handler(query: types.CallbackQuery):
|
||||
|
||||
global modelname
|
||||
global modelfamily
|
||||
|
||||
modelname = query.data.split("model_")[1]
|
||||
|
||||
await query.answer(f"Chosen model: {modelname}")
|
||||
|
||||
|
||||
@dp.callback_query(lambda query: query.data == "about")
|
||||
@perms_admins
|
||||
async def about_callback_handler(query: types.CallbackQuery):
|
||||
|
||||
dotenv_model = os.getenv("INITMODEL")
|
||||
|
||||
global modelname
|
||||
|
||||
await bot.send_message(
|
||||
chat_id=query.message.chat.id,
|
||||
text=f"""<b>Your LLMs</b>
|
||||
Currently using: <code>{modelname}</code>
|
||||
Default in .env: <code>{dotenv_model}</code>
|
||||
This project is under <a href='https://github.com/ruecat/ollama-telegram/blob/main/LICENSE'>MIT License.</a>
|
||||
<a href='https://github.com/ruecat/ollama-telegram'>Source Code</a>
|
||||
""",
|
||||
parse_mode=ParseMode.HTML,
|
||||
disable_web_page_preview=True,
|
||||
)
|
||||
|
||||
|
||||
@dp.message()
|
||||
@perms_allowed
|
||||
async def handle_message(message: types.Message):
|
||||
|
||||
await get_bot_info()
|
||||
|
||||
if message.chat.type == "private":
|
||||
|
||||
await ollama_request(message)
|
||||
|
||||
return
|
||||
|
||||
if await is_mentioned_in_group_or_supergroup(message):
|
||||
|
||||
thread = await collect_message_thread(message)
|
||||
prompt = format_thread_for_prompt(thread)
|
||||
|
||||
await ollama_request(message, prompt)
|
||||
|
||||
|
||||
async def is_mentioned_in_group_or_supergroup(message: types.Message):
|
||||
|
||||
if message.chat.type not in ["group", "supergroup"]:
|
||||
|
||||
return False
|
||||
|
||||
is_mentioned = (
|
||||
(message.text and message.text.startswith(mention)) or
|
||||
(message.caption and message.caption.startswith(mention))
|
||||
)
|
||||
|
||||
is_reply_to_bot = (
|
||||
message.reply_to_message and
|
||||
message.reply_to_message.from_user.id == bot.id
|
||||
)
|
||||
|
||||
return is_mentioned or is_reply_to_bot
|
||||
|
||||
|
||||
async def collect_message_thread(message: types.Message, thread=None):
|
||||
|
||||
if thread is None:
|
||||
|
||||
thread = []
|
||||
|
||||
thread.insert(0, message)
|
||||
|
||||
if message.reply_to_message:
|
||||
|
||||
await collect_message_thread(message.reply_to_message, thread)
|
||||
|
||||
return thread
|
||||
|
||||
|
||||
def format_thread_for_prompt(thread):
|
||||
|
||||
prompt = "Conversation thread:\n\n"
|
||||
|
||||
for msg in thread:
|
||||
|
||||
sender = "User" if msg.from_user.id != bot.id else "Bot"
|
||||
content = msg.text or msg.caption or "[No text content]"
|
||||
prompt += f"{sender}: {content}\n\n"
|
||||
|
||||
prompt += "History:"
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
async def process_image(message):
|
||||
|
||||
image_base64 = ""
|
||||
|
||||
if message.content_type == "photo":
|
||||
|
||||
image_buffer = io.BytesIO()
|
||||
|
||||
await bot.download(message.photo[-1], destination=image_buffer)
|
||||
|
||||
image_base64 = base64.b64encode(
|
||||
image_buffer.getvalue()
|
||||
).decode("utf-8")
|
||||
|
||||
return image_base64
|
||||
|
||||
|
||||
async def add_prompt_to_active_chats(message, prompt, image_base64, modelname):
|
||||
|
||||
async with ACTIVE_CHATS_LOCK:
|
||||
|
||||
if ACTIVE_CHATS.get(message.from_user.id) is None:
|
||||
|
||||
ACTIVE_CHATS[message.from_user.id] = {
|
||||
"model": modelname,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
"images": ([image_base64] if image_base64 else []),
|
||||
}
|
||||
],
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
else:
|
||||
|
||||
ACTIVE_CHATS[message.from_user.id]["messages"].append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
"images": ([image_base64] if image_base64 else []),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def handle_response(message, response_data, full_response):
|
||||
|
||||
full_response_stripped = full_response.strip()
|
||||
|
||||
if full_response_stripped == "":
|
||||
|
||||
return
|
||||
|
||||
if response_data.get("done"):
|
||||
|
||||
text = f"{full_response_stripped}\n\n⚙️ {modelname}\nGenerated in {
|
||||
response_data.get('total_duration') / 1e9:.2f}s."
|
||||
|
||||
await send_response(message, text)
|
||||
|
||||
async with ACTIVE_CHATS_LOCK:
|
||||
|
||||
if ACTIVE_CHATS.get(message.from_user.id) is not None:
|
||||
|
||||
ACTIVE_CHATS[message.from_user.id]["messages"].append(
|
||||
{"role": "assistant", "content": full_response_stripped}
|
||||
)
|
||||
|
||||
logging.info(
|
||||
f"[Response]: '{full_response_stripped}' for {
|
||||
message.from_user.first_name} {message.from_user.last_name}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def send_response(message, text):
|
||||
|
||||
# A negative message.chat.id is a group message
|
||||
if message.chat.id < 0 or message.chat.id == message.from_user.id:
|
||||
|
||||
await bot.send_message(chat_id=message.chat.id, text=text)
|
||||
|
||||
else:
|
||||
|
||||
await bot.edit_message_text(
|
||||
chat_id=message.chat.id,
|
||||
message_id=message.message_id,
|
||||
text=text
|
||||
)
|
||||
|
||||
|
||||
async def ollama_request(message: types.Message, prompt: str = None):
|
||||
|
||||
try:
|
||||
|
||||
full_response = ""
|
||||
await bot.send_chat_action(message.chat.id, "typing")
|
||||
image_base64 = await process_image(message)
|
||||
|
||||
if prompt is None:
|
||||
|
||||
prompt = message.text or message.caption
|
||||
|
||||
await add_prompt_to_active_chats(message, prompt, image_base64, modelname)
|
||||
|
||||
logging.info(
|
||||
f"[OllamaAPI]: Processing '{prompt}' for {
|
||||
message.from_user.first_name} {message.from_user.last_name}"
|
||||
)
|
||||
|
||||
payload = ACTIVE_CHATS.get(message.from_user.id)
|
||||
|
||||
async for response_data in generate(payload, modelname, prompt):
|
||||
|
||||
msg = response_data.get("message")
|
||||
|
||||
if msg is None:
|
||||
continue
|
||||
|
||||
chunk = msg.get("content", "")
|
||||
full_response += chunk
|
||||
|
||||
if any([c in chunk for c in ".\n!?"]) or response_data.get("done"):
|
||||
|
||||
if await handle_response(message, response_data, full_response):
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"""-----
|
||||
[OllamaAPI-ERR] CAUGHT FAULT!
|
||||
{traceback.format_exc()}
|
||||
-----""")
|
||||
|
||||
await bot.send_message(
|
||||
chat_id=message.chat.id,
|
||||
text=f"Something went wrong.",
|
||||
parse_mode=ParseMode.HTML,
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
|
||||
await bot.set_my_commands(commands)
|
||||
await dp.start_polling(bot, skip_update=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user