Skip to content

Commit dc5cdac

Browse files
committed
CLI support for callable class instances #238.
1 parent fadff04 commit dc5cdac

File tree

5 files changed

+38
-12
lines changed

5 files changed

+38
-12
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ paths are considered internals and can change in minor and patch releases.
1515
v4.20.0 (2023-01-??)
1616
--------------------
1717

18+
Added
19+
^^^^^
20+
- ``CLI`` support for callable class instances `#238
21+
<https://github.com/omni-us/jsonargparse/issues/238>`__.
22+
1823
Fixed
1924
^^^^^
2025
- ``add_subcommands`` fails when parser has required argument and default config

jsonargparse/cli.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ def CLI(
5252
module = inspect.getmodule(caller).__name__ # type: ignore
5353
components = [
5454
v for v in caller.f_locals.values()
55-
if (inspect.isfunction(v) or inspect.isclass(v)) and inspect.getmodule(v).__name__ == module # type: ignore
55+
if (
56+
(inspect.isclass(v) or callable(v)) and
57+
getattr(inspect.getmodule(v), '__name__', None) == module
58+
)
5659
]
5760
if len(components) == 0:
5861
raise ValueError('Either components argument must be given or there must be at least one '
@@ -61,6 +64,10 @@ def CLI(
6164
elif not isinstance(components, list):
6265
components = [components]
6366

67+
unexpected = [c for c in components if not (inspect.isclass(c) or callable(c))]
68+
if unexpected:
69+
raise ValueError(f'Unexpected components, not class or function: {unexpected}')
70+
6471
parser = parser_class(default_meta=False, **kwargs)
6572
parser.add_argument('--config', action=ActionConfigFile, help=config_help)
6673

@@ -108,11 +115,7 @@ def get_help_str(component, logger):
108115

109116
def _add_component_to_parser(component, parser, as_positional, fail_untyped, config_help):
110117
kwargs = dict(as_positional=as_positional, fail_untyped=fail_untyped, sub_configs=True)
111-
if inspect.isfunction(component):
112-
added_args = parser.add_function_arguments(component, as_group=False, **kwargs)
113-
if not parser.description:
114-
parser.description = get_help_str(component, parser.logger)
115-
else:
118+
if inspect.isclass(component):
116119
added_args = parser.add_class_arguments(component, **kwargs)
117120
subcommands = parser.add_subcommands(required=True)
118121
for key in [k for k, v in inspect.getmembers(component) if callable(v) and k[0] != '_']:
@@ -124,12 +127,16 @@ def _add_component_to_parser(component, parser, as_positional, fail_untyped, con
124127
if not added_subargs:
125128
remove_actions(subparser, (ActionConfigFile, _ActionPrintConfig))
126129
subcommands.add_subcommand(key, subparser, help=get_help_str(getattr(component, key), parser.logger))
130+
else:
131+
added_args = parser.add_function_arguments(component, as_group=False, **kwargs)
132+
if not parser.description:
133+
parser.description = get_help_str(component, parser.logger)
127134
return added_args
128135

129136

130137
def _run_component(component, cfg):
131138
cfg.pop('config', None)
132-
if inspect.isfunction(component):
139+
if not inspect.isclass(component):
133140
return component(**cfg)
134141
subcommand = cfg.pop('subcommand')
135142
subcommand_cfg = cfg.pop(subcommand, {})

jsonargparse/typehints.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,7 @@ def __init__(
146146
if sum(subtype_supported) < len(subtype_supported):
147147
discard = {typehint.__args__[n] for n, s in enumerate(subtype_supported) if not s}
148148
kwargs['logger'].debug(f'Discarding unsupported subtypes {discard} from {typehint}')
149-
orig_typehint = typehint # deepcopy does not copy ForwardRef
150-
typehint = deepcopy(orig_typehint)
151-
typehint.__args__ = tuple(
152-
orig_typehint.__args__[n] for n, s in enumerate(subtype_supported) if s
153-
)
149+
typehint = Union[tuple(t for t, s in zip(typehint.__args__, subtype_supported) if s)] # type: ignore
154150
self._typehint = typehint
155151
self._enable_path = False if is_optional(typehint, Path) else enable_path
156152
elif '_typehint' not in kwargs:

jsonargparse_tests/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def mock_module(*args):
3838
__module__ = 'jsonargparse_tests'
3939
for component in args:
4040
component.__module__ = __module__
41+
if not hasattr(component, '__name__'):
42+
component.__name__ = type(component).__name__.lower()
4143
component.__qualname__ = component.__name__
4244
if inspect.isclass(component):
4345
methods = [k for k, v in inspect.getmembers(component) if callable(v) and k[0] != '_']

jsonargparse_tests/test_cli.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import sys
44
import unittest
5+
import unittest.mock
56
from contextlib import redirect_stderr, redirect_stdout
67
from io import StringIO
78
from typing import Optional
@@ -16,6 +17,11 @@
1617

1718
class CLITests(unittest.TestCase):
1819

20+
def test_unexpected(self):
21+
with self.assertRaises(ValueError):
22+
CLI(0)
23+
24+
1925
def test_single_function_cli(self):
2026
def function(a1: float):
2127
return a1
@@ -36,6 +42,16 @@ def run_cli():
3642
self.assertIn('function CLITests.test_single_function_cli', out.getvalue())
3743

3844

45+
def test_callable_instance(self):
46+
class CallableClass:
47+
def __call__(self, x: int):
48+
return x
49+
50+
instance = CallableClass()
51+
with mock_module(instance):
52+
self.assertEqual(3, CLI(instance, as_positional=False, args=['--x=3']))
53+
54+
3955
def test_multiple_functions_cli(self):
4056
def cmd1(a1: int):
4157
return a1

0 commit comments

Comments
 (0)