Source code for artkit.model.diffusion.base._diffusion

# -----------------------------------------------------------------------------
# © 2024 Boston Consulting Group. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------

"""
Base implementation of diffusion model
"""

from __future__ import annotations

from abc import ABCMeta, abstractmethod

from ....util import Image

__all__ = [
    "DiffusionModel",
    "DiffusionModelConnector",
]

from typing import Any, Generic, TypeVar

from ...base import ConnectorMixin, GenAIModel
from ...util import retry_with_exponential_backoff

#
# Type variables
#

T_Client = TypeVar("T_Client")


#
# Classes
#


[docs] class DiffusionModel(GenAIModel, metaclass=ABCMeta): """An abstract diffusion model"""
[docs] @abstractmethod async def text_to_image( self, text: str, **model_params: dict[str, Any] ) -> list[Image]: """ Generate an image from text input :param text: the input text for the diffusion model :param model_params: additional parameters for the diffusion model :return: an :class:`.Image` object generated by the diffusion model """
[docs] class DiffusionModelConnector( DiffusionModel, ConnectorMixin[T_Client], Generic[T_Client], metaclass=ABCMeta ): """ A diffusion model that connects to a client. """ def __init_subclass__(cls, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) # Apply the retry strategy to the image_to_text method cls.text_to_image = ( # type: ignore[method-assign] retry_with_exponential_backoff(cls.text_to_image) # type: ignore )