Source code for artkit.model.diffusion.base._diffusion
# -----------------------------------------------------------------------------
# © 2024 Boston Consulting Group. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
"""
Base implementation of diffusion model
"""
from __future__ import annotations
from abc import ABCMeta, abstractmethod
from ....util import Image
__all__ = [
"DiffusionModel",
"DiffusionModelConnector",
]
from typing import Any, Generic, TypeVar
from ...base import ConnectorMixin, GenAIModel
from ...util import retry_with_exponential_backoff
#
# Type variables
#
T_Client = TypeVar("T_Client")
#
# Classes
#
[docs]
class DiffusionModel(GenAIModel, metaclass=ABCMeta):
"""An abstract diffusion model"""
[docs]
@abstractmethod
async def text_to_image(
self, text: str, **model_params: dict[str, Any]
) -> list[Image]:
"""
Generate an image from text input
:param text: the input text for the diffusion model
:param model_params: additional parameters for the diffusion model
:return: an :class:`.Image` object generated by the diffusion model
"""
[docs]
class DiffusionModelConnector(
DiffusionModel, ConnectorMixin[T_Client], Generic[T_Client], metaclass=ABCMeta
):
"""
A diffusion model that connects to a client.
"""
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
# Apply the retry strategy to the image_to_text method
cls.text_to_image = ( # type: ignore[method-assign]
retry_with_exponential_backoff(cls.text_to_image) # type: ignore
)