mirror of
https://github.com/Laurent2916/nio-llm.git
synced 2024-11-23 22:58:48 +00:00
✨ use chat history in prompt
This commit is contained in:
parent
35b987dd14
commit
8abae999e0
|
@ -2,8 +2,8 @@
|
|||
|
||||
import logging
|
||||
import time
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
|
||||
from llama_cpp import Llama
|
||||
from nio import AsyncClient, MatrixRoom, RoomMessageText
|
||||
|
@ -41,8 +41,12 @@ class LLMClient(AsyncClient):
|
|||
self.llm = Llama(
|
||||
model_path=str(ggml_path),
|
||||
n_threads=12,
|
||||
n_ctx=512 + 128,
|
||||
)
|
||||
|
||||
# create message history queue
|
||||
self.history: deque[RoomMessageText] = deque(maxlen=10)
|
||||
|
||||
# add callbacks
|
||||
self.add_event_callback(self.message_callback, RoomMessageText) # type: ignore
|
||||
|
||||
|
@ -50,11 +54,6 @@ class LLMClient(AsyncClient):
|
|||
"""Process new messages as they come in."""
|
||||
logger.debug(f"New RoomMessageText: {event.source}")
|
||||
|
||||
# ignore our own messages
|
||||
if event.sender == self.user:
|
||||
logger.debug("Ignoring our own message.")
|
||||
return
|
||||
|
||||
# ignore messages pre-dating our spawn time
|
||||
if event.server_timestamp < self.spawn_time:
|
||||
logger.debug("Ignoring message pre-spawn.")
|
||||
|
@ -70,6 +69,14 @@ class LLMClient(AsyncClient):
|
|||
logger.debug("Ignoring edited message.")
|
||||
return
|
||||
|
||||
# update history
|
||||
self.history.append(event)
|
||||
|
||||
# ignore our own messages
|
||||
if event.sender == self.user:
|
||||
logger.debug("Ignoring our own message.")
|
||||
return
|
||||
|
||||
# ignore messages not mentioning us
|
||||
if not (
|
||||
"format" in event.source["content"]
|
||||
|
@ -81,15 +88,24 @@ class LLMClient(AsyncClient):
|
|||
logger.debug("Ignoring message not directed at us.")
|
||||
return
|
||||
|
||||
# generate prompt from message
|
||||
prompt = dedent(
|
||||
f"""
|
||||
{self.preprompt}
|
||||
<{event.sender}>: {event.body}
|
||||
<{self.username}>:
|
||||
""",
|
||||
).strip()
|
||||
logger.debug(f"Prompt: {prompt}")
|
||||
# generate prompt from message and history
|
||||
history = "\n".join(f"<{message.sender}>: {message.body}" for message in self.history)
|
||||
prompt = "\n".join([self.preprompt, history, f"<{self.uid}>:"])
|
||||
tokens = self.llm.tokenize(str.encode(prompt))
|
||||
logger.debug(f"Prompt:\n{prompt}")
|
||||
logger.debug(f"Tokens: {len(tokens)}")
|
||||
|
||||
if len(tokens) > 512:
|
||||
logger.debug("Prompt too long, skipping.")
|
||||
await self.room_send(
|
||||
room_id=self.room,
|
||||
message_type="m.room.message",
|
||||
content={
|
||||
"msgtype": "m.emote",
|
||||
"body": "reached prompt token limit",
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
# enable typing indicator
|
||||
await self.room_typing(
|
||||
|
@ -99,10 +115,11 @@ class LLMClient(AsyncClient):
|
|||
)
|
||||
|
||||
# generate response using llama.cpp
|
||||
senders = [f"<{message.sender}>" for message in self.history]
|
||||
output = self.llm(
|
||||
prompt,
|
||||
max_tokens=100,
|
||||
stop=[f"<{event.sender}>"],
|
||||
max_tokens=128,
|
||||
stop=[f"<{self.uid}>", "### Human", "### Assistant", *senders],
|
||||
echo=True,
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in a new issue