1450 lines
54 KiB
Python
1450 lines
54 KiB
Python
import inspect
|
|
import itertools
|
|
import sys
|
|
from contextlib import suppress
|
|
from functools import partial
|
|
from typing import Any, Callable, Literal, Optional, Sequence, Union, get_args, get_origin
|
|
|
|
from attrs import define, field
|
|
|
|
from cyclopts._convert import (
|
|
ITERABLE_TYPES,
|
|
convert,
|
|
token_count,
|
|
)
|
|
from cyclopts.annotations import (
|
|
contains_hint,
|
|
is_attrs,
|
|
is_dataclass,
|
|
is_namedtuple,
|
|
is_nonetype,
|
|
is_pydantic,
|
|
is_typeddict,
|
|
is_union,
|
|
resolve,
|
|
resolve_annotated,
|
|
resolve_optional,
|
|
)
|
|
from cyclopts.exceptions import (
|
|
CoercionError,
|
|
CycloptsError,
|
|
MissingArgumentError,
|
|
MixedArgumentError,
|
|
RepeatArgumentError,
|
|
ValidationError,
|
|
)
|
|
from cyclopts.field_info import (
|
|
KEYWORD_ONLY,
|
|
POSITIONAL_ONLY,
|
|
POSITIONAL_OR_KEYWORD,
|
|
VAR_KEYWORD,
|
|
VAR_POSITIONAL,
|
|
FieldInfo,
|
|
_attrs_field_infos,
|
|
_generic_class_field_infos,
|
|
_pydantic_field_infos,
|
|
_typed_dict_field_infos,
|
|
get_field_infos,
|
|
signature_parameters,
|
|
)
|
|
from cyclopts.group import Group
|
|
from cyclopts.parameter import ITERATIVE_BOOL_IMPLICIT_VALUE, Parameter, get_parameters
|
|
from cyclopts.token import Token
|
|
from cyclopts.utils import UNSET, grouper, is_builtin
|
|
|
|
# parameter subkeys should not inherit these parameter values from their parent.
|
|
_PARAMETER_SUBKEY_BLOCKER = Parameter(
|
|
name=None,
|
|
converter=None, # pyright: ignore
|
|
validator=None,
|
|
accepts_keys=None,
|
|
consume_multiple=None,
|
|
env_var=None,
|
|
)
|
|
|
|
_kind_parent_child_reassignment = {
|
|
(POSITIONAL_OR_KEYWORD, POSITIONAL_OR_KEYWORD): POSITIONAL_OR_KEYWORD,
|
|
(POSITIONAL_OR_KEYWORD, POSITIONAL_ONLY): POSITIONAL_ONLY,
|
|
(POSITIONAL_OR_KEYWORD, KEYWORD_ONLY): KEYWORD_ONLY,
|
|
(POSITIONAL_OR_KEYWORD, VAR_POSITIONAL): VAR_POSITIONAL,
|
|
(POSITIONAL_OR_KEYWORD, VAR_KEYWORD): VAR_KEYWORD,
|
|
(POSITIONAL_ONLY, POSITIONAL_OR_KEYWORD): POSITIONAL_ONLY,
|
|
(POSITIONAL_ONLY, POSITIONAL_ONLY): POSITIONAL_ONLY,
|
|
(POSITIONAL_ONLY, KEYWORD_ONLY): None,
|
|
(POSITIONAL_ONLY, VAR_POSITIONAL): VAR_POSITIONAL,
|
|
(POSITIONAL_ONLY, VAR_KEYWORD): None,
|
|
(KEYWORD_ONLY, POSITIONAL_OR_KEYWORD): KEYWORD_ONLY,
|
|
(KEYWORD_ONLY, POSITIONAL_ONLY): None,
|
|
(KEYWORD_ONLY, KEYWORD_ONLY): KEYWORD_ONLY,
|
|
(KEYWORD_ONLY, VAR_POSITIONAL): None,
|
|
(KEYWORD_ONLY, VAR_KEYWORD): VAR_KEYWORD,
|
|
(VAR_POSITIONAL, POSITIONAL_OR_KEYWORD): POSITIONAL_ONLY,
|
|
(VAR_POSITIONAL, POSITIONAL_ONLY): POSITIONAL_ONLY,
|
|
(VAR_POSITIONAL, KEYWORD_ONLY): None,
|
|
(VAR_POSITIONAL, VAR_POSITIONAL): VAR_POSITIONAL,
|
|
(VAR_POSITIONAL, VAR_KEYWORD): None,
|
|
(VAR_KEYWORD, POSITIONAL_OR_KEYWORD): KEYWORD_ONLY,
|
|
(VAR_KEYWORD, POSITIONAL_ONLY): None,
|
|
(VAR_KEYWORD, KEYWORD_ONLY): KEYWORD_ONLY,
|
|
(VAR_KEYWORD, VAR_POSITIONAL): None,
|
|
(VAR_KEYWORD, VAR_KEYWORD): VAR_KEYWORD,
|
|
}
|
|
|
|
|
|
def _startswith(string, prefix):
|
|
def normalize(s):
|
|
return s.replace("_", "-")
|
|
|
|
return normalize(string).startswith(normalize(prefix))
|
|
|
|
|
|
def _missing_keys_factory(get_field_info: Callable[[Any], dict[str, FieldInfo]]):
|
|
def inner(argument: "Argument", data: dict[str, Any]) -> list[str]:
|
|
provided_keys = set(data)
|
|
field_info = get_field_info(argument.hint)
|
|
return [k for k, v in field_info.items() if (v.required and k not in provided_keys)]
|
|
|
|
return inner
|
|
|
|
|
|
def _identity_converter(type_, token):
|
|
return token
|
|
|
|
|
|
def _get_annotated_discriminator(annotation):
|
|
for meta in get_args(annotation)[1:]:
|
|
try:
|
|
return meta.discriminator
|
|
except AttributeError:
|
|
pass
|
|
return None
|
|
|
|
|
|
class ArgumentCollection(list["Argument"]):
|
|
"""A list-like container for :class:`Argument`."""
|
|
|
|
def __init__(self, *args):
|
|
super().__init__(*args)
|
|
|
|
def copy(self) -> "ArgumentCollection":
|
|
"""Returns a shallow copy of the :class:`ArgumentCollection`."""
|
|
return type(self)(self)
|
|
|
|
def match(
|
|
self,
|
|
term: Union[str, int],
|
|
*,
|
|
transform: Optional[Callable[[str], str]] = None,
|
|
delimiter: str = ".",
|
|
) -> tuple["Argument", tuple[str, ...], Any]:
|
|
"""Matches CLI keyword or index to their :class:`Argument`.
|
|
|
|
Parameters
|
|
----------
|
|
term: str | int
|
|
One of:
|
|
|
|
* :obj:`str` keyword like ``"--foo"`` or ``"-f"`` or ``"--foo.bar.baz"``.
|
|
|
|
* :obj:`int` global positional index.
|
|
|
|
Raises
|
|
------
|
|
ValueError
|
|
If the provided ``term`` doesn't match.
|
|
|
|
Returns
|
|
-------
|
|
Argument
|
|
Matched :class:`Argument`.
|
|
Tuple[str, ...]
|
|
Python keys into Argument. Non-empty iff Argument accepts keys.
|
|
Any
|
|
Implicit value (if a flag). :obj:`~.UNSET` otherwise.
|
|
"""
|
|
best_match_argument, best_match_keys, best_implicit_value = None, None, UNSET
|
|
for argument in self:
|
|
try:
|
|
match_keys, implicit_value = argument.match(term, transform=transform, delimiter=delimiter)
|
|
except ValueError:
|
|
continue
|
|
if best_match_keys is None or len(match_keys) < len(best_match_keys):
|
|
best_match_keys = match_keys
|
|
best_match_argument = argument
|
|
best_implicit_value = implicit_value
|
|
if not match_keys: # Perfect match
|
|
break
|
|
|
|
if best_match_argument is None or best_match_keys is None:
|
|
raise ValueError(f"No Argument matches {term!r}")
|
|
|
|
return best_match_argument, best_match_keys, best_implicit_value
|
|
|
|
def _set_marks(self, val: bool):
|
|
for argument in self:
|
|
argument._marked = val
|
|
|
|
def _convert(self):
|
|
"""Convert and validate all elements."""
|
|
self._set_marks(False)
|
|
for argument in sorted(self, key=lambda x: x.keys):
|
|
if argument._marked:
|
|
continue
|
|
argument.convert_and_validate()
|
|
|
|
@classmethod
|
|
def _from_type(
|
|
cls,
|
|
field_info: FieldInfo,
|
|
keys: tuple[str, ...],
|
|
*default_parameters: Optional[Parameter],
|
|
group_lookup: dict[str, Group],
|
|
group_arguments: Group,
|
|
group_parameters: Group,
|
|
parse_docstring: bool = True,
|
|
docstring_lookup: Optional[dict[tuple[str, ...], Parameter]] = None,
|
|
positional_index: Optional[int] = None,
|
|
_resolve_groups: bool = True,
|
|
):
|
|
out = cls() # groups=list(group_lookup.values()))
|
|
|
|
if docstring_lookup is None:
|
|
docstring_lookup = {}
|
|
|
|
cyclopts_parameters_no_group = []
|
|
|
|
hint = field_info.hint
|
|
hint, hint_parameters = get_parameters(hint)
|
|
cyclopts_parameters_no_group.extend(hint_parameters)
|
|
|
|
if not keys: # root hint annotation
|
|
if field_info.kind is field_info.VAR_KEYWORD:
|
|
hint = dict[str, hint]
|
|
elif field_info.kind is field_info.VAR_POSITIONAL:
|
|
hint = tuple[hint, ...]
|
|
|
|
if _resolve_groups:
|
|
cyclopts_parameters = []
|
|
for cparam in cyclopts_parameters_no_group:
|
|
resolved_groups = []
|
|
for group in cparam.group: # pyright:ignore
|
|
if isinstance(group, str):
|
|
group = group_lookup[group]
|
|
resolved_groups.append(group)
|
|
cyclopts_parameters.append(group.default_parameter)
|
|
cyclopts_parameters.append(cparam)
|
|
if resolved_groups:
|
|
cyclopts_parameters.append(Parameter(group=resolved_groups))
|
|
else:
|
|
cyclopts_parameters = cyclopts_parameters_no_group
|
|
|
|
upstream_parameter = Parameter.combine(
|
|
(
|
|
Parameter(group=group_arguments)
|
|
if field_info.kind in (field_info.POSITIONAL_ONLY, field_info.VAR_POSITIONAL)
|
|
else Parameter(group=group_parameters)
|
|
),
|
|
*default_parameters,
|
|
)
|
|
immediate_parameter = Parameter.combine(*cyclopts_parameters)
|
|
|
|
# We do NOT want to skip parse=False arguments here.
|
|
# This makes it easier to assemble ignored arguments downstream.
|
|
|
|
# resolve/derive the parameter name
|
|
if keys:
|
|
cparam = Parameter.combine(
|
|
upstream_parameter,
|
|
_PARAMETER_SUBKEY_BLOCKER,
|
|
immediate_parameter,
|
|
)
|
|
cparam = Parameter.combine(
|
|
cparam,
|
|
Parameter(
|
|
name=_resolve_parameter_name(
|
|
upstream_parameter.name, # pyright: ignore
|
|
(immediate_parameter.name or tuple(cparam.name_transform(x) for x in field_info.names))
|
|
+ cparam.alias, # pyright: ignore
|
|
)
|
|
),
|
|
)
|
|
else:
|
|
# This is directly on iparam
|
|
cparam = Parameter.combine(
|
|
upstream_parameter,
|
|
immediate_parameter,
|
|
)
|
|
assert isinstance(cparam.alias, tuple)
|
|
if cparam.name:
|
|
if field_info.is_keyword:
|
|
assert isinstance(cparam.name, tuple)
|
|
cparam = Parameter.combine(
|
|
cparam, Parameter(name=_resolve_parameter_name(cparam.name + cparam.alias))
|
|
)
|
|
else:
|
|
if field_info.kind in (field_info.POSITIONAL_ONLY, field_info.VAR_POSITIONAL):
|
|
# Name is only used for help-string
|
|
cparam = Parameter.combine(cparam, Parameter(name=(name.upper() for name in field_info.names)))
|
|
elif field_info.kind is field_info.VAR_KEYWORD:
|
|
cparam = Parameter.combine(cparam, Parameter(name=("--[KEYWORD]",)))
|
|
else:
|
|
# cparam.name_transform cannot be None due to:
|
|
# attrs.converters.default_if_none(default_name_transform)
|
|
assert cparam.name_transform is not None
|
|
cparam = Parameter.combine(
|
|
cparam,
|
|
Parameter(
|
|
name=tuple("--" + cparam.name_transform(name) for name in field_info.names)
|
|
+ _resolve_parameter_name(cparam.alias)
|
|
),
|
|
)
|
|
|
|
if field_info.is_keyword_only:
|
|
positional_index = None
|
|
|
|
argument = Argument(field_info=field_info, parameter=cparam, keys=keys, hint=hint)
|
|
if not argument._accepts_keywords and positional_index is not None:
|
|
argument.index = positional_index
|
|
positional_index += 1
|
|
|
|
out.append(argument)
|
|
if argument._accepts_keywords:
|
|
hint_docstring_lookup = _extract_docstring_help(argument.hint) if parse_docstring else {}
|
|
hint_docstring_lookup.update(docstring_lookup)
|
|
|
|
for sub_field_name, sub_field_info in argument._lookup.items():
|
|
updated_kind = _kind_parent_child_reassignment[(argument.field_info.kind, sub_field_info.kind)] # pyright: ignore
|
|
if updated_kind is None:
|
|
continue
|
|
|
|
sub_field_info.kind = updated_kind
|
|
|
|
if sub_field_info.is_keyword_only:
|
|
positional_index = None
|
|
|
|
subkey_docstring_lookup = {
|
|
k[1:]: v for k, v in hint_docstring_lookup.items() if k[0] == sub_field_name and len(k) > 1
|
|
}
|
|
|
|
subkey_argument_collection = cls._from_type(
|
|
sub_field_info,
|
|
keys + (sub_field_name,),
|
|
cparam,
|
|
(
|
|
Parameter(help=sub_field_info.help)
|
|
if sub_field_info.help
|
|
else hint_docstring_lookup.get((sub_field_name,))
|
|
),
|
|
Parameter(required=argument.required & sub_field_info.required),
|
|
group_lookup=group_lookup,
|
|
group_arguments=group_arguments,
|
|
group_parameters=group_parameters,
|
|
parse_docstring=parse_docstring,
|
|
docstring_lookup=subkey_docstring_lookup,
|
|
positional_index=positional_index,
|
|
_resolve_groups=_resolve_groups,
|
|
)
|
|
if subkey_argument_collection:
|
|
argument.children.append(subkey_argument_collection[0])
|
|
out.extend(subkey_argument_collection)
|
|
|
|
if positional_index is not None:
|
|
positional_index = subkey_argument_collection._max_index
|
|
if positional_index is not None:
|
|
positional_index += 1
|
|
|
|
return out
|
|
|
|
@classmethod
|
|
def _from_callable(
|
|
cls,
|
|
func: Callable,
|
|
*default_parameters: Optional[Parameter],
|
|
group_lookup: Optional[dict[str, Group]] = None,
|
|
group_arguments: Optional[Group] = None,
|
|
group_parameters: Optional[Group] = None,
|
|
parse_docstring: bool = True,
|
|
_resolve_groups: bool = True,
|
|
):
|
|
out = cls()
|
|
|
|
if group_arguments is None:
|
|
group_arguments = Group.create_default_arguments()
|
|
if group_parameters is None:
|
|
group_parameters = Group.create_default_parameters()
|
|
|
|
if _resolve_groups:
|
|
group_lookup = {
|
|
group.name: group
|
|
for group in _resolve_groups_from_callable(
|
|
func,
|
|
*default_parameters,
|
|
group_arguments=group_arguments,
|
|
group_parameters=group_parameters,
|
|
)
|
|
}
|
|
else:
|
|
group_lookup = {}
|
|
|
|
docstring_lookup = _extract_docstring_help(func) if parse_docstring else {}
|
|
positional_index = 0
|
|
for field_info in signature_parameters(func).values():
|
|
if parse_docstring:
|
|
subkey_docstring_lookup = {
|
|
k[1:]: v for k, v in docstring_lookup.items() if k[0] == field_info.name and len(k) > 1
|
|
}
|
|
else:
|
|
subkey_docstring_lookup = None
|
|
iparam_argument_collection = cls._from_type(
|
|
field_info,
|
|
(),
|
|
*default_parameters,
|
|
Parameter(help=field_info.help) if field_info.help else docstring_lookup.get((field_info.name,)),
|
|
group_lookup=group_lookup,
|
|
group_arguments=group_arguments,
|
|
group_parameters=group_parameters,
|
|
positional_index=positional_index,
|
|
parse_docstring=parse_docstring,
|
|
docstring_lookup=subkey_docstring_lookup,
|
|
_resolve_groups=_resolve_groups,
|
|
)
|
|
if positional_index is not None:
|
|
positional_index = iparam_argument_collection._max_index
|
|
if positional_index is not None:
|
|
positional_index += 1
|
|
out.extend(iparam_argument_collection)
|
|
|
|
return out
|
|
|
|
@property
|
|
def groups(self):
|
|
groups = []
|
|
for argument in self:
|
|
assert isinstance(argument.parameter.group, tuple)
|
|
for group in argument.parameter.group:
|
|
if group not in groups:
|
|
groups.append(group)
|
|
return groups
|
|
|
|
@property
|
|
def _root_arguments(self):
|
|
for argument in self:
|
|
if not argument.keys:
|
|
yield argument
|
|
|
|
@property
|
|
def _max_index(self) -> Optional[int]:
|
|
return max((x.index for x in self if x.index is not None), default=None)
|
|
|
|
def filter_by(
|
|
self,
|
|
*,
|
|
group: Optional[Group] = None,
|
|
has_tokens: Optional[bool] = None,
|
|
has_tree_tokens: Optional[bool] = None,
|
|
keys_prefix: Optional[tuple[str, ...]] = None,
|
|
kind: Optional[inspect._ParameterKind] = None,
|
|
parse: Optional[bool] = None,
|
|
show: Optional[bool] = None,
|
|
value_set: Optional[bool] = None,
|
|
) -> "ArgumentCollection":
|
|
"""Filter the :class:`ArgumentCollection`.
|
|
|
|
All non-:obj:`None` filters will be applied.
|
|
|
|
Parameters
|
|
----------
|
|
group: Optional[Group]
|
|
The :class:`.Group` the arguments should be in.
|
|
has_tokens: Optional[bool]
|
|
Immediately has tokens (not including children).
|
|
has_tree_tokens: Optional[bool]
|
|
Argument and/or it's children have parsed tokens.
|
|
kind: Optional[inspect._ParameterKind]
|
|
The :attr:`~inspect.Parameter.kind` of the argument.
|
|
parse: Optional[bool]
|
|
If the argument is intended to be parsed or not.
|
|
show: Optional[bool]
|
|
The Argument is intended to be show on the help page.
|
|
value_set: Optional[bool]
|
|
The converted value is set.
|
|
"""
|
|
ac = self.copy()
|
|
cls = type(self)
|
|
|
|
if group is not None:
|
|
ac = cls(x for x in ac if group in x.parameter.group) # pyright: ignore
|
|
if kind is not None:
|
|
ac = cls(x for x in ac if x.field_info.kind == kind)
|
|
if has_tokens is not None:
|
|
ac = cls(x for x in ac if not (bool(x.tokens) ^ bool(has_tokens)))
|
|
if has_tree_tokens is not None:
|
|
ac = cls(x for x in ac if not (bool(x.tokens) ^ bool(has_tree_tokens)))
|
|
if keys_prefix is not None:
|
|
ac = cls(x for x in ac if x.keys[: len(keys_prefix)] == keys_prefix)
|
|
if show is not None:
|
|
ac = cls(x for x in ac if not (x.show ^ bool(show)))
|
|
if value_set is not None:
|
|
ac = cls(x for x in ac if ((x.value is UNSET) ^ bool(value_set)))
|
|
if parse is not None:
|
|
ac = cls(x for x in ac if not (x.parameter.parse ^ parse))
|
|
|
|
return ac
|
|
|
|
|
|
@define(kw_only=True)
|
|
class Argument:
|
|
"""Encapsulates functionality and additional contextual information for parsing a parameter.
|
|
|
|
An argument is defined as anything that would have its own entry in the help page.
|
|
"""
|
|
|
|
tokens: list[Token] = field(factory=list)
|
|
"""
|
|
List of :class:`.Token` parsed from various sources.
|
|
Do not directly mutate; see :meth:`append`.
|
|
"""
|
|
|
|
field_info: FieldInfo = field(factory=FieldInfo)
|
|
"""
|
|
Additional information about the parameter from surrounding python syntax.
|
|
"""
|
|
|
|
parameter: Parameter = field(factory=Parameter) # pyright: ignore
|
|
"""
|
|
Fully resolved user-provided :class:`.Parameter`.
|
|
"""
|
|
|
|
hint: Any = field(default=str, converter=resolve)
|
|
"""
|
|
The type hint for this argument; may be different from :attr:`.FieldInfo.annotation`.
|
|
"""
|
|
|
|
index: Optional[int] = field(default=None)
|
|
"""
|
|
Associated python positional index for argument.
|
|
If ``None``, then cannot be assigned positionally.
|
|
"""
|
|
|
|
keys: tuple[str, ...] = field(default=())
|
|
"""
|
|
**Python** keys that lead to this leaf.
|
|
|
|
``self.parameter.name`` and ``self.keys`` can naively disagree!
|
|
For example, a ``self.parameter.name="--foo.bar.baz"`` could be aliased to "--fizz".
|
|
The resulting ``self.keys`` would be ``("bar", "baz")``.
|
|
|
|
This is populated based on type-hints and class-structure, not ``Parameter.name``.
|
|
|
|
.. code-block:: python
|
|
|
|
from cyclopts import App, Parameter
|
|
from dataclasses import dataclass
|
|
from typing import Annotated
|
|
|
|
app = App()
|
|
|
|
|
|
@dataclass
|
|
class User:
|
|
id: int
|
|
name: Annotated[str, Parameter(name="--fullname")]
|
|
|
|
|
|
@app.default
|
|
def main(user: User):
|
|
pass
|
|
|
|
|
|
for argument in app.assemble_argument_collection():
|
|
print(f"name: {argument.name:16} hint: {str(argument.hint):16} keys: {str(argument.keys)}")
|
|
|
|
.. code-block:: bash
|
|
|
|
$ my-script
|
|
name: --user.id hint: <class 'int'> keys: ('id',)
|
|
name: --fullname hint: <class 'str'> keys: ('name',)
|
|
"""
|
|
|
|
# Converted value; may be stale.
|
|
_value: Any = field(alias="value", default=UNSET)
|
|
"""
|
|
Converted value from last :meth:`convert` call.
|
|
This value may be stale if fields have changed since last :meth:`convert` call.
|
|
:class:`.UNSET` if :meth:`convert` has not yet been called with tokens.
|
|
"""
|
|
|
|
_accepts_keywords: bool = field(default=False, init=False, repr=False)
|
|
|
|
_default: Any = field(default=None, init=False, repr=False)
|
|
_lookup: dict[str, FieldInfo] = field(factory=dict, init=False, repr=False)
|
|
|
|
children: "ArgumentCollection" = field(factory=ArgumentCollection, init=False, repr=False)
|
|
"""
|
|
Collection of other :class:`Argument` that eventually culminate into the python variable represented by :attr:`field_info`.
|
|
"""
|
|
|
|
_marked_converted: bool = field(default=False, init=False, repr=False) # for mark & sweep algos
|
|
_mark_converted_override: bool = field(default=False, init=False, repr=False)
|
|
|
|
# Validator to be called based on builtin type support.
|
|
_missing_keys_checker: Optional[Callable] = field(default=None, init=False, repr=False)
|
|
|
|
_internal_converter: Optional[Callable] = field(default=None, init=False, repr=False)
|
|
|
|
def __attrs_post_init__(self):
|
|
# By definition, self.hint is Not AnnotatedType
|
|
hint = resolve(self.hint)
|
|
hints = get_args(hint) if is_union(hint) else (hint,)
|
|
|
|
if not self.parameter.parse:
|
|
return
|
|
|
|
if self.parameter.accepts_keys is False: # ``None`` means to infer.
|
|
return
|
|
|
|
for hint in hints:
|
|
# ``self.parameter.accepts_keys`` is either ``None`` or ``True`` here
|
|
origin = get_origin(hint)
|
|
hint_origin = {hint, origin}
|
|
|
|
# Classes that ALWAYS takes keywords (accepts_keys=None)
|
|
field_infos = get_field_infos(hint)
|
|
if dict in hint_origin:
|
|
self._accepts_keywords = True
|
|
key_type, val_type = str, str
|
|
args = get_args(hint)
|
|
with suppress(IndexError):
|
|
key_type = args[0]
|
|
val_type = args[1]
|
|
if key_type is not str:
|
|
raise TypeError('Dictionary type annotations must have "str" keys.')
|
|
self._default = val_type
|
|
elif is_typeddict(hint):
|
|
self._missing_keys_checker = _missing_keys_factory(_typed_dict_field_infos)
|
|
self._accepts_keywords = True
|
|
self._update_lookup(field_infos)
|
|
elif is_dataclass(hint): # Typical usecase of a dataclass will have more than 1 field.
|
|
self._missing_keys_checker = _missing_keys_factory(_generic_class_field_infos)
|
|
self._accepts_keywords = True
|
|
self._update_lookup(field_infos)
|
|
elif is_namedtuple(hint):
|
|
# collections.namedtuple does not have type hints, assume "str" for everything.
|
|
self._missing_keys_checker = _missing_keys_factory(_generic_class_field_infos)
|
|
self._accepts_keywords = True
|
|
if not hasattr(hint, "__annotations__"):
|
|
raise ValueError("Cyclopts cannot handle collections.namedtuple in python <3.10.")
|
|
self._update_lookup(field_infos)
|
|
elif is_attrs(hint):
|
|
self._missing_keys_checker = _missing_keys_factory(_attrs_field_infos)
|
|
self._accepts_keywords = True
|
|
self._update_lookup(field_infos)
|
|
elif is_pydantic(hint):
|
|
self._missing_keys_checker = _missing_keys_factory(_pydantic_field_infos)
|
|
self._accepts_keywords = True
|
|
# pydantic's __init__ signature doesn't accurately reflect its requirements.
|
|
# so we cannot use _generic_class_required_optional(...)
|
|
self._update_lookup(field_infos)
|
|
elif not is_builtin(hint) and field_infos:
|
|
# Some classic user class.
|
|
self._missing_keys_checker = _missing_keys_factory(_generic_class_field_infos)
|
|
self._accepts_keywords = True
|
|
self._update_lookup(field_infos)
|
|
elif self.parameter.accepts_keys is None:
|
|
# Typical builtin hint
|
|
continue
|
|
|
|
if self.parameter.accepts_keys is None:
|
|
continue
|
|
# Only explicit ``self.parameter.accepts_keys == True`` from here on
|
|
|
|
# Classes that MAY take keywords (accepts_keys=True)
|
|
# They must be explicitly specified ``accepts_keys=True`` because otherwise
|
|
# providing a single positional argument is what we want.
|
|
self._accepts_keywords = True
|
|
self._missing_keys_checker = _missing_keys_factory(_generic_class_field_infos)
|
|
for i, field_info in enumerate(signature_parameters(hint.__init__).values()):
|
|
if i == 0 and field_info.name == "self":
|
|
continue
|
|
if field_info.kind is field_info.VAR_KEYWORD:
|
|
self._default = field_info.annotation
|
|
else:
|
|
self._update_lookup({field_info.name: field_info})
|
|
|
|
def _update_lookup(self, field_infos: dict[str, FieldInfo]):
|
|
discriminator = _get_annotated_discriminator(self.field_info.annotation)
|
|
|
|
for key, field_info in field_infos.items():
|
|
if existing_field_info := self._lookup.get(key):
|
|
if existing_field_info == field_info:
|
|
pass
|
|
elif discriminator and discriminator in field_info.names and discriminator in existing_field_info.names:
|
|
existing_field_info.annotation = Literal[existing_field_info.annotation, field_info.annotation] # pyright: ignore
|
|
existing_field_info.default = FieldInfo.empty
|
|
else:
|
|
raise NotImplementedError
|
|
else:
|
|
self._lookup[key] = field_info
|
|
|
|
@property
|
|
def value(self):
|
|
"""Converted value from last :meth:`convert` call.
|
|
|
|
This value may be stale if fields have changed since last :meth:`convert` call.
|
|
:class:`.UNSET` if :meth:`convert` has not yet been called with tokens.
|
|
"""
|
|
return self._value
|
|
|
|
@value.setter
|
|
def value(self, val):
|
|
if self._marked:
|
|
self._mark_converted_override = True
|
|
self._marked = True
|
|
self._value = val
|
|
|
|
@property
|
|
def _marked(self):
|
|
"""If ``True``, then this node in the tree has already been converted and ``value`` has been populated."""
|
|
return self._marked_converted | self._mark_converted_override
|
|
|
|
@_marked.setter
|
|
def _marked(self, value: bool):
|
|
self._marked_converted = value
|
|
|
|
@property
|
|
def _accepts_arbitrary_keywords(self) -> bool:
|
|
args = get_args(self.hint) if is_union(self.hint) else (self.hint,)
|
|
return any(dict in (arg, get_origin(arg)) for arg in args)
|
|
|
|
@property
|
|
def show_default(self) -> Union[bool, Callable[[Any], str]]:
|
|
"""Show the default value on the help page."""
|
|
if self.required: # By definition, a required parameter cannot have a default.
|
|
return False
|
|
elif self.parameter.show_default is None:
|
|
# Showing a default ``None`` value is typically not helpful to the end-user.
|
|
return self.field_info.default not in (None, self.field_info.empty)
|
|
elif (self.field_info.default is self.field_info.empty) or not self.parameter.show_default:
|
|
return False
|
|
else:
|
|
return self.parameter.show_default
|
|
|
|
@property
|
|
def _use_pydantic_type_adapter(self) -> bool:
|
|
return bool(
|
|
is_pydantic(self.hint)
|
|
or (
|
|
is_union(self.hint)
|
|
and (
|
|
any(is_pydantic(x) for x in get_args(self.hint))
|
|
or _get_annotated_discriminator(self.field_info.annotation)
|
|
)
|
|
)
|
|
)
|
|
|
|
def _type_hint_for_key(self, key: str):
|
|
try:
|
|
return self._lookup[key].annotation
|
|
except KeyError:
|
|
if self._default is None:
|
|
raise
|
|
return self._default
|
|
|
|
def _should_attempt_json_dict(self, tokens: Optional[Sequence[Union[Token, str]]] = None) -> bool:
|
|
"""When parsing, should attempt to parse the token(s) as json dict data."""
|
|
if tokens is None:
|
|
tokens = self.tokens
|
|
if not tokens:
|
|
return False
|
|
if not self._accepts_keywords:
|
|
return False
|
|
value = tokens[0].value if isinstance(tokens[0], Token) else tokens[0]
|
|
if not value.strip().startswith("{"):
|
|
return False
|
|
if self.parameter.json_dict is not None:
|
|
return self.parameter.json_dict
|
|
if contains_hint(self.field_info.annotation, str):
|
|
return False
|
|
return True
|
|
|
|
def _should_attempt_json_list(
|
|
self, tokens: Union[Sequence[Union[Token, str]], Token, str, None] = None, keys: tuple[str, ...] = ()
|
|
) -> bool:
|
|
"""When parsing, should attempt to parse the token(s) as json list data."""
|
|
if tokens is None:
|
|
tokens = self.tokens
|
|
if not tokens:
|
|
return False
|
|
_, consume_all = self.token_count(keys)
|
|
if not consume_all:
|
|
return False
|
|
if isinstance(tokens, Token):
|
|
value = tokens.value
|
|
elif isinstance(tokens, str):
|
|
value = tokens
|
|
else:
|
|
value = tokens[0].value if isinstance(tokens[0], Token) else tokens[0]
|
|
if not value.strip().startswith("["):
|
|
return False
|
|
if self.parameter.json_list is not None:
|
|
return self.parameter.json_list
|
|
for arg in get_args(self.field_info.annotation) or (str,):
|
|
if contains_hint(arg, str):
|
|
return False
|
|
return True
|
|
|
|
def match(
|
|
self,
|
|
term: Union[str, int],
|
|
*,
|
|
transform: Optional[Callable[[str], str]] = None,
|
|
delimiter: str = ".",
|
|
) -> tuple[tuple[str, ...], Any]:
|
|
"""Match a name search-term, or a positional integer index.
|
|
|
|
Raises
|
|
------
|
|
ValueError
|
|
If no match is found.
|
|
|
|
Returns
|
|
-------
|
|
Tuple[str, ...]
|
|
Leftover keys after matching to this argument.
|
|
Used if this argument accepts_arbitrary_keywords.
|
|
Any
|
|
Implicit value.
|
|
:obj:`~.UNSET` if no implicit value is applicable.
|
|
"""
|
|
if not self.parameter.parse:
|
|
raise ValueError
|
|
return (
|
|
self._match_index(term)
|
|
if isinstance(term, int)
|
|
else self._match_name(term, transform=transform, delimiter=delimiter)
|
|
)
|
|
|
|
def _match_name(
|
|
self,
|
|
term: str,
|
|
*,
|
|
transform: Optional[Callable[[str], str]] = None,
|
|
delimiter: str = ".",
|
|
) -> tuple[tuple[str, ...], Any]:
|
|
"""Check how well this argument matches a token keyword identifier.
|
|
|
|
Parameters
|
|
----------
|
|
term: str
|
|
Something like "--foo"
|
|
transform: Callable
|
|
Function that converts the cyclopts Parameter name(s) into
|
|
something that should be compared against ``term``.
|
|
|
|
Raises
|
|
------
|
|
ValueError
|
|
If no match found.
|
|
|
|
Returns
|
|
-------
|
|
Tuple[str, ...]
|
|
Leftover keys after matching to this argument.
|
|
Used if this argument accepts_arbitrary_keywords.
|
|
Any
|
|
Implicit value.
|
|
"""
|
|
if self.field_info.kind is self.field_info.VAR_KEYWORD:
|
|
return tuple(term.lstrip("-").split(delimiter)), UNSET
|
|
|
|
trailing = term
|
|
implicit_value = UNSET
|
|
|
|
assert self.parameter.name
|
|
for name in self.parameter.name:
|
|
if transform:
|
|
name = transform(name)
|
|
if _startswith(term, name):
|
|
trailing = term[len(name) :]
|
|
implicit_value = True if self.hint is bool or self.hint in ITERATIVE_BOOL_IMPLICIT_VALUE else UNSET
|
|
if trailing:
|
|
if trailing[0] == delimiter:
|
|
trailing = trailing[1:]
|
|
break
|
|
# Otherwise, it's not an actual match.
|
|
else:
|
|
# exact match
|
|
return (), implicit_value
|
|
else:
|
|
# No positive-name matches found.
|
|
hint = resolve_annotated(self.field_info.annotation)
|
|
if is_union(hint):
|
|
hints = get_args(hint)
|
|
else:
|
|
hints = (hint,)
|
|
for hint in hints:
|
|
hint = resolve_annotated(hint)
|
|
double_break = False
|
|
for name in self.parameter.get_negatives(hint):
|
|
if transform:
|
|
name = transform(name)
|
|
if term.startswith(name):
|
|
trailing = term[len(name) :]
|
|
if hint in ITERATIVE_BOOL_IMPLICIT_VALUE:
|
|
implicit_value = False
|
|
elif is_nonetype(hint) or hint is None:
|
|
implicit_value = None
|
|
else:
|
|
hint = resolve_optional(hint)
|
|
implicit_value = (get_origin(hint) or hint)() # pyright: ignore[reportAbstractUsage]
|
|
if trailing:
|
|
if trailing[0] == delimiter:
|
|
trailing = trailing[1:]
|
|
double_break = True
|
|
break
|
|
# Otherwise, it's not an actual match.
|
|
else:
|
|
# exact match
|
|
return (), implicit_value
|
|
if double_break:
|
|
break
|
|
else:
|
|
# No negative-name matches found.
|
|
raise ValueError
|
|
|
|
if not self._accepts_arbitrary_keywords:
|
|
# Still not an actual match.
|
|
raise ValueError
|
|
|
|
return tuple(trailing.split(delimiter)), implicit_value
|
|
|
|
def _match_index(self, index: int) -> tuple[tuple[str, ...], Any]:
|
|
if self.index is None:
|
|
raise ValueError
|
|
elif self.field_info.kind is self.field_info.VAR_POSITIONAL:
|
|
if index < self.index:
|
|
raise ValueError
|
|
elif index != self.index:
|
|
raise ValueError
|
|
return (), UNSET
|
|
|
|
def append(self, token: Token):
|
|
"""Safely add a :class:`Token`."""
|
|
if not self.parameter.parse:
|
|
raise ValueError
|
|
|
|
if any(x.address == token.address for x in self.tokens):
|
|
_, consume_all = self.token_count(token.keys)
|
|
if not consume_all:
|
|
raise RepeatArgumentError(token=token)
|
|
|
|
if self.tokens:
|
|
if bool(token.keys) ^ any(x.keys for x in self.tokens):
|
|
raise MixedArgumentError(argument=self)
|
|
self.tokens.append(token)
|
|
|
|
@property
|
|
def has_tokens(self) -> bool:
|
|
"""This argument, or a child argument, has at least 1 parsed token.""" # noqa: D404
|
|
return bool(self.tokens) or any(x.has_tokens for x in self.children)
|
|
|
|
@property
|
|
def children_recursive(self) -> "ArgumentCollection":
|
|
out = ArgumentCollection()
|
|
for child in self.children:
|
|
out.append(child)
|
|
out.extend(child.children_recursive)
|
|
return out
|
|
|
|
def _convert_pydantic(self):
|
|
if self.has_tokens:
|
|
import pydantic
|
|
|
|
unstructured_data = self._json()
|
|
try:
|
|
# This inherently also invokes pydantic validators
|
|
return pydantic.TypeAdapter(self.field_info.annotation).validate_python(unstructured_data)
|
|
except pydantic.ValidationError as e:
|
|
self._handle_pydantic_validation_error(e)
|
|
else:
|
|
return UNSET
|
|
|
|
def _convert(self, converter: Optional[Callable] = None):
|
|
if self.parameter.converter:
|
|
converter = self.parameter.converter
|
|
elif converter is None:
|
|
converter = partial(convert, name_transform=self.parameter.name_transform)
|
|
|
|
def safe_converter(hint, tokens):
|
|
if isinstance(tokens, dict):
|
|
try:
|
|
return converter(hint, tokens) # pyright: ignore
|
|
except (AssertionError, ValueError, TypeError) as e:
|
|
raise CoercionError(msg=e.args[0] if e.args else None, argument=self, target_type=hint) from e
|
|
else:
|
|
try:
|
|
return converter(hint, tokens) # pyright: ignore
|
|
except (AssertionError, ValueError, TypeError) as e:
|
|
token = tokens[0] if len(tokens) == 1 else None
|
|
raise CoercionError(
|
|
msg=e.args[0] if e.args else None, argument=self, target_type=hint, token=token
|
|
) from e
|
|
|
|
if not self.parameter.parse:
|
|
out = UNSET
|
|
elif not self.children:
|
|
positional: list[Token] = []
|
|
keyword = {}
|
|
|
|
def expand_tokens(tokens):
|
|
for token in tokens:
|
|
if self._should_attempt_json_list(token):
|
|
import json
|
|
|
|
try:
|
|
parsed_json = json.loads(token.value)
|
|
except json.JSONDecodeError as e:
|
|
raise CoercionError(token=token, target_type=self.hint) from e
|
|
|
|
if not isinstance(parsed_json, list):
|
|
raise CoercionError(token=token, target_type=self.hint)
|
|
|
|
for element in parsed_json:
|
|
if element is None:
|
|
yield token.evolve(value="", implicit_value=element)
|
|
else:
|
|
yield token.evolve(value=str(element))
|
|
else:
|
|
yield token
|
|
|
|
expanded_tokens = list(expand_tokens(self.tokens))
|
|
for token in expanded_tokens:
|
|
if token.implicit_value is not UNSET and isinstance(
|
|
token.implicit_value, get_origin(self.hint) or self.hint
|
|
):
|
|
assert len(expanded_tokens) == 1
|
|
return token.implicit_value
|
|
|
|
if token.keys:
|
|
lookup = keyword
|
|
for key in token.keys[:-1]:
|
|
lookup = lookup.setdefault(key, {})
|
|
lookup.setdefault(token.keys[-1], []).append(token)
|
|
else:
|
|
positional.append(token)
|
|
|
|
if positional and keyword: # pragma: no cover
|
|
# This should never happen due to checks in ``Argument.append``
|
|
raise MixedArgumentError(argument=self)
|
|
|
|
if positional:
|
|
if self.field_info and self.field_info.kind is self.field_info.VAR_POSITIONAL:
|
|
# Apply converter to individual values
|
|
hint = get_args(self.hint)[0]
|
|
tokens_per_element, _ = self.token_count()
|
|
out = tuple(safe_converter(hint, values) for values in grouper(positional, tokens_per_element))
|
|
else:
|
|
out = safe_converter(self.hint, tuple(positional))
|
|
elif keyword:
|
|
if self.field_info and self.field_info.kind is self.field_info.VAR_KEYWORD and not self.keys:
|
|
# Apply converter to individual values
|
|
out = {key: safe_converter(get_args(self.hint)[1], value) for key, value in keyword.items()}
|
|
else:
|
|
out = safe_converter(self.hint, keyword)
|
|
elif self.required:
|
|
raise MissingArgumentError(argument=self)
|
|
else: # no tokens
|
|
return UNSET
|
|
else: # A dictionary-like structure.
|
|
data = {}
|
|
|
|
if self._should_attempt_json_dict():
|
|
# Dict-like structures may have incoming json data from an environment variable.
|
|
# Pass these values along as Tokens to children.
|
|
import json
|
|
|
|
from cyclopts.config._common import update_argument_collection
|
|
|
|
while self.tokens:
|
|
token = self.tokens.pop(0)
|
|
try:
|
|
parsed_json = json.loads(token.value)
|
|
except json.JSONDecodeError as e:
|
|
raise CoercionError(token=token, target_type=self.hint) from e
|
|
update_argument_collection(
|
|
{self.name.lstrip("-"): parsed_json},
|
|
token.source,
|
|
self.children_recursive,
|
|
root_keys=(),
|
|
allow_unknown=False,
|
|
)
|
|
|
|
if self._use_pydantic_type_adapter:
|
|
return self._convert_pydantic()
|
|
|
|
for child in self.children:
|
|
assert len(child.keys) == (len(self.keys) + 1)
|
|
if child.has_tokens: # Either the child directly has tokens, or a nested child has tokens.
|
|
data[child.keys[-1]] = child.convert_and_validate(converter=converter)
|
|
elif child.required:
|
|
# Check if the required fields are already populated.
|
|
obj = data
|
|
for k in child.keys:
|
|
try:
|
|
obj = obj[k]
|
|
except Exception:
|
|
raise MissingArgumentError(argument=child) from None
|
|
child._marked = True
|
|
|
|
self._run_missing_keys_checker(data)
|
|
|
|
if data:
|
|
out = self.hint(**data)
|
|
elif self.required:
|
|
# This should NEVER happen: empty data to a required dict field.
|
|
raise MissingArgumentError(argument=self) # pragma: no cover
|
|
else:
|
|
out = UNSET
|
|
|
|
return out
|
|
|
|
def convert(self, converter: Optional[Callable] = None):
|
|
"""Converts :attr:`tokens` into :attr:`value`.
|
|
|
|
Parameters
|
|
----------
|
|
converter: Optional[Callable]
|
|
Converter function to use. Overrides ``self.parameter.converter``
|
|
|
|
Returns
|
|
-------
|
|
Any
|
|
The converted data. Same as :attr:`value`.
|
|
"""
|
|
if not self._marked:
|
|
try:
|
|
self.value = self._convert(converter=converter)
|
|
except CoercionError as e:
|
|
if e.argument is None:
|
|
e.argument = self
|
|
if e.target_type is None:
|
|
e.target_type = self.hint
|
|
raise
|
|
except CycloptsError as e:
|
|
if e.argument is None:
|
|
e.argument = self
|
|
raise
|
|
|
|
return self.value
|
|
|
|
def validate(self, value):
|
|
"""Validates provided value.
|
|
|
|
Parameters
|
|
----------
|
|
value:
|
|
Value to validate.
|
|
|
|
Returns
|
|
-------
|
|
Any
|
|
The converted data. Same as :attr:`value`.
|
|
"""
|
|
assert isinstance(self.parameter.validator, tuple)
|
|
|
|
if "pydantic" in sys.modules:
|
|
import pydantic
|
|
|
|
pydantic_version = tuple(int(x) for x in pydantic.__version__.split("."))
|
|
if pydantic_version < (2,):
|
|
# Cyclopts does NOT support/use pydantic v1.
|
|
pydantic = None
|
|
else:
|
|
pydantic = None
|
|
|
|
def validate_pydantic(hint, val):
|
|
if not pydantic:
|
|
return
|
|
if self._use_pydantic_type_adapter:
|
|
# Pydantic already called the validators
|
|
return
|
|
|
|
try:
|
|
pydantic.TypeAdapter(hint).validate_python(val)
|
|
except pydantic.ValidationError as e:
|
|
self._handle_pydantic_validation_error(e)
|
|
except pydantic.PydanticUserError:
|
|
# Pydantic couldn't generate a schema for this type hint.
|
|
pass
|
|
|
|
try:
|
|
if not self.keys and self.field_info and self.field_info.kind is self.field_info.VAR_KEYWORD:
|
|
hint = get_args(self.hint)[1]
|
|
for validator in self.parameter.validator:
|
|
for val in value.values():
|
|
validator(hint, val)
|
|
validate_pydantic(dict[str, self.field_info.annotation], value)
|
|
elif self.field_info and self.field_info.kind is self.field_info.VAR_POSITIONAL:
|
|
hint = get_args(self.hint)[0]
|
|
for validator in self.parameter.validator:
|
|
for val in value:
|
|
validator(hint, val)
|
|
validate_pydantic(tuple[self.field_info.annotation, ...], value)
|
|
else:
|
|
for validator in self.parameter.validator:
|
|
validator(self.hint, value)
|
|
validate_pydantic(self.field_info.annotation, value)
|
|
except (AssertionError, ValueError, TypeError) as e:
|
|
raise ValidationError(exception_message=e.args[0] if e.args else "", argument=self) from e
|
|
|
|
def convert_and_validate(self, converter: Optional[Callable] = None):
|
|
"""Converts and validates :attr:`tokens` into :attr:`value`.
|
|
|
|
Parameters
|
|
----------
|
|
converter: Optional[Callable]
|
|
Converter function to use. Overrides ``self.parameter.converter``
|
|
|
|
Returns
|
|
-------
|
|
Any
|
|
The converted data. Same as :attr:`value`.
|
|
"""
|
|
val = self.convert(converter=converter)
|
|
if val is not UNSET:
|
|
self.validate(val)
|
|
elif self.field_info.default is not FieldInfo.empty:
|
|
self.validate(self.field_info.default)
|
|
return val
|
|
|
|
def token_count(self, keys: tuple[str, ...] = ()):
|
|
"""The number of string tokens this argument consumes.
|
|
|
|
Parameters
|
|
----------
|
|
keys: tuple[str, ...]
|
|
The **python** keys into this argument.
|
|
If provided, returns the number of string tokens that specific
|
|
data type within the argument consumes.
|
|
|
|
Returns
|
|
-------
|
|
int
|
|
Number of string tokens to create 1 element.
|
|
consume_all: bool
|
|
:obj:`True` if this data type is iterable.
|
|
"""
|
|
if len(keys) > 1:
|
|
hint = self._default
|
|
elif len(keys) == 1:
|
|
hint = self._type_hint_for_key(keys[0])
|
|
else:
|
|
hint = self.hint
|
|
tokens_per_element, consume_all = token_count(hint)
|
|
return tokens_per_element, consume_all
|
|
|
|
@property
|
|
def negatives(self):
|
|
"""Negative flags from :meth:`.Parameter.get_negatives`."""
|
|
return self.parameter.get_negatives(resolve_annotated(self.field_info.annotation))
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
"""The **first** provided name this argument goes by."""
|
|
return self.names[0]
|
|
|
|
@property
|
|
def names(self) -> tuple[str, ...]:
|
|
"""Names the argument goes by (both positive and negative)."""
|
|
assert isinstance(self.parameter.name, tuple)
|
|
return tuple(itertools.chain(self.parameter.name, self.negatives))
|
|
|
|
def env_var_split(self, value: str, delimiter: Optional[str] = None) -> list[str]:
|
|
"""Split a given value with :meth:`.Parameter.env_var_split`."""
|
|
return self.parameter.env_var_split(self.hint, value, delimiter=delimiter)
|
|
|
|
@property
|
|
def show(self) -> bool:
|
|
"""Show this argument on the help page.
|
|
|
|
If an argument has child arguments, don't show it on the help-page.
|
|
"""
|
|
return not self.children and self.parameter.show
|
|
|
|
@property
|
|
def required(self) -> bool:
|
|
"""Whether or not this argument requires a user-provided value."""
|
|
if self.parameter.required is None:
|
|
return self.field_info.required
|
|
else:
|
|
return self.parameter.required
|
|
|
|
def _json(self) -> dict:
|
|
"""Convert argument to be json-like for pydantic.
|
|
|
|
All values will be str/list/dict.
|
|
"""
|
|
out = {}
|
|
if self._accepts_keywords:
|
|
for token in self.tokens:
|
|
node = out
|
|
for key in token.keys[:-1]:
|
|
node = node.setdefault(key, {})
|
|
node[token.keys[-1]] = token.value if token.implicit_value is UNSET else token.implicit_value
|
|
for child in self.children:
|
|
child._marked = True
|
|
if not child.has_tokens:
|
|
continue
|
|
keys = child.keys[len(self.keys) :]
|
|
if child._accepts_keywords:
|
|
result = child._json()
|
|
if result:
|
|
out[keys[0]] = result
|
|
elif (get_origin(child.hint) or child.hint) in ITERABLE_TYPES:
|
|
out.setdefault(keys[-1], []).extend([token.value for token in child.tokens])
|
|
else:
|
|
token = child.tokens[0]
|
|
out[keys[0]] = token.value if token.implicit_value is UNSET else token.implicit_value
|
|
return out
|
|
|
|
def _run_missing_keys_checker(self, data):
|
|
if not self._missing_keys_checker or (not self.required and not data):
|
|
return
|
|
if not (missing_keys := self._missing_keys_checker(self, data)):
|
|
return
|
|
# Report the first missing argument.
|
|
missing_key = missing_keys[0]
|
|
keys = self.keys + (missing_key,)
|
|
missing_arguments = self.children.filter_by(keys_prefix=keys)
|
|
if missing_arguments:
|
|
raise MissingArgumentError(argument=missing_arguments[0])
|
|
else:
|
|
missing_description = self.field_info.names[0] + "->" + "->".join(keys)
|
|
raise ValueError(
|
|
f'Required field "{missing_description}" is not accessible by Cyclopts; possibly due to conflicting POSITIONAL/KEYWORD requirements.'
|
|
)
|
|
|
|
def _handle_pydantic_validation_error(self, exc):
|
|
import pydantic
|
|
|
|
error = exc.errors()[0]
|
|
if error["type"] == "missing":
|
|
missing_argument = self.children_recursive.filter_by(keys_prefix=self.keys + error["loc"])[0]
|
|
raise MissingArgumentError(argument=missing_argument) from exc
|
|
elif isinstance(exc, pydantic.ValidationError):
|
|
raise ValidationError(exception_message=str(exc), argument=self) from exc
|
|
else:
|
|
raise exc
|
|
|
|
|
|
def _resolve_groups_from_callable(
|
|
func: Callable[..., Any],
|
|
*default_parameters: Optional[Parameter],
|
|
group_arguments: Optional[Group] = None,
|
|
group_parameters: Optional[Group] = None,
|
|
) -> list[Group]:
|
|
argument_collection = ArgumentCollection._from_callable(
|
|
func,
|
|
*default_parameters,
|
|
group_arguments=group_arguments,
|
|
group_parameters=group_parameters,
|
|
parse_docstring=False,
|
|
_resolve_groups=False,
|
|
)
|
|
|
|
resolved_groups = []
|
|
if group_arguments is not None:
|
|
resolved_groups.append(group_arguments)
|
|
if group_parameters is not None:
|
|
resolved_groups.append(group_parameters)
|
|
|
|
# Iteration 1: Collect all explicitly instantiated groups
|
|
for argument in argument_collection:
|
|
for group in argument.parameter.group: # pyright: ignore
|
|
if not isinstance(group, Group):
|
|
continue
|
|
|
|
# Ensure a different, but same-named group doesn't already exist
|
|
if any(group != x and x._name == group._name for x in resolved_groups):
|
|
raise ValueError("Cannot register 2 distinct Group objects with same name.")
|
|
|
|
if group.default_parameter is not None and group.default_parameter.group:
|
|
# This shouldn't be possible due to ``Group`` internal checks.
|
|
raise ValueError("Group.default_parameter cannot have a specified group.") # pragma: no cover
|
|
|
|
# Add the group to resolved_groups if it hasn't been added yet.
|
|
try:
|
|
next(x for x in resolved_groups if x._name == group._name)
|
|
except StopIteration:
|
|
resolved_groups.append(group)
|
|
|
|
# Iteration 2: Create all implicitly defined Group from strings.
|
|
for argument in argument_collection:
|
|
for group in argument.parameter.group: # pyright: ignore
|
|
if not isinstance(group, str):
|
|
continue
|
|
try:
|
|
next(x for x in resolved_groups if x.name == group)
|
|
except StopIteration:
|
|
resolved_groups.append(Group(group))
|
|
|
|
return resolved_groups
|
|
|
|
|
|
def _extract_docstring_help(f: Callable) -> dict[tuple[str, ...], Parameter]:
|
|
from docstring_parser import parse_from_object
|
|
|
|
# Handle functools.partial
|
|
with suppress(AttributeError):
|
|
f = f.func # pyright: ignore[reportFunctionMemberAccess]
|
|
|
|
try:
|
|
return {
|
|
tuple(dparam.arg_name.split(".")): Parameter(help=dparam.description)
|
|
for dparam in parse_from_object(f).params
|
|
}
|
|
except TypeError:
|
|
# Type hints like ``dict[str, str]`` trigger this.
|
|
return {}
|
|
|
|
|
|
def _resolve_parameter_name_helper(elem):
|
|
if elem.endswith("*"):
|
|
elem = elem[:-1].rstrip(".")
|
|
if elem and not elem.startswith("-"):
|
|
elem = "--" + elem
|
|
return elem
|
|
|
|
|
|
def _resolve_parameter_name(*argss: tuple[str, ...]) -> tuple[str, ...]:
|
|
"""Resolve parameter names by combining and formatting multiple tuples of strings.
|
|
|
|
Parameters
|
|
----------
|
|
*argss
|
|
Each tuple represents a group of parameter name components.
|
|
|
|
Returns
|
|
-------
|
|
tuple[str, ...]
|
|
A tuple of resolved parameter names.
|
|
"""
|
|
argss = tuple(x for x in argss if x)
|
|
|
|
if len(argss) == 0:
|
|
return ()
|
|
elif len(argss) == 1:
|
|
return tuple("*" if x == "*" else _resolve_parameter_name_helper(x) for x in argss[0])
|
|
|
|
# Combine the first 2, and do a recursive call.
|
|
out = []
|
|
for a1 in argss[0]:
|
|
a1 = _resolve_parameter_name_helper(a1)
|
|
for a2 in argss[1]:
|
|
if a2.startswith("-") or not a1:
|
|
out.append(a2)
|
|
else:
|
|
out.append(a1 + "." + a2)
|
|
|
|
return _resolve_parameter_name(tuple(out), *argss[2:])
|