Creating New Model Classes#

View notebook on GitHub

Introduction#

This notebook serves as a primer for developing ARTKIT model classes. This is most likely your first entrypoint to contributing to ARTKIT. So, to get you started we will cover:

  • Object-oriented programming motivations: The advantages of using class hierarchies and how this enables us to distribute features across client implementations

  • ARTKIT’s model class hierarchy: How are classes structured across ARTKIT and what levels of abstraction exist

  • Model implementation example: A deep-dive of the initialization and get_response implementation of the OpenAIChat class

  • Your implementation - a checklist: Steps to complete to code up your own client implementation

This is a technical guide for users who are looking to implement new functionalities within the ARTKIT library. While a basic understanding of python classes is assumed, this guide should be accessible for anyone with a data science background.

To get more information on the classes mentioned here go to the API reference.

Object-oriented programming motivations#

ARTKIT makes heavy use of class inheritance to implement connectors to Gen AI model providers. This allows us to take advantage of features specific to certain model providers while standardizing common utility methods such as asynchronous execution, managing chat history, and caching. It also means you can quickly implement classes to connect to new providers and immediately re-use these standard features without any additional coding.

When should I create a new model class? Currently ARTKIT provides interfaces for connecting to OpenAI, Anthropic, Hugging Face, Groq, and Google’s Gemini LLMs. It also supports OpenAI’s multi-modal models, specifically DALL-E models and vision endpoints. If there are additional model providers you’d like to use in your testing and evaluation pipeline, creating a new class is the best way to do so – and also improves the usefulness of ARTKIT for everyone! To get to know more about contributing to ARTKIT see our Contributor Guide.

What if I want to change core class functionality? Reviewing the model class hierarchy in this guide is a good starting point to developing new core features to the library, though implementing and testing these features will more take careful considerations and scrutiny, as introducing changes on higher levels of abstraction can have implications for other client implementations.

ARTKIT’s model class hierarchy#

Here’s a summary of the model class hierarchy, starting from the abstract GenAIModel to the OpenAIChat class:

  1. GenAIModel is an abstract representation of a generic Gen AI model

  2. ConnectorMixin is a mixin class that adds a common interface for client connections to a GenAIModel

  3. ChatModel is a an abstract class to have a common interface for generating responses for text prompts, emulating a chat - its a subclass of a GenAIModel

  4. ChatModelConnector combines the ConnectorMixin with the ChatModel interface to create chat model that connects to a client

  5. OpenAIChat is a subclass of ChatModelConnector to connect to OpenAI’s model and utilize them for chats

Here’s a tree of the relevant repo structure, so you can take a look at the implementation of these classes:

model/
  base/
    _model.py     --> GenAIModel, ConnectorMixin
  llm/
    base/
      _llm.py     --> ChatModel, ChatModelConnector
    openai/
      _openai.py  --> OpenAIChat

A similar hierarchy is shared for other model providers and modalities. For example, levels 1. - 4. are identical for the AnthropicChat class. This means that OpenAIChat and AnthropicChat share all attributes and method signatures of ChatModelConnector, such as get_client() but not the actual implementation of those methods. This enables model-specific parameters and response processing.

If you’re adding a new model class, you will most likely only need to make changes on hierarchy level 5.

It is similar for other modalities; for example, the OpenAIDiffusion class inherits from a DiffusionModelConnector, which is again a subclass of a DiffusionModel and a ConnectorMixin. Unlike a ChatModel, a DiffusionModel does not have a get_response() method. However, because it also inherits from ConnectorMixin, it will have the same get_client() method signature.

Model implementation example#

Next, we will demonstrate how the model class hierarchy is used through an example of the OpenAIChat class.

The goal is to illustrate both (i) what methods you need to implement to create a new model class and (ii) why those methods are set up the way that they are.

A ChatModelConnector has two required parameters: model_id and api_env_key. The OpenAIChat has additional model-specific parameters such as temperature. So when we initialize an OpenAIChat object as below, we’re also initializing the superclass to handle client connections:

[2]:
# Import ARTKIT
import artkit.api as ak

# Load API keys
from dotenv import load_dotenv
load_dotenv()

gpt4_chat = ak.OpenAIChat(model_id="gpt-4", api_key_env="OPENAI_API_KEY", temperature=0.5,)

To implement a chat model connector the following methods need to be implemented:

[3]:
#####################################
### model/llm/openai/_openai.py   ###
#####################################
from typing import Any

from openai import AsyncOpenAI

from artkit.model.llm.base import ChatModelConnector
from artkit.model.llm.history import ChatHistory


class OpenAIChat(ChatModelConnector[AsyncOpenAI]):

    # required by ConnectorMixin to establish a client connection
    @classmethod
    def get_default_api_key_env(cls) -> str:
        pass

    # required by ConnectorMixin to establish a client connection
    def _make_client(self) -> AsyncOpenAI:
        pass

    # required by ChatModel to send a message to the model
    async def get_response(  # pragma: no cover
        self,
        message: str,
        *,
        history: ChatHistory | None = None,
        **model_params: dict[str, Any],
    ) -> list[str]:
        pass

By unifying the interfaces for chat models we are enabled to follow a delegation pattern through which we can add external behaviors to all ChatModelConnectors by wrapping them in a separate class. CachedChatModel is a great example of this that enables us to cache requests to a model provider without caring about the actual implementation details. You can see us make use of this below.

[4]:
# Wrap the model in a CachedChatModel
gpt4_chat = ak.CachedChatModel(
    model=gpt4_chat,
    database="cache/creating_new_model_classes.db",
)

Now, let’s take a deeper dive into the OpenAIChat implementation of get_response() to highlight how we’re taking advantage of the full class hierarchy:

  1. The with_system_prompt() method, that you will see later, is inherited from the ChatModelConnector superclass to set the system_prompt property of the model

  2. Upon calling get_response(), we first use AsyncExitStack() to enter the model’s context manager

  3. Then we call the get_client() method of the ConnectorMixin superclass to fetch a cached OpenAI client instance.

  4. If no previous client instance exists, it is created via the model’s _make_client() implementation. This logic rests in the ConnectorMixin as well.

  5. Next, we format an input message for OpenAI’s chat endpoint based on the model’s history, system prompt, and input message

  6. The message is sent to OpenAI’s chat endpoint, along with other parameters set during model initialization (such as temperature and max tokens)

  7. Finally, we parse return a list of responses from OpenAI’s chat endpoint

While there are quite a few steps here, note that the only ones specific to the OpenAIChat class are 5-8.

That means if you’re creating a new custom class, all you need to worry about is getting a client instance, passing a message to the client, and returning its response.

Everything else can be abstracted away via the model superclasses.

[5]:
####################################
### model/llm/openai/_openai.py  ###
####################################
from typing import Any
from contextlib import AsyncExitStack

from artkit.model.util import RateLimitException
from artkit.model.llm.history import ChatHistory

from openai import RateLimitError


async def get_response(
    self,
    message: str,
    *,
    history: ChatHistory | None = None,
    **model_params: dict[str, Any],
) -> list[str]:

    # ARTKIT model implementations are designed to be fully asynchronous -
    #   AsyncExitStack is used to handle multiple context managers dynamically.
    async with AsyncExitStack():
        try:
            # We access the client instance via the get_client method of the "ConnectorMixin" superclass -
            #   this will fetch a cached client instance if it exists or make a new one if it does not
            # This is very helpful, as it means you can share the same client instance across model objects
            completion = await self.get_client().chat.completions.create(

                # Here is the only OpenAI specific bit of code - we're formatting the message
                #  to pass to the chat endpoint
                messages=list(
                    self._messages_to_openai_format(  # type: ignore[arg-type]
                        message, history=history
                    )
                ),
                model=self.model_id,

                # We merge the model parameters passed to the get_response method with the defaults set
                # during instantiation, by overwriting the defaults with the passed parameters
                **{**self.get_model_params(), **model_params},
                )
        except RateLimitError as e:
                # If the rate limit is exceeded, we raise a custom RateLimitException
                # This is caught for all ChatModelConnectors and handled via exponential backoff
                raise RateLimitException(
                    "Rate limit exceeded. Please try again later."
                ) from e

    return list(self._responses_from_completion(completion))

Now we can see the actual output of our .get_response() function :

[6]:
print((await gpt4_chat.with_system_prompt("You respond only in haiku").get_response(message="What color is the sky?"))[0])
Blue as the endless sea,
Reflecting the sun's bright glow,
Infinite and free.

Your implementation - a checklist:#

Here are the basic steps you’ll need to take to create a new model class:

  1. Depending on which kind of model you want to implement the right abstract class e.g., ChatModelConnect. Ideally, your IDE assists you here. Otherwise, you can try starting with an existing implementation, but make sure to check your parent classes and method signatures.

  2. Update __init__ to only include the parameters relevant for your model

  3. Update _make_client to return an instance of your model’s client. To do so, review the model provider’s API documentation; refer to Connecting to Gen AI Models for some examples of what you’re looking for

  4. Update get_response to pass a message to the client endpoint and return its response

  5. Add unit tests for your new model implementation

If you’re working with a diffusion or vision model, you will have to implement a different abstract model class but the necessary steps are very similar.

Here are a few other best-practices that will save you time during development:

  • Run pre-commit hooks frequently; they will help you catch any missing implementation, type errors, or general formatting inconsistencies

  • Write unit tests as you go, and run pytest intermittently to make sure you haven’t accidentally broken anything

  • Import your model class in `api.py``; this will allow it to be called via the ARTKIT API

  • Add a try / expect ImportError at the top of your class; ARTKIT does not require every supported model to be installed on setup

Calling custom endpoints via HTTP#

ARTKIT offers support for sending HTTP requests to any endpoint. You will need to subclass a version of HTTPXChatConnector (which implements the ChatModelConnector class). It has native support for exponential retries if you do not specify any keyword arguments into the constructor (but can be customized to specific retry requirements for the endpoint at hand). Under the hood it uses the http.AsyncClient to make requests to your endpoint. You can also pass a dictonary of client key word arguments for example to specify a read timeout:

from httpx import Timeout

client_kwargs = dict(timeout=Timeout(10.0, read=10.0))
chat_model = YourImplementedClass(
    model_id="http://test.url/api/v1/",
    httpx_client_kwargs=client_kwargs,
)

With the abstract HTTPXChatConnector class most ARTKIT specific implementation detail are already taken care of. All you need to implement are the following functions:

def get_default_api_key_env(cls) -> str:
    # returns the defaul environment variable name
    # under which you keep your API Key if needed

def build_request_arguments(
        self,
        message: str,
        *,
        history: ChatHistory | None = None,
        **model_params: dict[str, Any],
    ) -> dict[str, Any]:
    # "Translates" the ARTKIT input to an httpx.request input

def parse_httpx_response(
    self,
    response: Response,
) -> list[str]
    # "Translates" the httpx.request output to an ARTKIT output

To make things a bit more concrete we provide an example using the REST API provided by OpenAI but you can also look at the implementations of ak.HUggingFaceURLChat and ak.BaseBedrockChat.

[6]:
import json
from typing import Any

from httpx import Response
from artkit.model.llm.base import HTTPXChatConnector
from artkit.model.llm.history import ChatHistory


from dotenv import load_dotenv
load_dotenv()


class OpenAIChatConnector(HTTPXChatConnector):
    def __init__(self, model_id: str, model_name: str, **kwargs):
        """
        For this specific example, the OpenAI API endpoint requires a specific model name and thus, a model_name has been added to the constructor.
        """
        super().__init__(model_id=model_id, **kwargs)
        self.model_name = model_name

    @classmethod
    def get_default_api_key_env(cls) -> str:
        """
        Get the default name of the environment variable that holds the API key.

        :return: the default name of the api key environment variable
        """
        return "OPENAI_API_KEY"


    def build_request_arguments(
        self,
        message: str,
        *,
        history: ChatHistory | None = None,
        **model_params: dict[str, Any],
    ) -> dict[str, Any]:
        """
        This method is responsible for formatting the input to the LLM chat system.
        For argument options see :class:`httpx.AsyncClient.request`.

        :param message: The input message to format.
        :param history: The chat history preceding the message.
        :param model_params: Additional parameters for the chat system.

        :return:
        """
        messages = []
        if self.system_prompt:
            messages.append(
                {
                    "role": "system",
                    "content": self.system_prompt
                }
            )
        if history:
            messages.extend(
                [
                    {
                        "role": message.role,
                        "content": message.text,
                    }
                    for message in history.messages
                ]
            )
        messages.append(
            {
                "role": "user",
                "content": message
            }
        )

        request_body = json.dumps({
            "model": self.model_name,
            "messages": messages,
            **{**self.get_model_params(), **model_params},
        })
        return dict(
            method="POST",
            url=self.model_id,
            headers={
                "Authorization": f"Bearer {self.get_api_key()}",
                "Content-Type": "application/json",
            },
            data=request_body,
        )


    def parse_httpx_response(self, response: Response) -> list[str]:
        """
        This method is responsible for formatting the :class:`httpx.Response` after
        having made the request.

        :param response: The response from the endpoint.
        :return: A list of formatted response strings.
        """
        response_json = response.json()
        return [response_json['choices'][0]['message']['content']] if 'choices' in response_json else []


openai_chat = OpenAIChatConnector(model_id="https://api.openai.com/v1/chat/completions", model_name="gpt-4o")
prompt = "What is 2+2?"
response = await openai_chat.get_response(message=prompt)
response
[6]:
['2 + 2 equals 4.']

AWS Bedrock Specifics#

The AWS Bedrock integration is also based on the HTTPXCustomChat.

For AWS Bedrock requests, as long as valid credentials are in-place, ARTKIT will handle that all necessary properties are set for API request signing. You will need to ensure that that an appropriate AWS account is set-up and that a user within that account has the right IAM policies to access the foundational models that you wish to use.

Additionally, you will need to implement specific I/O parsing for calls to those foundational models as each foundational model has distinct request and response payloads (i.e., build_request_arguments and parse_httpx_response).

The goal is to illustrate how we would do that with the existing code.

A BaseBedrockChat has two required parameters: model_id and region corresponding to the specific foundational model and region that you wish to make API calls to. Unlike the other functionality with other components in ARTKIT, the model parameters are submitted via a JSON payload and the invoke_model API see here

To implement a specific payload for another AWS Foundational Model, the following methods will need to be implemented:

[ ]:
# Using Claude as an example
from typing import Any

from artkit.model.llm.bedrock.base import BaseBedrockChat

class ClaudeBedrockChat(BaseBedrockChat):
    async def build_request_arguments(
        self,
        message: str,
        *,
        history: ChatHistory | None = None,
        **model_params: dict[str, Any],
    ) -> dict[str, Any]:
        # your implementation here
        pass

    def parse_httpx_response(self, response: Response) -> list[str]:
        # your implementation here
        pass