diff --git a/nio-llm/test.py b/nio-llm/test.py deleted file mode 100644 index c96de97..0000000 --- a/nio-llm/test.py +++ /dev/null @@ -1,28 +0,0 @@ -from textwrap import dedent - -from llama_cpp import Llama - -llm = Llama(model_path="../../../llama.cpp/models/sv13B/stable-vicuna-13B.ggml.q5_1.bin", n_threads=12) - -msg = dedent( - """ - You are pipobot, an arrogant assistant. Answer as concisely as possible. - <@fainsil:inpt.fr>: Qu'est ce qu'une intégrale de Lebesgue ? - <@pipobot:inpt.fr>: - """, -).strip() - -print(msg) -print(repr(msg)) - -output = llm( - msg, - max_tokens=100, - stop=["<@fainsil:inpt.fr>:", "\n"], - echo=True, -) - -print(output) -res = output["choices"][0]["text"] -print(res) -print(res.removeprefix(msg).strip()) diff --git a/nio_llm/__init__.py b/nio_llm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nio-llm/client.py b/nio_llm/client.py similarity index 58% rename from nio-llm/client.py rename to nio_llm/client.py index be1b447..708bf46 100644 --- a/nio-llm/client.py +++ b/nio_llm/client.py @@ -1,17 +1,14 @@ """A Matrix client that uses Llama to respond to messages.""" -import asyncio import logging import time +from pathlib import Path from textwrap import dedent -import click -from huggingface_hub import hf_hub_download from llama_cpp import Llama from nio import AsyncClient, MatrixRoom, RoomMessageText -from rich.logging import RichHandler -logger = logging.getLogger("nio-llm") +logger = logging.getLogger("nio-llm.client") class LLMClient(AsyncClient): @@ -23,7 +20,7 @@ class LLMClient(AsyncClient): homeserver: str, device_id: str, preprompt: str, - ggml_path: str, + ggml_path: Path, room: str, ): """Create a new LLMClient instance.""" @@ -42,7 +39,7 @@ class LLMClient(AsyncClient): # create the Llama instance self.llm = Llama( - model_path=ggml_path, + model_path=str(ggml_path), n_threads=12, ) @@ -125,87 +122,3 @@ class LLMClient(AsyncClient): "body": output, }, ) - - -@click.command() -@click.option("--homeserver", "-h", help="The homeserver to connect to.", required=True) -@click.option("--device-id", "-d", help="The device ID to use.", required=True) -@click.option("--username", "-u", help="The username to log in as.", required=True) -@click.option("--password", "-p", help="The password to log in with.", required=True) -@click.option("--room", "-r", help="The room to join.", required=True) -@click.option("--preprompt", "-t", help="The preprompt to use.", required=True) -def main( - homeserver: str, - device_id: str, - username: str, - password: str, - room: str, - preprompt: str, -) -> None: - """Run the main program. - - Download the model from HuggingFace Hub and start the async loop. - """ - # download the model - ggml_path = hf_hub_download( - repo_id="TheBloke/stable-vicuna-13B-GGML", - filename="stable-vicuna-13B.ggmlv3.q5_1.bin", - ) - - asyncio.get_event_loop().run_until_complete( - _main( - ggml_path=ggml_path, - homeserver=homeserver, - device_id=device_id, - username=username, - password=password, - preprompt=preprompt, - room=room, - ), - ) - - -async def _main( - homeserver: str, - device_id: str, - username: str, - password: str, - room: str, - preprompt: str, - ggml_path: str, -) -> None: - """Run the async main program. - - Create the client, login, join the room, and sync forever. - """ - # create the client - client = LLMClient( - homeserver=homeserver, - device_id=device_id, - username=username, - room=room, - preprompt=preprompt, - ggml_path=ggml_path, - ) - - # Login to the homeserver - logger.debug(await client.login(password)) - - # Join the room, if not already joined - logger.debug(await client.join(room)) - - # Sync with the server forever - await client.sync_forever(timeout=30000) - - -if __name__ == "__main__": - # set up logging - logging.captureWarnings(True) - logging.basicConfig( - level="DEBUG", - format="%(name)s: %(message)s", - handlers=[RichHandler(markup=True)], - ) - - # run the main program (with environment variables) - main(auto_envvar_prefix="NIOLLM") diff --git a/nio_llm/main.py b/nio_llm/main.py new file mode 100644 index 0000000..348d6ae --- /dev/null +++ b/nio_llm/main.py @@ -0,0 +1,100 @@ +"""The main program for nio-llm.""" + +import asyncio +import logging +from pathlib import Path + +import click +from huggingface_hub import hf_hub_download +from rich.logging import RichHandler + +from nio_llm.client import LLMClient + +logger = logging.getLogger("nio-llm.main") + + +@click.command() +@click.option("--homeserver", "-h", help="The homeserver to connect to.", required=True) +@click.option("--device-id", "-d", help="The device ID to use.", required=True) +@click.option("--username", "-u", help="The username to log in as.", required=True) +@click.option("--password", "-p", help="The password to log in with.", required=True) +@click.option("--room", "-r", help="The room to join.", required=True) +@click.option("--preprompt", "-t", help="The preprompt to use.", required=True) +def main( + homeserver: str, + device_id: str, + username: str, + password: str, + room: str, + preprompt: str, +) -> None: + """Run the main program. + + Download the model from HuggingFace Hub and start the async loop. + """ + # download the model + ggml_path = Path( + hf_hub_download( + repo_id="TheBloke/stable-vicuna-13B-GGML", + filename="stable-vicuna-13B.ggmlv3.q5_1.bin", + ), + ) + + # start the async loop + asyncio.get_event_loop().run_until_complete( + _main( + ggml_path=ggml_path, + homeserver=homeserver, + device_id=device_id, + username=username, + password=password, + preprompt=preprompt, + room=room, + ), + ) + + +async def _main( + homeserver: str, + device_id: str, + username: str, + password: str, + room: str, + preprompt: str, + ggml_path: Path, +) -> None: + """Run the async main program. + + Create the client, login, join the room, and sync forever. + """ + # create the client + client = LLMClient( + homeserver=homeserver, + device_id=device_id, + username=username, + room=room, + preprompt=preprompt, + ggml_path=ggml_path, + ) + + # Login to the homeserver + logger.debug(await client.login(password)) + + # Join the room, if not already joined + logger.debug(await client.join(room)) + + # Sync with the server forever + await client.sync_forever(timeout=30000) + + +if __name__ == "__main__": + # set up logging + logging.captureWarnings(True) + logging.basicConfig( + level="DEBUG", + format="%(name)s: %(message)s", + handlers=[RichHandler(markup=True)], + ) + + # run the main program (with environment variables) + main(auto_envvar_prefix="NIOLLM")