Skip to content

Commit ece2f7d

Browse files
refactor(core): do not change signature of functions decorated with custom_experiment (#261)
* refactor(core): custom_experiment decorator does not change signature Implementation of #252 * chore(core): apply suggestions from code review * feat(core): Validate custom_experiment parameters in wrapper
1 parent 7a79a23 commit ece2f7d

File tree

2 files changed

+131
-69
lines changed

2 files changed

+131
-69
lines changed

orchestrator/modules/actuators/custom_experiments.py

Lines changed: 100 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import logging
66
import typing
77
import uuid
8-
from functools import wraps
98

109
import pydantic
1110
import ray
@@ -31,6 +30,7 @@
3130
ObservedProperty,
3231
ObservedPropertyValue,
3332
)
33+
from orchestrator.schema.point import SpacePoint
3434
from orchestrator.schema.property import (
3535
AbstractPropertyDescriptor,
3636
ConstitutiveProperty,
@@ -313,47 +313,6 @@ def calculate_density(mass, volume):
313313
logger = logging.getLogger("custom_experiment_decorator")
314314

315315
def decorator(func):
316-
317-
@wraps(func)
318-
def wrapper(
319-
entity: Entity, experiment: Experiment
320-
) -> list[ObservedPropertyValue]:
321-
"""
322-
Wrapper function that converts Entity+Experiment to dict and calls the wrapped function.
323-
"""
324-
input_values = experiment.propertyValuesFromEntity(entity)
325-
result_dict = func(**input_values)
326-
observed_property_values = []
327-
for property_identifier, value in result_dict.items():
328-
observed_property = experiment.observedPropertyForTargetIdentifier(
329-
property_identifier
330-
)
331-
if not observed_property:
332-
raise ValueError(
333-
f"{experiment.identifier} returned a property called {property_identifier}, however "
334-
f"the experiment definition does not define an output property with this name"
335-
)
336-
337-
observed_property_value = ObservedPropertyValue(
338-
property=observed_property, value=value
339-
)
340-
observed_property_values.append(observed_property_value)
341-
342-
return observed_property_values
343-
344-
# Set the wrapper's __signature__ so it is (entity,experiment)
345-
# This is required for the wrapped function to be used with ray.remote
346-
import inspect
347-
348-
wrapper.__signature__ = inspect.Signature(
349-
[
350-
inspect.Parameter("entity", inspect.Parameter.POSITIONAL_OR_KEYWORD),
351-
inspect.Parameter(
352-
"experiment", inspect.Parameter.POSITIONAL_OR_KEYWORD
353-
),
354-
]
355-
)
356-
357316
# If we were not given information on required/optional properties
358317
# or parameterization try to infer it
359318
# This function will log a critical error message and raise exception
@@ -380,12 +339,12 @@ def wrapper(
380339
)
381340
raise
382341

383-
# Store decorator arguments as function attributes
384-
wrapper._decorator_required_properties = _required_properties
385-
wrapper._decorator_optional_properties = _optional_properties
386-
wrapper._decorator_parameterization = _parameterization
387-
wrapper._original_func = func
388-
wrapper._is_custom_experiment = True
342+
# Store decorator arguments as function attributes (on func itself)
343+
func._decorator_required_properties = _required_properties
344+
func._decorator_optional_properties = _optional_properties
345+
func._decorator_parameterization = _parameterization
346+
func._original_func = func
347+
func._is_custom_experiment = True
389348

390349
# Create an ExperimentModuleConf instance describing where the function is
391350
metadata["module"] = ExperimentModuleConf(
@@ -415,12 +374,44 @@ def wrapper(
415374
deprecated=False,
416375
metadata=metadata,
417376
)
418-
wrapper._experiment = experiment
377+
func._experiment = experiment
419378

420379
# Add the experiment to the module-level catalog
421380
_custom_experiments_catalog.addExperiment(experiment)
422381

423-
return wrapper
382+
from functools import wraps
383+
384+
@wraps(func)
385+
def validated_func(*args, **kwargs):
386+
# Build property dict from either kwargs or args
387+
# Prefer kwargs, but support positional for backwards compatibility
388+
import inspect
389+
390+
sig = inspect.signature(func)
391+
bound_args = sig.bind(*args, **kwargs)
392+
bound_args.apply_defaults()
393+
param_dict = dict(bound_args.arguments)
394+
395+
# Validate using SpacePoint and Experiment.validate_entity
396+
spoint = SpacePoint(entity=param_dict)
397+
entity = spoint.to_entity()
398+
if not experiment.validate_entity(entity, verbose=True):
399+
raise ValueError(
400+
f"Arguments {param_dict} do not match required/optional properties for experiment '{experiment.identifier}'. "
401+
f"See logs/stderr for reasons, or check experiment.requiredProperties/optionalProperties."
402+
)
403+
# Call the original with the unpacked arguments
404+
return func(*args, **kwargs)
405+
406+
# Attach metadata to validated_func not func, so end users get the right attributes
407+
validated_func._decorator_required_properties = _required_properties
408+
validated_func._decorator_optional_properties = _optional_properties
409+
validated_func._decorator_parameterization = _parameterization
410+
validated_func._original_func = func
411+
validated_func._is_custom_experiment = True
412+
validated_func._experiment = experiment
413+
414+
return validated_func
424415

425416
return decorator
426417

@@ -513,7 +504,51 @@ def get_custom_experiments_catalog() -> (
513504
return _custom_experiments_catalog
514505

515506

516-
async def custom_experiment_wrapper(
507+
def _call_decorated_custom_experiment(
508+
function: typing.Callable, target_experiment: Experiment, entity: Entity
509+
) -> list[ObservedPropertyValue]:
510+
511+
# Build input dict using experiment values from entity
512+
input_values = target_experiment.propertyValuesFromEntity(entity)
513+
# Call function with unpacked parameters
514+
result_dict = function(**input_values)
515+
516+
# Create observed property values
517+
observed_property_values = []
518+
for property_identifier, value in result_dict.items():
519+
observed_property = target_experiment.observedPropertyForTargetIdentifier(
520+
property_identifier
521+
)
522+
if not observed_property:
523+
raise ValueError(
524+
f"{target_experiment.identifier} returned a property called {property_identifier}, however "
525+
f"the experiment definition does not define an output property with this name"
526+
)
527+
observed_property_value = ObservedPropertyValue(
528+
property=observed_property, value=value
529+
)
530+
observed_property_values.append(observed_property_value)
531+
532+
return observed_property_values
533+
534+
535+
def _call_legacy_custom_experiment(
536+
function: typing.Callable,
537+
target_experiment: Experiment,
538+
entity: Entity,
539+
parameters: dict | None = None,
540+
) -> list[ObservedPropertyValue]:
541+
# For legacy case or other functions, check for parameters kwarg else pass entity/experiment
542+
func_signature = inspect.signature(function)
543+
if "parameters" in func_signature.parameters:
544+
values = function(entity, target_experiment, parameters=parameters)
545+
else:
546+
values = function(entity, target_experiment)
547+
548+
return values
549+
550+
551+
async def custom_experiment_executor(
517552
function: typing.Callable,
518553
parameters: dict,
519554
measurement_request: MeasurementRequest,
@@ -532,16 +567,22 @@ async def custom_experiment_wrapper(
532567

533568
measurement_results = []
534569
for entity in measurement_request.entities:
535-
# Inspect function to see if it has a keyword parameter "parameters"
536-
func_signature = inspect.signature(function)
537-
func_param_names = set(func_signature.parameters.keys())
538570
try:
539-
if "parameters" in func_param_names:
540-
values = function(entity, target_experiment, parameters=parameters)
571+
# Check if this is a custom experiment decorated function
572+
if getattr(function, "_is_custom_experiment", False):
573+
values = _call_decorated_custom_experiment(
574+
function=function,
575+
target_experiment=target_experiment,
576+
entity=entity,
577+
)
541578
else:
542-
values = function(entity, target_experiment)
579+
values = _call_legacy_custom_experiment(
580+
function=function,
581+
target_experiment=target_experiment,
582+
entity=entity,
583+
parameters=parameters,
584+
)
543585

544-
# Record the results in the entity
545586
if len(values) > 0:
546587
measurement_result = ValidMeasurementResult(
547588
entityIdentifier=entity.identifier, measurements=values
@@ -683,7 +724,7 @@ async def submit(
683724
**targetExperiment.model_dump(),
684725
)
685726

686-
await custom_experiment_wrapper(
727+
await custom_experiment_executor(
687728
self._functionImplementations[
688729
request.experimentReference.experimentIdentifier
689730
],

website/docs/actuators/creating-custom-experiments.md

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ my_experiment = "my_custom_package.experiments"
4141

4242
In the simplest case:
4343

44-
- type the parameters (using python `typing`)
44+
- type your parameters using python `typing`
4545
- return the output in a dictionary of key value pairs
46-
- define the keys of this dictionary in the `output_property_identifiers`
47-
parameter of the decorator
46+
- define the set of output property keys in the
47+
`output_property_identifiers` parameter of the decorator
4848

4949
```python
5050
from typing import Dict, Any
@@ -286,15 +286,36 @@ This is illustrated in the above example.
286286

287287
## Using your decorated function in code
288288

289-
The decorated function is wrapped to take `ado` internal
290-
data structures, and you would not typically need to
291-
call it directly. However, the decorated experiment function is
292-
still regular Python and can be called:
289+
The decorated function can be called
290+
directly in Python as normal e.g.,
293291

294292
```python
295-
# Access the original function (undecorated)
296-
original = calculate_density._original_func
297-
print(original(8, 4)) # {'density': 2}
293+
result = calculate_density(8, 4) # {'density': 2}
294+
```
295+
296+
The `custom_experiment` decorator attaches the
297+
ado `Experiment` object generated from the decoration as an attribute e.g.
298+
299+
```python
300+
from orchestrator.schema.experiment import Experiment
301+
302+
exp_obj: Experiment = calculate_density._experiment
303+
print(exp_obj.identifier) # e.g., 'calculate_density'
304+
print(exp_obj.requiredProperties)
305+
print(exp_obj.optionalProperties)
306+
print(exp_obj.targetProperties)
307+
```
308+
309+
When you call the decorated function, its arguments are
310+
automatically validated against the required and optional inputs
311+
specified in the decorator, including domain constraints.
312+
If you call it with missing, extra, or out-of-domain arguments,
313+
the function will raise a `ValueError` describing what was invalid and why.
314+
For example:
315+
316+
```python
317+
# Value outside domain - an error will be raised
318+
result = calculate_density(mass=0, volume=10)
298319
```
299320

300321
## Next Steps

0 commit comments

Comments
 (0)