download ggml weights from huggingface_hub

This commit is contained in:
Laureηt 2023-05-24 20:48:00 +02:00
parent 8598244fe8
commit 65bcb80f85
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI

View file

@ -6,6 +6,7 @@ import time
from textwrap import dedent
import click
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
from nio import AsyncClient, MatrixRoom, RoomMessageText
@ -22,6 +23,7 @@ class LLMClient(AsyncClient):
device_id: str,
preprompt: str,
room: str,
ggml_path: str,
):
"""Create a new LLMClient instance."""
super().__init__(
@ -37,7 +39,7 @@ class LLMClient(AsyncClient):
# create the Llama instance
self.llm = Llama(
model_path="../../../llama.cpp/models/sv13B/stable-vicuna-13B.ggml.q5_1.bin",
model_path=ggml_path,
n_threads=12,
)
@ -88,7 +90,7 @@ class LLMClient(AsyncClient):
output = self.llm(
prompt,
max_tokens=100,
stop=["<{event.sender}>"],
stop=[f"<{event.sender}>"],
echo=True,
)
@ -123,17 +125,28 @@ def main(
username: str,
password: str,
room: str,
preprompt,
preprompt: str,
) -> None:
"""Run the main program.
Download the model from HuggingFace Hub and start the async loop.
"""
# download the model
ggml_path = hf_hub_download(
repo_id="TheBloke/stable-vicuna-13B-GGML",
filename="stable-vicuna-13B.ggmlv3.q5_1.bin",
)
asyncio.get_event_loop().run_until_complete(
_main(
ggml_path=ggml_path,
homeserver=homeserver,
device_id=device_id,
username=username,
password=password,
preprompt=preprompt,
room=room,
)
),
)
@ -143,9 +156,13 @@ async def _main(
username: str,
password: str,
room: str,
preprompt,
preprompt: str,
ggml_path: str,
) -> None:
"""Run the main program."""
"""Run the async main program.
Create the client, login, join the room, and sync forever.
"""
# create the client
client = LLMClient(
homeserver=homeserver,
@ -153,6 +170,7 @@ async def _main(
username=username,
room=room,
preprompt=preprompt,
ggml_path=ggml_path,
)
# Login to the homeserver
@ -166,6 +184,8 @@ async def _main(
if __name__ == "__main__":
# set up logging
logging.basicConfig(level=logging.DEBUG)
main(auto_envvar_prefix="NIOLLM")
# run the main program (with environment variables)
main(auto_envvar_prefix="NIOLLM")