♻️ move _main function to internal client method

This commit is contained in:
Laureηt 2023-05-29 18:26:19 +02:00
parent 38c0249ea9
commit 1d8a6bc573
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI
2 changed files with 48 additions and 39 deletions

View file

@ -23,7 +23,16 @@ class LLMClient(AsyncClient):
ggml_path: Path,
room: str,
):
"""Create a new LLMClient instance."""
"""Create a new LLMClient instance.
Args:
username (`str`): The username to log in as.
homeserver (`str`): The homeserver to connect to.
device_id (`str`): The device ID to use.
preprompt (`str`): The preprompt to use.
ggml_path (`Path`): The path to the GGML model.
room (`str`): The room to join.
"""
self.uid = f"@{username}:{homeserver.removeprefix('https://')}"
self.spawn_time = time.time() * 1000
self.username = username
@ -50,8 +59,13 @@ class LLMClient(AsyncClient):
# add callbacks
self.add_event_callback(self.message_callback, RoomMessageText) # type: ignore
async def message_callback(self, room: MatrixRoom, event: RoomMessageText):
"""Process new messages as they come in."""
async def message_callback(self, room: MatrixRoom, event: RoomMessageText) -> None:
"""Process new messages as they come in.
Args:
room (`MatrixRoom`): The room the message was sent in.
event (`RoomMessageText`): The message event.
"""
logger.debug(f"New RoomMessageText: {event.source}")
# ignore messages pre-dating our spawn time
@ -148,3 +162,19 @@ class LLMClient(AsyncClient):
"body": output,
},
)
async def start(self, password, sync_timeout=30000):
"""Start the client.
Args:
password (`str`): The password to log in with.
sync_timeout (`int`, default `30000`): The sync timeout in milliseconds.
"""
# Login to the homeserver
logger.debug(await self.login(password))
# Join the room, if not already joined
logger.debug(await self.join(self.room))
# Sync with the server forever
await self.sync_forever(timeout=sync_timeout)

View file

@ -66,6 +66,13 @@ logger = logging.getLogger("nio-llm.main")
default="stable-vicuna-13B.ggmlv3.q5_1.bin",
show_default=True,
)
@click.option(
"--sync-timeout",
"-s",
help="The timeout to use when syncing with the homeserver.",
default=30000,
show_default=True,
)
def main(
*,
room: str,
@ -76,6 +83,7 @@ def main(
homeserver: str,
ggml_repoid: str,
ggml_filename: str,
sync_timeout: int,
) -> None:
"""Run the main program.
@ -89,34 +97,6 @@ def main(
),
)
# start the async loop
asyncio.get_event_loop().run_until_complete(
_main(
room=room,
password=password,
username=username,
device_id=device_id,
ggml_path=ggml_path,
preprompt=preprompt,
homeserver=homeserver,
),
)
async def _main(
*,
room: str,
password: str,
username: str,
device_id: str,
preprompt: str,
ggml_path: Path,
homeserver: str,
) -> None:
"""Run the async main program.
Create the client, login, join the room, and sync forever.
"""
# create the client
client = LLMClient(
room=room,
@ -127,14 +107,13 @@ async def _main(
homeserver=homeserver,
)
# Login to the homeserver
logger.debug(await client.login(password))
# Join the room, if not already joined
logger.debug(await client.join(room))
# Sync with the server forever
await client.sync_forever(timeout=30000)
# start the client
asyncio.get_event_loop().run_until_complete(
client.start(
password=password,
sync_timeout=sync_timeout,
),
)
if __name__ == "__main__":