download ggml weights from huggingface_hub

This commit is contained in:
Laureηt 2023-05-24 20:48:00 +02:00
parent 8598244fe8
commit 65bcb80f85
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI

View file

@ -6,6 +6,7 @@ import time
from textwrap import dedent from textwrap import dedent
import click import click
from huggingface_hub import hf_hub_download
from llama_cpp import Llama from llama_cpp import Llama
from nio import AsyncClient, MatrixRoom, RoomMessageText from nio import AsyncClient, MatrixRoom, RoomMessageText
@ -22,6 +23,7 @@ class LLMClient(AsyncClient):
device_id: str, device_id: str,
preprompt: str, preprompt: str,
room: str, room: str,
ggml_path: str,
): ):
"""Create a new LLMClient instance.""" """Create a new LLMClient instance."""
super().__init__( super().__init__(
@ -37,7 +39,7 @@ class LLMClient(AsyncClient):
# create the Llama instance # create the Llama instance
self.llm = Llama( self.llm = Llama(
model_path="../../../llama.cpp/models/sv13B/stable-vicuna-13B.ggml.q5_1.bin", model_path=ggml_path,
n_threads=12, n_threads=12,
) )
@ -88,7 +90,7 @@ class LLMClient(AsyncClient):
output = self.llm( output = self.llm(
prompt, prompt,
max_tokens=100, max_tokens=100,
stop=["<{event.sender}>"], stop=[f"<{event.sender}>"],
echo=True, echo=True,
) )
@ -123,17 +125,28 @@ def main(
username: str, username: str,
password: str, password: str,
room: str, room: str,
preprompt, preprompt: str,
) -> None: ) -> None:
"""Run the main program.
Download the model from HuggingFace Hub and start the async loop.
"""
# download the model
ggml_path = hf_hub_download(
repo_id="TheBloke/stable-vicuna-13B-GGML",
filename="stable-vicuna-13B.ggmlv3.q5_1.bin",
)
asyncio.get_event_loop().run_until_complete( asyncio.get_event_loop().run_until_complete(
_main( _main(
ggml_path=ggml_path,
homeserver=homeserver, homeserver=homeserver,
device_id=device_id, device_id=device_id,
username=username, username=username,
password=password, password=password,
preprompt=preprompt, preprompt=preprompt,
room=room, room=room,
) ),
) )
@ -143,9 +156,13 @@ async def _main(
username: str, username: str,
password: str, password: str,
room: str, room: str,
preprompt, preprompt: str,
ggml_path: str,
) -> None: ) -> None:
"""Run the main program.""" """Run the async main program.
Create the client, login, join the room, and sync forever.
"""
# create the client # create the client
client = LLMClient( client = LLMClient(
homeserver=homeserver, homeserver=homeserver,
@ -153,6 +170,7 @@ async def _main(
username=username, username=username,
room=room, room=room,
preprompt=preprompt, preprompt=preprompt,
ggml_path=ggml_path,
) )
# Login to the homeserver # Login to the homeserver
@ -166,6 +184,8 @@ async def _main(
if __name__ == "__main__": if __name__ == "__main__":
# set up logging
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
main(auto_envvar_prefix="NIOLLM")
# run the main program (with environment variables)
main(auto_envvar_prefix="NIOLLM") main(auto_envvar_prefix="NIOLLM")