Source code for artkit.model.llm._historized

# -----------------------------------------------------------------------------
# © 2024 Boston Consulting Group. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------

"""
Implementation of ``HistorizedChatModel``.
"""

from __future__ import annotations

import logging
from typing import Any, Generic, TypeVar

from pytools.api import appenddoc

from .base import ChatModel, ChatModelAdapter
from .history import AssistantMessage, ChatHistory, UserMessage

log = logging.getLogger(__name__)

__all__ = [
    "HistorizedChatModel",
]


#
# Type variables
#

T_ChatModel_ret = TypeVar("T_ChatModel_ret", bound=ChatModel, covariant=True)


#
# Classes
#


[docs] class HistorizedChatModel(ChatModelAdapter[T_ChatModel_ret], Generic[T_ChatModel_ret]): """ An LLM chat system that maintains a chat history. """ #: The chat history. history: ChatHistory @appenddoc(to=ChatModelAdapter.__init__) def __init__( self, model: T_ChatModel_ret, *, max_history: int | None = None, ) -> None: """ :param max_history: the maximum number of messages to store in the chat history (defaults to ``None`` for no limit) """ super().__init__(model=model) self.history = ChatHistory(max_length=max_history) @property def max_history(self) -> int | None: """ The maximum length of the chat history. ``None`` for no limit. """ return self.history.max_length
[docs] async def get_response( self, message: str, *, history: ChatHistory | None = None, **model_params: dict[str, Any], ) -> list[str]: """ Get a response, or multiple alternative responses, to a user message. Update the chat history with the user message and the first response. :param message: the user message :param history: must be ``None`` since the chat history is managed internally :param model_params: additional parameters for the chat system :return: the response or alternative responses :raises ValueError: if a history is passed as the ``history`` argument """ if history is not None: raise ValueError("Cannot provide a history to a historized chat system") responses = await self.model.get_response( message=message, history=self.history, **model_params ) self.history.add_message(UserMessage(message)) if responses: self.history.add_message(AssistantMessage(responses[0])) return responses