From 0f312a0a70819073e380ad286792c9ed29fb5004 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laure=CE=B7t?= Date: Thu, 19 Oct 2023 16:10:47 +0000 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20use=20async=20openai=20met?= =?UTF-8?q?hods=20+=20create=20async=20task=20for=20the=20typing=20loop?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/nio_llm/client.py | 51 ++++++++++++++++++++++++++++--------------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/src/nio_llm/client.py b/src/nio_llm/client.py index 9cf1844..c2aa204 100644 --- a/src/nio_llm/client.py +++ b/src/nio_llm/client.py @@ -1,3 +1,4 @@ +import asyncio import logging import time from collections import deque @@ -45,7 +46,7 @@ class LLMClient(AsyncClient): The OpenAI temperature to use. openai_max_tokens (`int`): The OpenAI max tokens to use. - history_size (`int`, default `3`): + history_size (`int`): The number of messages to keep in history. """ self.uid = f"@{username}:{homeserver.removeprefix('https://')}" @@ -73,7 +74,23 @@ class LLMClient(AsyncClient): # add callbacks self.add_event_callback(self.message_callback, RoomMessageText) # type: ignore - async def message_callback(self, room: MatrixRoom, event: RoomMessageText) -> None: + async def typing_loop(self) -> None: + """Send typing indicators every 5 seconds.""" + logging.debug("Started typing indicator.") + try: + while True: + logging.debug("Sending typing indicator.") + await self.room_typing(self.room, True) + await asyncio.sleep(5) + except asyncio.CancelledError: + await self.room_typing(self.room, False) + logging.debug("Stopped typing indicator.") + + async def message_callback( + self, + room: MatrixRoom, + event: RoomMessageText, + ) -> None: """Process new messages as they come in. Args: @@ -128,16 +145,11 @@ class LLMClient(AsyncClient): logger.debug("Ignoring message not mentioning us.") return - # enable typing indicator - await self.room_typing( - self.room, - typing_state=True, - timeout=30000, - ) - logger.debug("Enabled typing indicator.") + # start typing indicator loop + typing_task = asyncio.create_task(self.typing_loop()) # generate response using llama.cpp - response = openai.ChatCompletion.create( + response = await openai.ChatCompletion.acreate( model="local-model", messages=[ { @@ -162,10 +174,6 @@ class LLMClient(AsyncClient): output = response["choices"][0]["message"]["content"] # type: ignore output = output.strip().removeprefix(f"{self.uid}:").strip() - # disable typing indicator - await self.room_typing(self.room, typing_state=False) - logger.debug("Disabled typing indicator.") - # send the response await self.room_send( room_id=self.room, @@ -177,12 +185,21 @@ class LLMClient(AsyncClient): ) logger.debug(f"Sent response: {output}") - async def start(self, password, sync_timeout=30000) -> None: + # stop typing indicator loop + typing_task.cancel() + + async def start( + self, + password: str, + sync_timeout: int = 30000, + ) -> None: """Start the client. Args: - password (`str`): The password to log in with. - sync_timeout (`int`, default `30000`): The sync timeout in milliseconds. + password (`str`): + The password to log in with. + sync_timeout (`int`, default `30000`): + The sync timeout in milliseconds. """ # Login to the homeserver logger.debug(await self.login(password))