From 12080ad3a56216f4641b5d2cf56e380f1823fb18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laure=CE=B7t?= Date: Thu, 19 Oct 2023 14:15:09 +0000 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20make=20history=5Fsize=20configurabl?= =?UTF-8?q?e?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/nio_llm/__main__.py | 5 +++++ src/nio_llm/client.py | 5 ++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/nio_llm/__main__.py b/src/nio_llm/__main__.py index 510e39e..183ef39 100644 --- a/src/nio_llm/__main__.py +++ b/src/nio_llm/__main__.py @@ -23,6 +23,7 @@ def main( openai_api_endpoint: str = "http://localhost:8000/v1", openai_temperature: float = 0, openai_max_tokens: int = 256, + history_size: int = 3, ) -> None: """Instantiate and start the client. @@ -57,6 +58,9 @@ def main( openai_max_tokens (`int`): The OpenAI max tokens to use. Defaults to `256`. + history_size (`int`): + The number of messages to keep in history. + Defaults to `3`. """ # create the client client = LLMClient( @@ -69,6 +73,7 @@ def main( openai_api_endpoint=openai_api_endpoint, openai_temperature=openai_temperature, openai_max_tokens=openai_max_tokens, + history_size=history_size, ) # start the client diff --git a/src/nio_llm/client.py b/src/nio_llm/client.py index d374cb7..9cf1844 100644 --- a/src/nio_llm/client.py +++ b/src/nio_llm/client.py @@ -22,6 +22,7 @@ class LLMClient(AsyncClient): openai_api_endpoint: str, openai_temperature: float, openai_max_tokens: int, + history_size: int, ) -> None: """Create a new LLMClient instance. @@ -44,6 +45,8 @@ class LLMClient(AsyncClient): The OpenAI temperature to use. openai_max_tokens (`int`): The OpenAI max tokens to use. + history_size (`int`, default `3`): + The number of messages to keep in history. """ self.uid = f"@{username}:{homeserver.removeprefix('https://')}" self.spawn_time = time.time() * 1000 @@ -65,7 +68,7 @@ class LLMClient(AsyncClient): ) # create message history queue - self.history: deque[RoomMessageText] = deque(maxlen=10) + self.history: deque[RoomMessageText] = deque(maxlen=history_size) # add callbacks self.add_event_callback(self.message_callback, RoomMessageText) # type: ignore