make history_size configurable

This commit is contained in:
Laureηt 2023-10-19 14:15:09 +00:00
parent ca22fe640f
commit 12080ad3a5
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI
2 changed files with 9 additions and 1 deletions

View file

@ -23,6 +23,7 @@ def main(
openai_api_endpoint: str = "http://localhost:8000/v1", openai_api_endpoint: str = "http://localhost:8000/v1",
openai_temperature: float = 0, openai_temperature: float = 0,
openai_max_tokens: int = 256, openai_max_tokens: int = 256,
history_size: int = 3,
) -> None: ) -> None:
"""Instantiate and start the client. """Instantiate and start the client.
@ -57,6 +58,9 @@ def main(
openai_max_tokens (`int`): openai_max_tokens (`int`):
The OpenAI max tokens to use. The OpenAI max tokens to use.
Defaults to `256`. Defaults to `256`.
history_size (`int`):
The number of messages to keep in history.
Defaults to `3`.
""" """
# create the client # create the client
client = LLMClient( client = LLMClient(
@ -69,6 +73,7 @@ def main(
openai_api_endpoint=openai_api_endpoint, openai_api_endpoint=openai_api_endpoint,
openai_temperature=openai_temperature, openai_temperature=openai_temperature,
openai_max_tokens=openai_max_tokens, openai_max_tokens=openai_max_tokens,
history_size=history_size,
) )
# start the client # start the client

View file

@ -22,6 +22,7 @@ class LLMClient(AsyncClient):
openai_api_endpoint: str, openai_api_endpoint: str,
openai_temperature: float, openai_temperature: float,
openai_max_tokens: int, openai_max_tokens: int,
history_size: int,
) -> None: ) -> None:
"""Create a new LLMClient instance. """Create a new LLMClient instance.
@ -44,6 +45,8 @@ class LLMClient(AsyncClient):
The OpenAI temperature to use. The OpenAI temperature to use.
openai_max_tokens (`int`): openai_max_tokens (`int`):
The OpenAI max tokens to use. The OpenAI max tokens to use.
history_size (`int`, default `3`):
The number of messages to keep in history.
""" """
self.uid = f"@{username}:{homeserver.removeprefix('https://')}" self.uid = f"@{username}:{homeserver.removeprefix('https://')}"
self.spawn_time = time.time() * 1000 self.spawn_time = time.time() * 1000
@ -65,7 +68,7 @@ class LLMClient(AsyncClient):
) )
# create message history queue # create message history queue
self.history: deque[RoomMessageText] = deque(maxlen=10) self.history: deque[RoomMessageText] = deque(maxlen=history_size)
# add callbacks # add callbacks
self.add_event_callback(self.message_callback, RoomMessageText) # type: ignore self.add_event_callback(self.message_callback, RoomMessageText) # type: ignore