diff --git a/nio_llm/client.py b/nio_llm/client.py index 7ab4ce3..888b430 100644 --- a/nio_llm/client.py +++ b/nio_llm/client.py @@ -23,7 +23,16 @@ class LLMClient(AsyncClient): ggml_path: Path, room: str, ): - """Create a new LLMClient instance.""" + """Create a new LLMClient instance. + + Args: + username (`str`): The username to log in as. + homeserver (`str`): The homeserver to connect to. + device_id (`str`): The device ID to use. + preprompt (`str`): The preprompt to use. + ggml_path (`Path`): The path to the GGML model. + room (`str`): The room to join. + """ self.uid = f"@{username}:{homeserver.removeprefix('https://')}" self.spawn_time = time.time() * 1000 self.username = username @@ -50,8 +59,13 @@ class LLMClient(AsyncClient): # add callbacks self.add_event_callback(self.message_callback, RoomMessageText) # type: ignore - async def message_callback(self, room: MatrixRoom, event: RoomMessageText): - """Process new messages as they come in.""" + async def message_callback(self, room: MatrixRoom, event: RoomMessageText) -> None: + """Process new messages as they come in. + + Args: + room (`MatrixRoom`): The room the message was sent in. + event (`RoomMessageText`): The message event. + """ logger.debug(f"New RoomMessageText: {event.source}") # ignore messages pre-dating our spawn time @@ -148,3 +162,19 @@ class LLMClient(AsyncClient): "body": output, }, ) + + async def start(self, password, sync_timeout=30000): + """Start the client. + + Args: + password (`str`): The password to log in with. + sync_timeout (`int`, default `30000`): The sync timeout in milliseconds. + """ + # Login to the homeserver + logger.debug(await self.login(password)) + + # Join the room, if not already joined + logger.debug(await self.join(self.room)) + + # Sync with the server forever + await self.sync_forever(timeout=sync_timeout) diff --git a/nio_llm/main.py b/nio_llm/main.py index 819b9fa..629253a 100644 --- a/nio_llm/main.py +++ b/nio_llm/main.py @@ -66,6 +66,13 @@ logger = logging.getLogger("nio-llm.main") default="stable-vicuna-13B.ggmlv3.q5_1.bin", show_default=True, ) +@click.option( + "--sync-timeout", + "-s", + help="The timeout to use when syncing with the homeserver.", + default=30000, + show_default=True, +) def main( *, room: str, @@ -76,6 +83,7 @@ def main( homeserver: str, ggml_repoid: str, ggml_filename: str, + sync_timeout: int, ) -> None: """Run the main program. @@ -89,34 +97,6 @@ def main( ), ) - # start the async loop - asyncio.get_event_loop().run_until_complete( - _main( - room=room, - password=password, - username=username, - device_id=device_id, - ggml_path=ggml_path, - preprompt=preprompt, - homeserver=homeserver, - ), - ) - - -async def _main( - *, - room: str, - password: str, - username: str, - device_id: str, - preprompt: str, - ggml_path: Path, - homeserver: str, -) -> None: - """Run the async main program. - - Create the client, login, join the room, and sync forever. - """ # create the client client = LLMClient( room=room, @@ -127,14 +107,13 @@ async def _main( homeserver=homeserver, ) - # Login to the homeserver - logger.debug(await client.login(password)) - - # Join the room, if not already joined - logger.debug(await client.join(room)) - - # Sync with the server forever - await client.sync_forever(timeout=30000) + # start the client + asyncio.get_event_loop().run_until_complete( + client.start( + password=password, + sync_timeout=sync_timeout, + ), + ) if __name__ == "__main__":