️ use async openai methods + create async task for the typing loop

This commit is contained in:
Laureηt 2023-10-19 16:10:47 +00:00
parent 8eda4825d9
commit 0f312a0a70
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI

View file

@ -1,3 +1,4 @@
import asyncio
import logging
import time
from collections import deque
@ -45,7 +46,7 @@ class LLMClient(AsyncClient):
The OpenAI temperature to use.
openai_max_tokens (`int`):
The OpenAI max tokens to use.
history_size (`int`, default `3`):
history_size (`int`):
The number of messages to keep in history.
"""
self.uid = f"@{username}:{homeserver.removeprefix('https://')}"
@ -73,7 +74,23 @@ class LLMClient(AsyncClient):
# add callbacks
self.add_event_callback(self.message_callback, RoomMessageText) # type: ignore
async def message_callback(self, room: MatrixRoom, event: RoomMessageText) -> None:
async def typing_loop(self) -> None:
"""Send typing indicators every 5 seconds."""
logging.debug("Started typing indicator.")
try:
while True:
logging.debug("Sending typing indicator.")
await self.room_typing(self.room, True)
await asyncio.sleep(5)
except asyncio.CancelledError:
await self.room_typing(self.room, False)
logging.debug("Stopped typing indicator.")
async def message_callback(
self,
room: MatrixRoom,
event: RoomMessageText,
) -> None:
"""Process new messages as they come in.
Args:
@ -128,16 +145,11 @@ class LLMClient(AsyncClient):
logger.debug("Ignoring message not mentioning us.")
return
# enable typing indicator
await self.room_typing(
self.room,
typing_state=True,
timeout=30000,
)
logger.debug("Enabled typing indicator.")
# start typing indicator loop
typing_task = asyncio.create_task(self.typing_loop())
# generate response using llama.cpp
response = openai.ChatCompletion.create(
response = await openai.ChatCompletion.acreate(
model="local-model",
messages=[
{
@ -162,10 +174,6 @@ class LLMClient(AsyncClient):
output = response["choices"][0]["message"]["content"] # type: ignore
output = output.strip().removeprefix(f"{self.uid}:").strip()
# disable typing indicator
await self.room_typing(self.room, typing_state=False)
logger.debug("Disabled typing indicator.")
# send the response
await self.room_send(
room_id=self.room,
@ -177,12 +185,21 @@ class LLMClient(AsyncClient):
)
logger.debug(f"Sent response: {output}")
async def start(self, password, sync_timeout=30000) -> None:
# stop typing indicator loop
typing_task.cancel()
async def start(
self,
password: str,
sync_timeout: int = 30000,
) -> None:
"""Start the client.
Args:
password (`str`): The password to log in with.
sync_timeout (`int`, default `30000`): The sync timeout in milliseconds.
password (`str`):
The password to log in with.
sync_timeout (`int`, default `30000`):
The sync timeout in milliseconds.
"""
# Login to the homeserver
logger.debug(await self.login(password))