Source code for artkit.model.llm.bedrock.base._base
# -----------------------------------------------------------------------------
# © 2024 Boston Consulting Group. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
"""
Bedrock LLM systems.
"""
from __future__ import annotations
import logging
from abc import ABCMeta
from typing import Any, TypeVar
from pytools.api import MissingClassMeta, appenddoc, inheritdoc, subsdoc
from ...base import HTTPXChatConnector
log = logging.getLogger(__name__)
__all__ = ["BaseBedrockChat"]
try:
from botocore.auth import SigV4Auth
from botocore.session import Session
except ImportError:
class SigV4Auth(metaclass=MissingClassMeta, module="SigV4Auth"): # type: ignore
"""Placeholder class for missing ``SigV4Auth`` class."""
class Session(metaclass=MissingClassMeta, module="Session"): # type: ignore
"""Placeholder class for missing ``Session`` class."""
try:
from httpx import AsyncClient
except ImportError:
class AsyncClient(metaclass=MissingClassMeta, module="httpx"): # type: ignore
"""Placeholder class for missing ``AsyncClient`` class."""
__all__ = ["BaseBedrockChat"]
#
# Type variables
#
T_BaseBedrockChat = TypeVar("T_BaseBedrockChat", bound="BaseBedrockChat")
#
# Class declarations
#
log = logging.getLogger(__name__)
[docs]
@inheritdoc(match="""[see superclass]""")
class BaseBedrockChat(HTTPXChatConnector, metaclass=ABCMeta):
"""
Base class for Bedrock LLMs.
"""
region: str | None
[docs]
@classmethod
def get_default_api_key_env(cls) -> str:
"""[see superclass]"""
return ""
@subsdoc(
# The pattern matches the row defining model_params, and move it to the end
# of the docstring.
pattern=r"(:param model_params: .*\n)((:?.|\n)*\S)(\n|\s)*",
replacement=r"\2\1",
)
@appenddoc(to=HTTPXChatConnector.__init__)
def __init__(
self,
*,
model_id: str,
api_key_env: str | None = None,
initial_delay: float | None = None,
exponential_base: float | None = None,
jitter: bool | None = None,
max_retries: int | None = None,
system_prompt: str | None = None,
httpx_client_kwargs: dict[str, Any] | None = None,
region: str | None = None,
**model_params: Any,
) -> None:
"""
:param region: The specific AWS region to connect to.
:raises CredentialsNotFoundError: if unable to find AWS Credentials.
"""
super().__init__(
model_id=model_id,
api_key_env=api_key_env,
initial_delay=initial_delay,
exponential_base=exponential_base,
jitter=jitter,
max_retries=max_retries,
system_prompt=system_prompt,
httpx_client_kwargs=httpx_client_kwargs,
**model_params,
)
self.region = region if region else "us-east-1"
self.session = Session()
self.credentials = self.session.get_credentials()
self.auth = SigV4Auth(self.credentials, "bedrock", self.region)
self.endpoint = f"https://bedrock-runtime.{self.region}.amazonaws.com/model/{self.model_id}/invoke"