use chat history in prompt

This commit is contained in:
Laureηt 2023-05-27 14:58:05 +02:00
parent 35b987dd14
commit 8abae999e0
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI

View file

@ -2,8 +2,8 @@
import logging import logging
import time import time
from collections import deque
from pathlib import Path from pathlib import Path
from textwrap import dedent
from llama_cpp import Llama from llama_cpp import Llama
from nio import AsyncClient, MatrixRoom, RoomMessageText from nio import AsyncClient, MatrixRoom, RoomMessageText
@ -41,8 +41,12 @@ class LLMClient(AsyncClient):
self.llm = Llama( self.llm = Llama(
model_path=str(ggml_path), model_path=str(ggml_path),
n_threads=12, n_threads=12,
n_ctx=512 + 128,
) )
# create message history queue
self.history: deque[RoomMessageText] = deque(maxlen=10)
# add callbacks # add callbacks
self.add_event_callback(self.message_callback, RoomMessageText) # type: ignore self.add_event_callback(self.message_callback, RoomMessageText) # type: ignore
@ -50,11 +54,6 @@ class LLMClient(AsyncClient):
"""Process new messages as they come in.""" """Process new messages as they come in."""
logger.debug(f"New RoomMessageText: {event.source}") logger.debug(f"New RoomMessageText: {event.source}")
# ignore our own messages
if event.sender == self.user:
logger.debug("Ignoring our own message.")
return
# ignore messages pre-dating our spawn time # ignore messages pre-dating our spawn time
if event.server_timestamp < self.spawn_time: if event.server_timestamp < self.spawn_time:
logger.debug("Ignoring message pre-spawn.") logger.debug("Ignoring message pre-spawn.")
@ -70,6 +69,14 @@ class LLMClient(AsyncClient):
logger.debug("Ignoring edited message.") logger.debug("Ignoring edited message.")
return return
# update history
self.history.append(event)
# ignore our own messages
if event.sender == self.user:
logger.debug("Ignoring our own message.")
return
# ignore messages not mentioning us # ignore messages not mentioning us
if not ( if not (
"format" in event.source["content"] "format" in event.source["content"]
@ -81,15 +88,24 @@ class LLMClient(AsyncClient):
logger.debug("Ignoring message not directed at us.") logger.debug("Ignoring message not directed at us.")
return return
# generate prompt from message # generate prompt from message and history
prompt = dedent( history = "\n".join(f"<{message.sender}>: {message.body}" for message in self.history)
f""" prompt = "\n".join([self.preprompt, history, f"<{self.uid}>:"])
{self.preprompt} tokens = self.llm.tokenize(str.encode(prompt))
<{event.sender}>: {event.body} logger.debug(f"Prompt:\n{prompt}")
<{self.username}>: logger.debug(f"Tokens: {len(tokens)}")
""",
).strip() if len(tokens) > 512:
logger.debug(f"Prompt: {prompt}") logger.debug("Prompt too long, skipping.")
await self.room_send(
room_id=self.room,
message_type="m.room.message",
content={
"msgtype": "m.emote",
"body": "reached prompt token limit",
},
)
return
# enable typing indicator # enable typing indicator
await self.room_typing( await self.room_typing(
@ -99,10 +115,11 @@ class LLMClient(AsyncClient):
) )
# generate response using llama.cpp # generate response using llama.cpp
senders = [f"<{message.sender}>" for message in self.history]
output = self.llm( output = self.llm(
prompt, prompt,
max_tokens=100, max_tokens=128,
stop=[f"<{event.sender}>"], stop=[f"<{self.uid}>", "### Human", "### Assistant", *senders],
echo=True, echo=True,
) )