Creating New Model Classes#

Introduction#

This notebook serves as a primer for developing ARTKIT model classes. This is most likely your first entrypoint to contributing to ARTKIT. So, to get you started we will cover:

  • Object-oriented programming motivations: The advantages of using class hierarchies and how this enables us to distribute features across client implementations

  • ARTKIT’s model class hierarchy: How are classes structured across ARTKIT and what levels of abstraction exist

  • Model implementation example: A deep-dive of the initialization and get_response implementation of the OpenAIChat class

  • Your implementation - a checklist: Steps to complete to code up your own client implementation

This is a technical guide for users who are looking to implement new functionalities within the ARTKIT library. While a basic understanding of python classes is assumed, this guide should be accessible for anyone with a data science background.

To get more information on the classes mentioned here go to the API reference.

Object-oriented programming motivations#

ARTKIT makes heavy use of class inheritance to implement connectors to Gen AI model providers. This allows us to take advantage of features specific to certain model providers while standardizing common utility methods such as asynchronous execution, managing chat history, and caching. It also means you can quickly implement classes to connect to new providers and immediately re-use these standard features without any additional coding.

When should I create a new model class? Currently ARTKIT provides interfaces for connecting to OpenAI, Anthropic, Hugging Face, Groq, and Google’s Gemini LLMs. It also supports OpenAI’s multi-modal models, specifically DALL-E models and vision endpoints. If there are additional model providers you’d like to use in your testing and evaluation pipeline, creating a new class is the best way to do so – and also improves the usefulness of ARTKIT for everyone! To get to know more about contributing to ARTKIT see our Contributor Guide.

What if I want to change core class functionality? Reviewing the model class hierarchy in this guide is a good starting point to developing new core features to the library, though implementing and testing these features will more take careful considerations and scrutiny, as introducing changes on higher levels of abstraction can have implications for other client implementations.

ARTKIT’s model class hierarchy#

Here’s a summary of the model class hierarchy, starting from the abstract GenAIModel to the OpenAIChat class:

  1. GenAIModel is an abstract representation of a generic Gen AI model

  2. ConnectorMixin is a mixin class that adds a common interface for client connections to a GenAIModel

  3. ChatModel is a an abstract class to have a common interface for generating responses for text prompts, emulating a chat - its a subclass of a GenAIModel

  4. ChatModelConnector combines the ConnectorMixin with the ChatModel interface to create chat model that connects to a client

  5. OpenAIChat is a subclass of ChatModelConnector to connect to OpenAI’s model and utilize them for chats

Here’s a tree of the relevant repo structure, so you can take a look at the implementation of these classes:

model/
  base/
    _model.py     --> GenAIModel, ConnectorMixin
  llm/
    base/
      _llm.py     --> ChatModel, ChatModelConnector
    openai/
      _openai.py  --> OpenAIChat

A similar hierarchy is shared for other model providers and modalities. For example, levels 1. - 4. are identical for the AnthropicChat class. This means that OpenAIChat and AnthropicChat share all attributes and method signatures of ChatModelConnector, such as get_client() but not the actual implementation of those methods. This enables model-specific parameters and response processing.

If you’re adding a new model class, you will most likely only need to make changes on hierarchy level 5.

It is similar for other modalities; for example, the OpenAIDiffusion class inherits from a DiffusionModelConnector, which is again a subclass of a DiffusionModel and a ConnectorMixin. Unlike a ChatModel, a DiffusionModel does not have a get_response() method. However, because it also inherits from ConnectorMixin, it will have the same get_client() method signature.

Model implementation example#

Next, we will demonstrate how the model class hierarchy is used through an example of the OpenAIChat class.

The goal is to illustrate both (i) what methods you need to implement to create a new model class and (ii) why those methods are set up the way that they are.

A ChatModelConnector has two required parameters: model_id and api_env_key. The OpenAIChat has additional model-specific parameters such as temperature. So when we initialize an OpenAIChat object as below, we’re also initializing the superclass to handle client connections:

[2]:
# Import ARTKIT
import artkit.api as ak

# Load API keys
from dotenv import load_dotenv
load_dotenv()

gpt4_chat = ak.OpenAIChat(model_id="gpt-4", api_key_env="OPENAI_API_KEY", temperature=0.5,)

To implement a chat model connector the following methods need to be implemented:

[3]:
#####################################
### model/llm/openai/_openai.py   ###
#####################################
from typing import Any

from openai import AsyncOpenAI

from artkit.model.llm.base import ChatModelConnector
from artkit.model.llm.history import ChatHistory


class OpenAIChat(ChatModelConnector[AsyncOpenAI]):

    # required by ConnectorMixin to establish a client connection
    @classmethod
    def get_default_api_key_env(cls) -> str:
        pass

    # required by ConnectorMixin to establish a client connection
    def _make_client(self) -> AsyncOpenAI:
        pass

    # required by ChatModel to send a message to the model
    async def get_response(  # pragma: no cover
        self,
        message: str,
        *,
        history: ChatHistory | None = None,
        **model_params: dict[str, Any],
    ) -> list[str]:
        pass

By unifying the interfaces for chat models we are enabled to follow a delegation pattern through which we can add external behaviors to all ChatModelConnectors by wrapping them in a separate class. CachedChatModel is a great example of this that enables us to cache requests to a model provider without caring about the actual implementation details. You can see us make use of this below.

[4]:
# Wrap the model in a CachedChatModel
gpt4_chat = ak.CachedChatModel(
    model=gpt4_chat,
    database="cache/creating_new_model_classes.db",
)

Now, let’s take a deeper dive into the OpenAIChat implementation of get_response() to highlight how we’re taking advantage of the full class hierarchy:

  1. The with_system_prompt() method, that you will see later, is inherited from the ChatModelConnector superclass to set the system_prompt property of the model

  2. Upon calling get_response(), we first use AsyncExitStack() to enter the model’s context manager

  3. Then we call the get_client() method of the ConnectorMixin superclass to fetch a cached OpenAI client instance.

  4. If no previous client instance exists, it is created via the model’s _make_client() implementation. This logic rests in the ConnectorMixin as well.

  5. Next, we format an input message for OpenAI’s chat endpoint based on the model’s history, system prompt, and input message

  6. The message is sent to OpenAI’s chat endpoint, along with other parameters set during model initialization (such as temperature and max tokens)

  7. Finally, we parse return a list of responses from OpenAI’s chat endpoint

While there are quite a few steps here, note that the only ones specific to the OpenAIChat class are 5-8.

That means if you’re creating a new custom class, all you need to worry about is getting a client instance, passing a message to the client, and returning its response.

Everything else can be abstracted away via the model superclasses.

[5]:
####################################
### model/llm/openai/_openai.py  ###
####################################
from typing import Any
from contextlib import AsyncExitStack

from artkit.model.util import RateLimitException
from artkit.model.llm.history import ChatHistory

from openai import RateLimitError


async def get_response(
    self,
    message: str,
    *,
    history: ChatHistory | None = None,
    **model_params: dict[str, Any],
) -> list[str]:

    # ARTKIT model implementations are designed to be fully asynchronous -
    #   AsyncExitStack is used to handle multiple context managers dynamically.
    async with AsyncExitStack():
        try:
            # We access the client instance via the get_client method of the "ConnectorMixin" superclass -
            #   this will fetch a cached client instance if it exists or make a new one if it does not
            # This is very helpful, as it means you can share the same client instance across model objects
            completion = await self.get_client().chat.completions.create(

                # Here is the only OpenAI specific bit of code - we're formatting the message
                #  to pass to the chat endpoint
                messages=list(
                    self._messages_to_openai_format(  # type: ignore[arg-type]
                        message, history=history
                    )
                ),
                model=self.model_id,

                # We merge the model parameters passed to the get_response method with the defaults set
                # during instantiation, by overwriting the defaults with the passed parameters
                **{**self.get_model_params(), **model_params},
                )
        except RateLimitError as e:
                # If the rate limit is exceeded, we raise a custom RateLimitException
                # This is caught for all ChatModelConnectors and handled via exponential backoff
                raise RateLimitException(
                    "Rate limit exceeded. Please try again later."
                ) from e

    return list(self._responses_from_completion(completion))

Now we can see the actual output of our .get_response() function :

[6]:
print((await gpt4_chat.with_system_prompt("You respond only in haiku").get_response(message="What color is the sky?"))[0])
Blue as the endless sea,
Reflecting the sun's bright glow,
Infinite and free.

Your implementation - a checklist:#

Here are the basic steps you’ll need to take to create a new model class:

  1. Depending on which kind of model you want to implement the right abstract class e.g., ChatModelConnect. Ideally, your IDE assists you here. Otherwise, you can try starting with an existing implementation, but make sure to check your parent classes and method signatures.

  2. Update __init__ to only include the parameters relevant for your model

  3. Update _make_client to return an instance of your model’s client. To do so, review the model provider’s API documentation; refer to Connecting to Gen AI Models for some examples of what you’re looking for

  4. Update get_response to pass a message to the client endpoint and return its response

  5. Add unit tests for your new model implementation

If you’re working with a diffusion or vision model, you will have to implement a different abstract model class but the necessary steps are very similar.

Here are a few other best-practices that will save you time during development:

  • Run pre-commit hooks frequently; they will help you catch any missing implementation, type errors, or general formatting inconsistencies

  • Write unit tests as you go, and run pytest intermittently to make sure you haven’t accidentally broken anything

  • Import your model class in `api.py``; this will allow it to be called via the ARTKIT API

  • Add a try / expect ImportError at the top of your class; ARTKIT does not require every supported model to be installed on setup