use chat history in prompt

This commit is contained in:
Laureηt 2023-05-27 14:58:05 +02:00
parent 35b987dd14
commit 8abae999e0
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI

View file

@ -2,8 +2,8 @@
import logging
import time
from collections import deque
from pathlib import Path
from textwrap import dedent
from llama_cpp import Llama
from nio import AsyncClient, MatrixRoom, RoomMessageText
@ -41,8 +41,12 @@ class LLMClient(AsyncClient):
self.llm = Llama(
model_path=str(ggml_path),
n_threads=12,
n_ctx=512 + 128,
)
# create message history queue
self.history: deque[RoomMessageText] = deque(maxlen=10)
# add callbacks
self.add_event_callback(self.message_callback, RoomMessageText) # type: ignore
@ -50,11 +54,6 @@ class LLMClient(AsyncClient):
"""Process new messages as they come in."""
logger.debug(f"New RoomMessageText: {event.source}")
# ignore our own messages
if event.sender == self.user:
logger.debug("Ignoring our own message.")
return
# ignore messages pre-dating our spawn time
if event.server_timestamp < self.spawn_time:
logger.debug("Ignoring message pre-spawn.")
@ -70,6 +69,14 @@ class LLMClient(AsyncClient):
logger.debug("Ignoring edited message.")
return
# update history
self.history.append(event)
# ignore our own messages
if event.sender == self.user:
logger.debug("Ignoring our own message.")
return
# ignore messages not mentioning us
if not (
"format" in event.source["content"]
@ -81,15 +88,24 @@ class LLMClient(AsyncClient):
logger.debug("Ignoring message not directed at us.")
return
# generate prompt from message
prompt = dedent(
f"""
{self.preprompt}
<{event.sender}>: {event.body}
<{self.username}>:
""",
).strip()
logger.debug(f"Prompt: {prompt}")
# generate prompt from message and history
history = "\n".join(f"<{message.sender}>: {message.body}" for message in self.history)
prompt = "\n".join([self.preprompt, history, f"<{self.uid}>:"])
tokens = self.llm.tokenize(str.encode(prompt))
logger.debug(f"Prompt:\n{prompt}")
logger.debug(f"Tokens: {len(tokens)}")
if len(tokens) > 512:
logger.debug("Prompt too long, skipping.")
await self.room_send(
room_id=self.room,
message_type="m.room.message",
content={
"msgtype": "m.emote",
"body": "reached prompt token limit",
},
)
return
# enable typing indicator
await self.room_typing(
@ -99,10 +115,11 @@ class LLMClient(AsyncClient):
)
# generate response using llama.cpp
senders = [f"<{message.sender}>" for message in self.history]
output = self.llm(
prompt,
max_tokens=100,
stop=[f"<{event.sender}>"],
max_tokens=128,
stop=[f"<{self.uid}>", "### Human", "### Assistant", *senders],
echo=True,
)