Source code for artkit.model.llm._generator
# -----------------------------------------------------------------------------
# © 2024 Boston Consulting Group. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
"""
Implementation of the text producer.
"""
import logging
import re
from collections.abc import Iterable
from pytools.api import inheritdoc
from pytools.expression import (
Expression,
HasExpressionRepr,
expression_from_init_params,
)
from pytools.text import TextTemplate
from .base import ChatModel
log = logging.getLogger(__name__)
__all__ = [
"TextGenerator",
]
[docs]
@inheritdoc(match="""[see superclass]""")
class TextGenerator(HasExpressionRepr):
"""
Generates text using an LLM.
Sets up an LLM with a system prompt that instructs the LLM to generate text items.
The system prompt must:
- include a set of required formatting keys, in the form of '{key}', that will be
substituted with values before sending the system prompt to the LLM
- instruct the LLM to accept an integer as the user prompt, and generate a number
of texts indicated by the user prompt
- instruct the LLM to separate generated text items by lines consisting of '#####'
"""
#: The LLM system used to generate text items.
llm: ChatModel
def __init__(
self, *, llm: ChatModel, system_prompt: str, formatting_keys: Iterable[str]
) -> None:
"""
:param llm: the LLM system used to generate text
:param system_prompt: the system prompt instructing the LLM to generate the
text items
:param formatting_keys: the names of the formatting keys used in the system
prompt
"""
super().__init__()
self.llm = llm
self._system_prompt_template = TextTemplate(
format_string=system_prompt, required_keys=formatting_keys
)
@property
def system_prompt(self) -> str:
"""
The system prompt used to generate text items.
"""
return self._system_prompt_template.format_string
@property
def formatting_keys(self) -> set[str]:
"""
The formatting keys used in the system prompt.
"""
return self._system_prompt_template.formatting_keys
[docs]
async def make_text_items(self, *, n: int, attributes: dict[str, str]) -> list[str]:
"""
Use the LLM to generate the given number of text items, substituting the
given attributes for the formatting keys in the system prompt.
Calls method :meth:`parse_llm_response` to parse the LLM response into
individual text items.
:param n: the number of text items to generate
:param attributes: the attributes to substitute for the formatting keys in
the system prompt
:return: the generated text items
"""
# we make at least 10% more text items than requested, since the LLM sometimes
# does not produce the full number of requested text items
n_with_buffer = n * 11 // 10 + 1
response = await self.llm.with_system_prompt(
system_prompt=(
self._system_prompt_template.format_with_attributes(**attributes)
)
).get_response(message=str(n_with_buffer))
# get the text items from the last response
if len(response) > 1:
log.warning("LLM returned more than one response; using the last one")
text_items = self.parse_llm_response(response[-1])
if len(text_items) < n:
raise ValueError(
f"Expected {n} text items but only got {len(text_items)} "
f"in LLM response\n{response[-1]!r}"
)
return text_items[:n]
[docs]
@staticmethod
def parse_llm_response(response: str) -> list[str]:
"""
Parse an LLM response into individual text items.
Unless overridden, splits the text using ``"#####"`` as the separator.
Called by method :meth:`.make_text_items` with the LLM response.
:param response: the LLM response to parse
:return: a list of text items
"""
return list(
filter(
# exclude empty descriptions, e.g., after a trailing "#####"
None,
map(
# strip whitespace
str.strip,
# split the response by the separator string
# be flexible with the number of '#' characters
# since the LLM may not use the precise number indicated
re.split(r"###+", response),
),
)
)
[docs]
def to_expression(self) -> Expression:
"""[see superclass]"""
return expression_from_init_params(self)
#
# Auxiliary functions
#