"""
Implementation of sphinx utility callbacks specific to building Gamma documentation.
"""
from __future__ import annotations
import collections.abc
import importlib
import itertools
import logging
import re
import typing
from abc import ABCMeta, abstractmethod
from collections.abc import Callable, Generator, Iterable, Mapping
from inspect import getattr_static
from re import Pattern
from types import FunctionType, GenericAlias, MethodType, UnionType
from typing import Any, ForwardRef, Generic, TypeVar, Union, cast, get_type_hints
import typing_inspect
from ...api import (
AllTracker,
inheritdoc,
public_module_prefix,
subsdoc,
update_forward_references,
)
from ...meta import SingletonABCMeta
from .. import (
AutodocBeforeProcessSignature,
AutodocProcessDocstring,
AutodocProcessSignature,
AutodocSkipMember,
ObjectDescriptionTransform,
)
try:
# import sphinx classes if available ...
from docutils.nodes import Element, Text
from sphinx.application import Sphinx
except ImportError:
# ... otherwise mock them up
# noinspection PyMissingOrEmptyDocstring,PyUnusedLocal,SpellCheckingInspection
class _Element:
children: list[Element]
attributes: dict[str, Any]
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise TypeError("docutils package is not installed")
def replace(self, old: Element, new: Element) -> None: ... # noqa: E704
# noinspection PyMissingOrEmptyDocstring,SpellCheckingInspection
class _Text(_Element):
rawsource: str
# noinspection SpellCheckingInspection
def astext(self) -> str: # type: ignore
...
Sphinx = type("Sphinx", (object,), {})
Element = _Element
Text = _Text
#
# Constants
#
log = logging.getLogger(__name__)
#
# Exported names
#
__all__ = [
"AddInheritance",
"CollapseModulePaths",
"CollapseModulePathsInDocstring",
"CollapseModulePathsInSignature",
"CollapseModulePathsInXRef",
"RenamePrivateArguments",
"Replace3rdPartyDoc",
"ResolveTypeVariables",
"SkipIndirectImports",
"TrackCurrentClass",
"UpdateForwardReferences",
]
#
# Type variables
#
method_descriptor: type[Any] = type(str.startswith)
wrapper_descriptor: type[Any] = type(str.__add__)
internal_function_or_method: type[Any] = type(iter)
#
# Constants
#
METHOD_TYPE_DYNAMIC = 0
METHOD_TYPE_STATIC = 1
METHOD_TYPE_CLASS = 2
#
# Ensure all symbols introduced below are included in __all__
#
__tracker = AllTracker(globals())
[docs]@inheritdoc(match="""[see superclass]""")
class AddInheritance(AutodocProcessDocstring):
"""
Add list of base classes as the first line of the docstring.
Ignore builtin classes and classes that have already been visited once before.
"""
def __init__(self, collapsible_submodules: Mapping[str, str]) -> None:
"""
:param collapsible_submodules: mapping of submodule paths to shorter
*(collapsed)* versions they should be replaced with
"""
super().__init__()
self.collapsible_submodules = collapsible_submodules
#: Dict mapping visited classes to their unprocessed docstrings.
self._visited: dict[type, str] = {}
#: Field directive for base classes.
F_BASES = ":bases:"
#: Field directive for generic types.
F_GENERICS = ":generic types:"
#: Field directive for metaclasses.
F_METACLASSES = ":metaclasses:"
[docs] def process(
self,
app: Sphinx,
what: str,
name: str,
obj: object,
options: object,
lines: list[str],
) -> None:
"""[see superclass]"""
if what != "class":
return
# add bases and generics documentation to class
# generate the RST for bases and generics
class_ = cast(type, obj)
_current_lines = "\n".join(lines)
try:
_seen_lines = self._visited[class_]
if _current_lines != _seen_lines:
# we are seeing another part of the docstring, probably in __init__
# vs. the class docstring;
# ignore this to prevent adding the same content at two places
return
except KeyError:
# we are seeing a class for the first time; store its content, so we can
# detect and allow repeat visits
self._visited[class_] = _current_lines
bases_lines: list[str] = [""]
bases: list[type] = _get_minimal_bases(class_)
if bases:
base_names = (self._class_name_with_generics(base) for base in bases)
bases_lines.append(f'{AddInheritance.F_BASES} {", ".join(base_names)}')
generics: list[str] = self._get_generics(class_)
if generics:
bases_lines.append(f'{AddInheritance.F_GENERICS} {", ".join(generics)}')
metaclasses: list[str] = self._get_metaclasses(class_)
if metaclasses:
bases_lines.append(
f'{AddInheritance.F_METACLASSES} {", ".join(metaclasses)}'
)
bases_lines.append("")
# insert this after the intro text, and before class parameters
self._insert_bases_lines(bases_lines, lines)
@staticmethod
def _insert_bases_lines(bases_lines: list[str], lines: list[str]) -> None:
def _insert_position() -> int:
for n, line in enumerate(lines):
if re.match(r"\s*:\w+(?:\s+\w+)*:", line) and (
n == 0 or not lines[n - 1].strip()
):
return n
return len(lines)
if len(bases_lines) > 0:
pos = _insert_position()
lines[pos:pos] = bases_lines
def _class_module(self, cls: type) -> str:
module_name: str = public_module_prefix(cls.__module__)
# return the collapsed submodule if it exists,
# else return the unchanged module name
return self.collapsible_submodules.get(module_name, module_name)
def _full_name(self, cls: type) -> str:
# get the full name of the class, including the module prefix
return f"{self._class_module(cls)}.{_class_name(cls)}"
def _class_name_with_generics(self, cls: Any) -> str:
def _class_tag(
name: str,
*,
is_class: bool = True,
is_local: bool = False,
is_short: bool = False,
) -> str:
if is_local:
name = f".{name}"
if is_short:
name = f"~{name}"
if is_class:
return f":class:`{name}`"
else:
return f":obj:`{name}`"
if isinstance(cls, TypeVar):
return str(cls)
if isinstance(cls, ForwardRef):
if cls.__forward_evaluated__:
cls = cls.__forward_value__
else:
return _class_tag(cls.__forward_arg__, is_local=True)
if not hasattr(cls, "__module__"):
return _class_tag(str(cls))
if cls.__module__ in ("__builtin__", "builtins"):
return _class_tag(cls.__name__)
else:
generic_args = [
self._class_name_with_generics(arg)
for arg in typing_inspect.get_args(cls, evaluate=True)
]
generic_arg_str = f" [{', '.join(generic_args)}]" if generic_args else ""
return (
_class_tag(
self._full_name(cls), is_class=isinstance(cls, type), is_short=True
)
+ generic_arg_str
)
def _typevar_name(self, cls: TypeVar) -> str:
if isinstance(cls, TypeVar):
args: list[str] = [
self._class_name_with_generics(c)
for c in getattr(cls, "__constraints__", ())
]
if getattr(cls, "__bound__", None):
args.append(f"bound= {self._class_name_with_generics(cls.__bound__)}")
return f'{cls}({", ".join(args)})' if args else str(cls)
else:
return str(cls)
def _get_generics(self, child_class: type) -> list[str]:
return list(
itertools.chain.from_iterable(
(
self._typevar_name(arg)
for arg in typing_inspect.get_args(base, evaluate=True)
)
for base in _get_generic_bases(child_class)
if typing_inspect.get_origin(base) is Generic
)
)
def _get_metaclasses(self, class_: type) -> list[str]:
return [
self._class_name_with_generics(meta_)
for meta_ in _get_bases(type(class_), include_subclass=True)
if meta_ is not type
]
[docs]class CollapseModulePaths(metaclass=ABCMeta):
"""
Replace private module paths with their public prefix so that object references
can be matched by *intersphinx*.
"""
# matches a full name of an object, including the preceding module path with at
# least one private submodule (starting with a "_") directly preceding the item
# name
__RE_PRIVATE_MODULE_AND_ITEM = re.compile(
r"\b(?# we start with a word break so we match full words)"
r"(?# public module path)(\w+(?:\.\w+)*?)"
r"(?# private module path)((?:\._\w+)+)"
r"\."
r"(?# item name)(\w+(?![.\w]))"
)
def __init__(
self,
collapsible_submodules: Mapping[str, str],
collapse_private_modules: bool = True,
) -> None:
"""
:param collapsible_submodules: mapping from module paths to their public
prefix, e.g., ``{"pandas.core.frame": "pandas"}``
:param collapse_private_modules: if ``True``, collapse module sub-paths
consisting of one or more protected modules (i.e. the module name starts
with an underscore)
"""
super().__init__()
self._classes_visited: set[type] = set()
col = [
self._make_substitution_pattern(old.replace(".", r"\."), new)
for old, new in collapsible_submodules.items()
]
self._intersphinx_collapsible_prefixes: list[tuple[Pattern[str], str]] = col
self._collapse_private_modules = collapse_private_modules
@abstractmethod
def _make_substitution_pattern(
self, old: str, new: str
) -> tuple[Pattern[str], str]:
# create the regex substitution rule given a raw match and replacement patterns
pass
[docs] def collapse_module_paths(self, line: str) -> str:
"""
In the given line, replace all module paths with their collapsed version.
:param line: the line in which to collapse module paths
:return: the resulting line with collapsed module paths
"""
if self._collapse_private_modules:
line = self._collapse_private_module_paths(line)
for expanded, collapsed in self._intersphinx_collapsible_prefixes:
line = expanded.sub(collapsed, line)
return line
@staticmethod
def _collapse_private_module_paths(line: str) -> str:
for (
# e.g., "pytools.expression"
public_module_path,
# e.g., "._expression"
private_module_path,
# e.g., "Expression"
item_name,
) in CollapseModulePaths.__RE_PRIVATE_MODULE_AND_ITEM.findall(line):
module_path = public_module_path + private_module_path
collapsed_path = public_module_path
try:
module = importlib.import_module(name=module_path)
item = vars(module)[item_name]
collapsed_path = item.__publicmodule__
except KeyError:
pass
except AttributeError:
pass
except ModuleNotFoundError:
pass
line = line.replace(
f"{module_path}.{item_name}", f"{collapsed_path}.{item_name}"
)
return line
[docs]@inheritdoc(match="""[see superclass]""")
class CollapseModulePathsInDocstring(CollapseModulePaths, AutodocProcessDocstring):
"""
Replace private module paths in docstrings with their public prefix so that object
references can be matched by *intersphinx*.
"""
[docs] def process(
self,
app: Sphinx,
what: str,
name: str,
obj: object,
options: object,
lines: list[str],
) -> None:
"""[see superclass]"""
for i, line in enumerate(lines):
lines[i] = self.collapse_module_paths(line)
def _make_substitution_pattern(
self, old: str, new: str
) -> tuple[Pattern[str], str]:
return re.compile(f"(`~?){old}"), f"\\1{new}"
[docs]@inheritdoc(match="""[see superclass]""")
class CollapseModulePathsInSignature(CollapseModulePaths, AutodocProcessSignature):
"""
Replace private module paths in signatures with their public prefix so that object
references can be matched by *intersphinx*.
"""
[docs] def process(
self,
app: Sphinx,
what: str,
name: str,
obj: object,
options: object,
signature: str | None,
return_annotation: str | None,
) -> tuple[str | None, str | None] | None:
"""[see superclass]"""
if signature or return_annotation:
return (
self.collapse_module_paths(signature) if signature else None,
(
self.collapse_module_paths(return_annotation)
if return_annotation
else None
),
)
return None
def _make_substitution_pattern(
self, old: str, new: str
) -> tuple[Pattern[str], str]:
return re.compile(old), new
[docs]@inheritdoc(match="""[see superclass]""")
class CollapseModulePathsInXRef(ObjectDescriptionTransform, CollapseModulePaths):
# noinspection GrazieInspection
"""
Replace private module paths in documentation cross-references with their public
prefix so that object references can be matched by *intersphinx*, and module
paths correspond to what users are expected to use in their code.
"""
# noinspection SpellCheckingInspection
[docs] def process(
self, app: Sphinx, domain: str, objtype: str, contentnode: Element
) -> None:
"""[see superclass]"""
if domain == "py" and objtype == "class":
self._process_children(contentnode)
def _process_children(self, parent_node: Element) -> None:
self._process_child(parent_node)
try:
children: Iterable[Element] = parent_node.children
except AttributeError:
# parent node is not an Element instance
return
for child_node in children:
self._process_children(child_node)
def _process_child(self, content_node: Element) -> None:
if type(content_node).__name__ == "pending_xref" and tuple(
type(c).__name__ for c in content_node.children
) == ("Text",):
text_node: Text = cast(Text, content_node.children[0])
text = text_node.astext()
text_collapsed: str = self.collapse_module_paths(text)
if text_collapsed != text:
text_collapsed = text_collapsed.replace(
f'{content_node.attributes["py:module"]}.', ""
)
content_node.replace(
old=text_node,
new=type(text_node)(
data=text_collapsed, rawsource=text_node.rawsource
),
)
def _make_substitution_pattern(
self, old: str, new: str
) -> tuple[Pattern[str], str]:
return re.compile(old), new
[docs]@inheritdoc(match="""[see superclass]""")
class SkipIndirectImports(AutodocSkipMember, metaclass=SingletonABCMeta):
"""
Skip members imported by a private package.
"""
[docs] def process(
self,
app: Sphinx,
what: str,
name: str,
obj: object,
skip: bool,
options: object,
) -> bool | None:
"""[see superclass]"""
if not skip and what == "module" and name.startswith("_"):
log.info(f"skipping: {what}: {name}")
return False
return None
[docs]@inheritdoc(match="""[see superclass]""")
class Replace3rdPartyDoc(AutodocProcessDocstring, metaclass=SingletonABCMeta):
"""
Replace 3rd party docstrings with a reference to the 3rd party documentation.
This is necessary for methods and attributes inherited from 3rd party packages,
as these might use an incompatible format for docstrings.
"""
__RE_ROOT_PACKAGE = re.compile(r"\w+(?=\.)")
__RST_DIRECTIVE = {
"module": "mod",
"class": "class",
"exception": "exception",
"function": "func",
"method": "meth",
"attribute": "attr",
"property": "attr",
}
[docs] def process(
self,
app: Sphinx,
what: str,
name: str,
obj: object,
options: object,
lines: list[str],
) -> None:
"""[see superclass]"""
if what == "attribute":
# we cannot determine docstrings for attributes, as the object represents
# the value of the attribute, and not the attribute itself
mod_obj_attr = name.rsplit(".", 2)
if len(mod_obj_attr) == 3:
mod_name, obj_name, _ = mod_obj_attr
obj = getattr(importlib.import_module(mod_name), obj_name)
else:
log.debug(f"could not determine module for {name}")
return
elif what == "property":
obj = cast(property, obj).fget
try:
# if the object has an __objclass__ attribute, use that to determine the
# module
obj_module = getattr(obj, "__objclass__", obj).__module__
except AttributeError:
log.debug(f"could not determine module for {name}")
return
name_root_package = self.__root_package(name)
obj_root_package = self.__root_package(obj_module) if obj_module else None
if name_root_package != obj_root_package:
# replace 3rd party docstring with cross-reference
directive = Replace3rdPartyDoc.__RST_DIRECTIVE.get(what, what)
if not isinstance(
obj,
(
FunctionType,
MethodType,
method_descriptor,
wrapper_descriptor,
internal_function_or_method,
),
):
log.warning(f"{obj!r}:{type(obj)} is not a function or method")
return
if not obj_module or obj_module == "builtins":
full_name = obj.__qualname__
else:
full_name = f"{public_module_prefix(obj_module)}.{obj.__qualname__}"
del lines[:]
lines.append(f"See :{directive}:`{full_name}`")
@staticmethod
def __root_package(name: str) -> str:
root_package_match = Replace3rdPartyDoc.__RE_ROOT_PACKAGE.match(name)
return root_package_match[0] if root_package_match else ""
#
# auxiliary functions
#
def _get_bases(subclass: type, include_subclass: bool) -> Generator[type, None, None]:
# get the names of the immediate base classes of arg _subclass
visited_classes: set[type] = set()
def _inner(_subclass: type, _include_subclass: bool) -> Generator[type, None, None]:
# ensure we have the non-generic origin class
_subclass = typing_inspect.get_origin(_subclass) or _subclass
if _subclass in visited_classes:
return
visited_classes.add(_subclass)
# get the base classes; try generic bases first then fall back to regular
# bases
base_classes: tuple[type, ...] = (
_get_generic_bases(_subclass) or _subclass.__bases__
)
# include the _subclass itself in the list of bases, if requested
if _include_subclass:
# noinspection PyTypeChecker
base_classes = (_subclass, *base_classes)
# get the names of all base classes; go up the class hierarchy in case of
# hidden classes
for base in base_classes:
# exclude object and Generic types
if base is object or typing_inspect.get_origin(base) is Generic:
continue
# exclude protected classes
elif _class_name(base).startswith("_"):
yield from _inner(base, _include_subclass=False)
# all other classes will be listed as bases
else:
yield base
return _inner(subclass, _include_subclass=include_subclass)
def _get_minimal_bases(class_: type) -> list[type]:
bases_with_origin = [
(base, typing_inspect.get_origin(base) or base)
for base in set(_get_bases(class_, include_subclass=False))
]
return [
base
for base, origin in bases_with_origin
if not any(
origin is not other and issubclass(other, origin)
for _, other in bases_with_origin
)
]
def _class_name(cls: Any) -> str:
return cast(str, _class_attr(cls=cls, attr=["__qualname__", "__name__", "_name"]))
def _class_attr(cls: Any, attr: list[str]) -> Any:
def _get_attr(_cls: type) -> Any:
# we try to get the class attribute
for attr_name in attr:
attr_value = getattr(_cls, attr_name, None)
if attr_value is not None:
return attr_value
# if the attribute is not defined, this class is likely to have generic
# arguments, so we re-try recursively with the origin (unless the origin
# is the class itself to avoid infinite recursion)
cls_origin = typing_inspect.get_origin(_cls)
if cls_origin is not None and cls_origin != _cls:
return _get_attr(cls_origin)
else:
# as a last resort, we create the default value
raise AttributeError(
f"none of the attributes not found in class {cls}: {', '.join(attr)}"
)
return _get_attr(_cls=typing.get_origin(cls) or cls)
class _TypeVarBindings:
current_class: type[Any]
_bindings: dict[
type[Any],
dict[
TypeVar,
type[Any] | TypeVar,
],
]
def __init__(self, current_class: type[Any]) -> None:
super().__init__()
self.current_class = current_class
self._bindings = self._get_parameter_bindings(
cls=current_class, subclass_bindings={}
)
def resolve_parameter(
self, defining_class: type[Any], parameter: TypeVar
) -> type[Any] | TypeVar:
"""
Resolve a type parameter, substituting it with an actual type if the parameter
is bound to a type argument in the context of the current class;
otherwise return the parameter unchanged.
:param defining_class: the class that introduced the type parameter; this is
the current class itself, or a base class of the current class
:param parameter: the type variable
:return: the resolved parameter if bound to a type argument; else the original
parameter as a type variable
"""
return self._bindings.get(defining_class, {}).get(parameter, parameter)
def _get_parameter_bindings(
self,
cls: type[Any],
subclass_bindings: dict[TypeVar, type[Any] | TypeVar],
) -> dict[type[Any], dict[TypeVar, type[Any] | TypeVar]]:
# get type variable bindings for all generic types defined in the class
# hierarchy of the given parent class, applying the given bindings derived from
# child classes
# if arg cls has generic type parameters, it will have a corresponding
cls_origin: type[Any] | None = None
if typing_inspect.is_generic_type(cls):
cls_origin = typing_inspect.get_origin(cls)
class_bindings: dict[TypeVar, type[Any] | TypeVar]
if cls_origin:
class_bindings = {
param: subclass_bindings.get(arg, arg) if subclass_bindings else arg
for param, arg in zip(
typing_inspect.get_parameters(cls_origin),
typing_inspect.get_args(cls),
)
}
cls = cls_origin
else:
# this class has no generic parameters of itself, so we adopt the existing
# parameter bindings from the subclass(es)
class_bindings = subclass_bindings
superclass_bindings = {
superclass: bindings
for generic_superclass in _get_generic_bases(cls)
for superclass, bindings in (
self._get_parameter_bindings(
cls=generic_superclass, subclass_bindings=class_bindings
).items()
)
if bindings
}
if cls_origin:
# we have generic type parameters in this class, so we remember the
# associated bindings
return {cls_origin: class_bindings, **superclass_bindings}
else:
# we have no generic type parameters in this class, so we return the
# parameter bindings of the superclasses
return superclass_bindings
[docs]@inheritdoc(match="""[see superclass]""")
class ResolveTypeVariables(AutodocBeforeProcessSignature, metaclass=SingletonABCMeta):
"""
Resolve type variables that can be inferred through generic class parameters or
``self``/``cls`` special arguments.
For example, the Sphinx documentation for the inherited method ``B.f`` in the
following example will be rendered with the signature ``(int) -> int``:
.. code-block:: python
T = TypeVar("T")
class A(Generic[T]):
def f(x: T) -> T:
return x
class B(A[int]):
pass
"""
original_signatures: dict[Any, dict[str, type[Any] | TypeVar]]
_current_class: type[Any] | None
_current_class_bindings: _TypeVarBindings | None
_track_current_class: TrackCurrentClass
def __init__(self) -> None:
super().__init__()
self.original_signatures = {}
self._current_class = None
self._current_class_bindings = None
self._track_current_class = TrackCurrentClass()
[docs] def connect(self, app: Sphinx, priority: int | None = None) -> int:
"""[see superclass]"""
if TrackCurrentClass().app is not app:
raise RuntimeError(
f"connect {TrackCurrentClass.__name__}() to the same app "
f"before connecting {ResolveTypeVariables.__name__}()"
)
return super().connect(app, priority)
def _resolve_function_signature(
self, bindings: _TypeVarBindings, func: FunctionType
) -> None:
# get the class in which the method has been defined
defining_class_opt: type[Any] | None = self._get_defining_class(func)
if defining_class_opt is None:
# missing or unknown defining class: nothing to resolve in the signature
return
defining_class: type = defining_class_opt
# get the original signature and convert it to a list of (name, type) tuples
signature_original_items = list(self._get_original_signature(func).items())
def _get_self_or_cls_type_substitution() -> (
tuple[TypeVar, type[Any]] | tuple[None, None]
):
if signature_original_items:
method_type = self._get_method_type(defining_class, func)
if method_type is METHOD_TYPE_DYNAMIC:
# special case: we substitute type vars bound to the class
# when assigned to the 'self' or 'cls' parameters of methods
_, arg_0_type = signature_original_items[0]
if typing_inspect.is_typevar(arg_0_type):
return cast(TypeVar, arg_0_type), bindings.current_class
elif method_type is METHOD_TYPE_CLASS:
# special case: we substitute type vars bound to the class
# when assigned to the 'self' or 'cls' parameters of methods
_, arg_0_type = signature_original_items[0]
if (
typing_inspect.is_generic_type(arg_0_type)
and typing_inspect.get_origin(arg_0_type) is type
):
arg_0_type_args = typing_inspect.get_args(arg_0_type)
if len(arg_0_type_args) == 1 and typing_inspect.is_typevar(
arg_0_type_args[0]
):
return arg_0_type_args[0], bindings.current_class
return None, None
arg_0_type_var: TypeVar | None
arg_0_substitute: type[Any] | None
arg_0_type_var, arg_0_substitute = _get_self_or_cls_type_substitution()
def _substitute_type_vars_in_type_expression(
type_expression: type[Any] | TypeVar,
) -> type[Any] | TypeVar:
# recursively substitute type vars with their resolutions
if isinstance(type_expression, TypeVar):
if type_expression == arg_0_type_var:
# special case: substitute a type variable introduced by the
# initial self/cls argument of a dynamic or class method
assert arg_0_substitute is not None
return arg_0_substitute
else:
# resolve type variables defined by Generic[] in the
# class hierarchy
return bindings.resolve_parameter(defining_class, type_expression)
else:
# dynamically resolve type variables inside nested type expressions
return _substitute_generic_type_arguments(
type_expression=type_expression,
fn_substitute_type_vars=_substitute_type_vars_in_type_expression,
)
# get the actual signature object that we will modify
signature = func.__annotations__
if not signature:
return
for name, tp in signature_original_items:
signature[name] = _substitute_type_vars_in_type_expression(tp)
def _resolve_attribute_signatures(self, cls: type[Any]) -> None:
assert self._current_class_bindings is not None
bindings: _TypeVarBindings = self._current_class_bindings
def _substitute_type_vars_in_type_expression(
type_expression: type[Any] | TypeVar,
) -> type[Any] | TypeVar:
# recursively substitute type vars with their resolutions
if isinstance(type_expression, TypeVar):
# resolve type variables defined by Generic[] in the
# class hierarchy
return bindings.resolve_parameter(cls, type_expression)
else:
return _substitute_generic_type_arguments(
type_expression=type_expression,
fn_substitute_type_vars=_substitute_type_vars_in_type_expression,
)
annotations = getattr(cls, "__annotations__", None)
if annotations:
cls.__annotations__ = {
attr: _substitute_type_vars_in_type_expression(annotation)
for attr, annotation in annotations.items()
}
@staticmethod
def _get_defining_class(method: FunctionType) -> type[Any] | None:
# get the class that defined the callable
if "." not in method.__qualname__:
# this is a function, not a method
return None
method_container: str
if method.__qualname__.endswith(f".{method.__name__}"):
method_container = method.__qualname__[: -len(method.__name__) - 1]
else:
method_container = method.__qualname__[: method.__qualname__.rfind(".")]
try:
return cast(
type[Any],
eval(
method_container,
importlib.import_module(method.__module__).__dict__,
),
)
except NameError:
# we could not find the container of the given method in the method's global
# namespace - this is likely an inherited method where the parent class
# sits in a different module
log.debug(
f"failed to find container '{method.__module__}.{method_container}' "
f"of method '{method.__name__}'"
)
return None
@staticmethod
def _get_method_type(defining_class: type[Any], func: FunctionType) -> int:
# do we have a static or class method?
try:
raw_func = getattr_static(defining_class, func.__name__)
if isinstance(raw_func, staticmethod):
return METHOD_TYPE_STATIC
elif isinstance(raw_func, classmethod):
return METHOD_TYPE_CLASS
except AttributeError:
# this should not happen, but we try to handle this gracefully
log.warning(
f"failed to look up method {func.__name__!r} "
f"in class {defining_class.__name__}"
)
return METHOD_TYPE_DYNAMIC
def _get_original_signature(
self, func: FunctionType
) -> dict[str, type[Any] | TypeVar]:
# get the original signature as defined in the code
signature_original: dict[str, type[Any] | TypeVar]
try:
signature_original = self.original_signatures[func]
except KeyError:
signature_original = get_type_hints(func)
self.original_signatures[func] = signature_original
return signature_original
[docs] def process(self, app: Sphinx, obj: Any, bound_method: bool) -> None:
"""[see superclass]"""
self._update_current_class(self._track_current_class.current_class)
if isinstance(obj, FunctionType):
# instance method definitions are unbound, so we need to determine
# the class we are currently in from context
bindings: _TypeVarBindings | None
if obj.__name__ == obj.__qualname__:
# this is a function, not a method
return
defining_class = self._get_defining_class(obj)
assert (
defining_class is not None
), f"function {obj.__qualname__} has a defining class"
if obj.__name__ in ["__init__", "__init_subclass__", "__new__"] or (
obj.__name__ == "__call__" and issubclass(defining_class, type)
):
# special case of class initializer, this usually means that we are
# starting to document a new class, or refer to a special method
# of a metaclass
bindings = self._update_current_class(defining_class)
else:
bindings = self._current_class_bindings
assert (
bindings is not None
), f"bindings are in place for function {obj.__qualname__}"
assert issubclass(bindings.current_class, defining_class), (
f"current class {bindings.current_class.__name__} "
f"is a subclass of the class of unbound method {obj.__qualname__}, "
f"class={defining_class}"
)
self._resolve_function_signature(bindings=bindings, func=obj)
elif isinstance(obj, MethodType):
# class method definitions are bound, so we can infer the current class
cls = obj.__self__
assert isinstance(cls, type), "methods are class methods"
bindings = self._current_class_bindings
assert (
bindings is not None
), "bindings expected to be in place when processing a bound method"
assert (
bindings.current_class is cls
), "bindings expected to be for the correct class"
self._resolve_function_signature(
bindings=bindings, func=cast(FunctionType, obj.__func__)
)
def _update_current_class(self, cls: type[Any] | None) -> _TypeVarBindings | None:
if cls is None:
return None
bindings: _TypeVarBindings | None = self._current_class_bindings
if bindings is None or bindings.current_class is not cls:
# we're visiting a new class
log.debug(f"visiting new class {cls.__name__}")
# create a TypeVar bindings object for this class
bindings = self._current_class_bindings = _TypeVarBindings(cls)
# and resolve type variables in type annotations for class attributes
self._resolve_attribute_signatures(cls=cls)
return bindings
[docs]@inheritdoc(match="""[see superclass]""")
class TrackCurrentClass(AutodocProcessSignature, metaclass=SingletonABCMeta):
"""
Keep track of the class currently being processed by autodoc.
This is required to attribute unbound methods to the correct class, e.g.,
in class :class:`.ResolveTypeVariables`.
"""
#: The class currently being processed by autodoc.
current_class: type[Any] | None
def __init__(self) -> None:
super().__init__()
self.current_class = None
[docs] def process(
self,
app: Sphinx,
what: str,
name: str,
obj: object,
options: object,
signature: str | None,
return_annotation: str | None,
) -> tuple[str | None, str | None] | None:
"""[see superclass]"""
if what == "class":
cls = cast(type, obj)
self.current_class = cls
return None
[docs]@inheritdoc(match="""[see superclass]""")
class RenamePrivateArguments(AutodocBeforeProcessSignature, metaclass=SingletonABCMeta):
"""
Rename private argument names to their original names given in the source code.
For example, arg ``__x`` of method ``f`` in
.. code-block:: python
class A:
def f(self, __x: int) -> None:
...
will be renamed from its internal, private name ``_A__x`` back to ``__x``.
"""
[docs] def process(self, app: Sphinx, obj: Any, bound_method: bool) -> None:
"""[see superclass]"""
if not (bound_method or isinstance(obj, FunctionType)):
return
# get the name of the module or class that the function is a member of
try:
containers = obj.__qualname__.split(".")
except AttributeError:
return
if len(containers) < 2:
return
private_prefix = f"_{containers[-2]}__"
def get_public_name(name: str) -> str:
if name.startswith(private_prefix):
return name[len(private_prefix) - 2 :]
else:
return name
# get the original signature
try:
annotations: dict[str, Any] = obj.__annotations__
except AttributeError:
annotations = {}
annotations_original = list(annotations.items())
for name, annotation in annotations_original:
public_name = get_public_name(name)
if public_name is not name:
del annotations[name]
annotations[public_name] = annotation
try:
code = obj.__code__
arg_and_variable_names = code.co_varnames
if any(name.startswith(private_prefix) for name in arg_and_variable_names):
obj.__code__ = code.replace(
co_varnames=tuple(
get_public_name(name) for name in arg_and_variable_names
),
)
except AttributeError:
pass
[docs]@inheritdoc(match="""[see superclass]""")
class UpdateForwardReferences(AutodocProcessSignature, metaclass=SingletonABCMeta):
"""
A Sphinx autodoc process signature that updates forward references in the
docstring of a class.
"""
[docs] @subsdoc(
# match and delete the row that declares :return:
# remember this is a multiline string, so we need to match the whole line
pattern=r"\s*:return:.*",
replacement="",
using=AutodocProcessSignature.process,
)
def process(
self,
app: Sphinx,
what: str,
name: str,
obj: object,
options: object,
signature: str | None,
return_annotation: str | None,
) -> None:
"""[see superclass]"""
if what == "class":
try:
update_forward_references(cast(type, obj))
except Exception as e:
log.error(f"failed to update forward references for {name}: {e}")
# print the traceback to the console
import traceback
traceback.print_exc()
raise
return None
#
# validate __all__
#
__tracker.validate()
def _substitute_generic_type_arguments(
type_expression: type[Any] | TypeVar,
fn_substitute_type_vars: typing.Callable[
[type[Any] | TypeVar], type[Any] | TypeVar
],
) -> type[Any] | TypeVar:
# dynamically resolve type variables inside nested type expressions
type_args: tuple[list[type[Any] | TypeVar] | type[Any] | TypeVar, ...] = (
typing_inspect.get_args(type_expression)
)
if type_args:
if isinstance(type_expression, UnionType):
type_expression = Union[type_args[0], type_args[1]]
return _copy_generic_type_with_arguments(
type_expression=type_expression,
new_arguments=tuple(
(
list(map(fn_substitute_type_vars, arg))
if isinstance(arg, list)
else fn_substitute_type_vars(arg)
)
for arg in type_args
),
)
else:
return type_expression
def _copy_generic_type_with_arguments(
type_expression: type[Any] | TypeVar,
new_arguments: tuple[list[type[Any] | TypeVar] | type[Any] | TypeVar, ...],
) -> type[Any] | TypeVar:
# create a copy of the given type expression, replacing its type arguments with
# the given new arguments
origin = typing_inspect.get_origin(type_expression)
assert origin is not None
try:
copy_with: Callable[
[
tuple[
list[type[Any] | TypeVar] | type[Any] | TypeVar,
...,
]
],
type[Any] | TypeVar,
] = type_expression.copy_with # type: ignore
except AttributeError:
# this is a generic type that does not support copying
return cast(type[Any], origin[new_arguments])
# unpack callable args, since copy_with() expects a flat tuple
# (arg_1, arg_2, ..., arg_n, return)
# instead of ([arg_1, arg_2, ..., arg_n], return)
if (origin is collections.abc.Callable) and isinstance(new_arguments[0], list):
new_arguments = (*new_arguments[0], *new_arguments[1:])
return copy_with(new_arguments)
def _get_generic_bases(class_: type) -> tuple[type, ...]:
"""
Bugfix version of :func:`typing_inspect.get_generic_bases`.
Prevents getting the generic bases of the parent class if not defined for the given
class.
:param class_: class to get the generic bases for
:return: the generic base classes of the given class
"""
bases: tuple[type, ...] = typing_inspect.get_generic_bases(class_)
if not isinstance(
class_, GenericAlias
) and bases is typing_inspect.get_generic_bases(super(class_, class_)):
return ()
else:
return bases