55import logging
66import typing
77import uuid
8- from functools import wraps
98
109import pydantic
1110import ray
3130 ObservedProperty ,
3231 ObservedPropertyValue ,
3332)
33+ from orchestrator .schema .point import SpacePoint
3434from orchestrator .schema .property import (
3535 AbstractPropertyDescriptor ,
3636 ConstitutiveProperty ,
@@ -313,47 +313,6 @@ def calculate_density(mass, volume):
313313 logger = logging .getLogger ("custom_experiment_decorator" )
314314
315315 def decorator (func ):
316-
317- @wraps (func )
318- def wrapper (
319- entity : Entity , experiment : Experiment
320- ) -> list [ObservedPropertyValue ]:
321- """
322- Wrapper function that converts Entity+Experiment to dict and calls the wrapped function.
323- """
324- input_values = experiment .propertyValuesFromEntity (entity )
325- result_dict = func (** input_values )
326- observed_property_values = []
327- for property_identifier , value in result_dict .items ():
328- observed_property = experiment .observedPropertyForTargetIdentifier (
329- property_identifier
330- )
331- if not observed_property :
332- raise ValueError (
333- f"{ experiment .identifier } returned a property called { property_identifier } , however "
334- f"the experiment definition does not define an output property with this name"
335- )
336-
337- observed_property_value = ObservedPropertyValue (
338- property = observed_property , value = value
339- )
340- observed_property_values .append (observed_property_value )
341-
342- return observed_property_values
343-
344- # Set the wrapper's __signature__ so it is (entity,experiment)
345- # This is required for the wrapped function to be used with ray.remote
346- import inspect
347-
348- wrapper .__signature__ = inspect .Signature (
349- [
350- inspect .Parameter ("entity" , inspect .Parameter .POSITIONAL_OR_KEYWORD ),
351- inspect .Parameter (
352- "experiment" , inspect .Parameter .POSITIONAL_OR_KEYWORD
353- ),
354- ]
355- )
356-
357316 # If we were not given information on required/optional properties
358317 # or parameterization try to infer it
359318 # This function will log a critical error message and raise exception
@@ -380,12 +339,12 @@ def wrapper(
380339 )
381340 raise
382341
383- # Store decorator arguments as function attributes
384- wrapper ._decorator_required_properties = _required_properties
385- wrapper ._decorator_optional_properties = _optional_properties
386- wrapper ._decorator_parameterization = _parameterization
387- wrapper ._original_func = func
388- wrapper ._is_custom_experiment = True
342+ # Store decorator arguments as function attributes (on func itself)
343+ func ._decorator_required_properties = _required_properties
344+ func ._decorator_optional_properties = _optional_properties
345+ func ._decorator_parameterization = _parameterization
346+ func ._original_func = func
347+ func ._is_custom_experiment = True
389348
390349 # Create an ExperimentModuleConf instance describing where the function is
391350 metadata ["module" ] = ExperimentModuleConf (
@@ -415,12 +374,44 @@ def wrapper(
415374 deprecated = False ,
416375 metadata = metadata ,
417376 )
418- wrapper ._experiment = experiment
377+ func ._experiment = experiment
419378
420379 # Add the experiment to the module-level catalog
421380 _custom_experiments_catalog .addExperiment (experiment )
422381
423- return wrapper
382+ from functools import wraps
383+
384+ @wraps (func )
385+ def validated_func (* args , ** kwargs ):
386+ # Build property dict from either kwargs or args
387+ # Prefer kwargs, but support positional for backwards compatibility
388+ import inspect
389+
390+ sig = inspect .signature (func )
391+ bound_args = sig .bind (* args , ** kwargs )
392+ bound_args .apply_defaults ()
393+ param_dict = dict (bound_args .arguments )
394+
395+ # Validate using SpacePoint and Experiment.validate_entity
396+ spoint = SpacePoint (entity = param_dict )
397+ entity = spoint .to_entity ()
398+ if not experiment .validate_entity (entity , verbose = True ):
399+ raise ValueError (
400+ f"Arguments { param_dict } do not match required/optional properties for experiment '{ experiment .identifier } '. "
401+ f"See logs/stderr for reasons, or check experiment.requiredProperties/optionalProperties."
402+ )
403+ # Call the original with the unpacked arguments
404+ return func (* args , ** kwargs )
405+
406+ # Attach metadata to validated_func not func, so end users get the right attributes
407+ validated_func ._decorator_required_properties = _required_properties
408+ validated_func ._decorator_optional_properties = _optional_properties
409+ validated_func ._decorator_parameterization = _parameterization
410+ validated_func ._original_func = func
411+ validated_func ._is_custom_experiment = True
412+ validated_func ._experiment = experiment
413+
414+ return validated_func
424415
425416 return decorator
426417
@@ -513,7 +504,51 @@ def get_custom_experiments_catalog() -> (
513504 return _custom_experiments_catalog
514505
515506
516- async def custom_experiment_wrapper (
507+ def _call_decorated_custom_experiment (
508+ function : typing .Callable , target_experiment : Experiment , entity : Entity
509+ ) -> list [ObservedPropertyValue ]:
510+
511+ # Build input dict using experiment values from entity
512+ input_values = target_experiment .propertyValuesFromEntity (entity )
513+ # Call function with unpacked parameters
514+ result_dict = function (** input_values )
515+
516+ # Create observed property values
517+ observed_property_values = []
518+ for property_identifier , value in result_dict .items ():
519+ observed_property = target_experiment .observedPropertyForTargetIdentifier (
520+ property_identifier
521+ )
522+ if not observed_property :
523+ raise ValueError (
524+ f"{ target_experiment .identifier } returned a property called { property_identifier } , however "
525+ f"the experiment definition does not define an output property with this name"
526+ )
527+ observed_property_value = ObservedPropertyValue (
528+ property = observed_property , value = value
529+ )
530+ observed_property_values .append (observed_property_value )
531+
532+ return observed_property_values
533+
534+
535+ def _call_legacy_custom_experiment (
536+ function : typing .Callable ,
537+ target_experiment : Experiment ,
538+ entity : Entity ,
539+ parameters : dict | None = None ,
540+ ) -> list [ObservedPropertyValue ]:
541+ # For legacy case or other functions, check for parameters kwarg else pass entity/experiment
542+ func_signature = inspect .signature (function )
543+ if "parameters" in func_signature .parameters :
544+ values = function (entity , target_experiment , parameters = parameters )
545+ else :
546+ values = function (entity , target_experiment )
547+
548+ return values
549+
550+
551+ async def custom_experiment_executor (
517552 function : typing .Callable ,
518553 parameters : dict ,
519554 measurement_request : MeasurementRequest ,
@@ -532,16 +567,22 @@ async def custom_experiment_wrapper(
532567
533568 measurement_results = []
534569 for entity in measurement_request .entities :
535- # Inspect function to see if it has a keyword parameter "parameters"
536- func_signature = inspect .signature (function )
537- func_param_names = set (func_signature .parameters .keys ())
538570 try :
539- if "parameters" in func_param_names :
540- values = function (entity , target_experiment , parameters = parameters )
571+ # Check if this is a custom experiment decorated function
572+ if getattr (function , "_is_custom_experiment" , False ):
573+ values = _call_decorated_custom_experiment (
574+ function = function ,
575+ target_experiment = target_experiment ,
576+ entity = entity ,
577+ )
541578 else :
542- values = function (entity , target_experiment )
579+ values = _call_legacy_custom_experiment (
580+ function = function ,
581+ target_experiment = target_experiment ,
582+ entity = entity ,
583+ parameters = parameters ,
584+ )
543585
544- # Record the results in the entity
545586 if len (values ) > 0 :
546587 measurement_result = ValidMeasurementResult (
547588 entityIdentifier = entity .identifier , measurements = values
@@ -683,7 +724,7 @@ async def submit(
683724 ** targetExperiment .model_dump (),
684725 )
685726
686- await custom_experiment_wrapper (
727+ await custom_experiment_executor (
687728 self ._functionImplementations [
688729 request .experimentReference .experimentIdentifier
689730 ],
0 commit comments