1414import sys
1515import tempfile
1616import time
17+ import types
1718
1819from collections import defaultdict , OrderedDict
1920from dataclasses import asdict , dataclass , fields
3536)
3637from tritonbench .components .export import export_data
3738
39+ from tritonbench .utils .constants import (DEFAULT_WARMUP ,DEFAULT_REP ,DEFAULT_QUANTILES ,DEFAULT_SLEEP )
3840from tritonbench .utils .env_utils import (
3941 apply_precision ,
4042 is_fbcode ,
4345 set_random_seed ,
4446)
4547from tritonbench .utils .input import input_cast
48+ from tritonbench .utils .parser import get_parser
4649from tritonbench .utils .path_utils import add_cmd_parameter , remove_cmd_parameter
4750
4851if is_hip ():
@@ -77,11 +80,6 @@ class BenchmarkOperatorBackend:
7780 # ci = False implies enabled = False
7881 ci : bool = True
7982
80-
81- DEFAULT_WARMUP = 25
82- DEFAULT_REP = 100
83- DEFAULT_QUANTILES = [0.5 , 0.1 , 0.9 ]
84- DEFAULT_SLEEP = 0.0
8583REGISTERED_BENCHMARKS : Dict [str , OrderedDict [str , BenchmarkOperatorBackend ]] = {}
8684REGISTERED_METRICS : defaultdict [str , List [str ]] = defaultdict (list )
8785OVERRIDDEN_METRICS : defaultdict [str , List [str ]] = defaultdict (list )
@@ -590,10 +588,11 @@ def register_benchmark(
590588 label : Optional [str ] = None ,
591589):
592590 def decorator (function ):
591+
593592 op_name = (
594- _find_op_name_from_module_path ( function . __module__ )
595- if not operator_name
596- else operator_name
593+ operator_name
594+ if operator_name
595+ else _find_op_name_from_module_path ( function . __module__ )
597596 )
598597 fn_name = function .__name__ if not func_name else func_name
599598 backend_config = BenchmarkOperatorBackend (
@@ -667,6 +666,11 @@ def _has_and_true(attr):
667666 if _has_and_true ("fwd_no_grad" ):
668667 tb_args .mode = "fwd_no_grad"
669668
669+ def override_args (args_to_override ):
670+ parser = get_parser ()
671+ tb_args , extra_args = parser .parse_known_args (args_to_override )
672+ return tb_args , extra_args
673+
670674
671675class BenchmarkOperator (metaclass = PostInitProcessor ):
672676 mode : Mode = Mode .FWD
@@ -692,11 +696,19 @@ class BenchmarkOperator(metaclass=PostInitProcessor):
692696 """
693697
694698 def __init__ (
695- self , tb_args : argparse .Namespace , extra_args : Optional [List [str ]] = None
699+ self , tb_args : argparse .Namespace = None , extra_args : Optional [List [str ]] = None , args_list_override : List [ str ] = None
696700 ):
697701 set_env ()
698702 set_random_seed ()
699- self .name = _find_op_name_from_module_path (self .__class__ .__module__ )
703+ if args_list_override :
704+ tb_args , extra_args = override_args (args_list_override )
705+ elif not tb_args :
706+ raise ValueError ('no args selected. Either pass in argparse namespace or give list override' )
707+
708+ if tb_args .benchmark_name :
709+ self .name = tb_args .benchmark_name
710+ else :
711+ self .name = _find_op_name_from_module_path (self .__class__ .__module__ )
700712 self ._raw_extra_args = copy .deepcopy (extra_args )
701713 self .tb_args = tb_args
702714 self .add_production_shapes = (
@@ -807,6 +819,24 @@ def fwd_no_grad_fn():
807819
808820 setattr (fwd_no_grad_fn , "_name" , bm_func_name )
809821 return fwd_no_grad_fn
822+
823+ def set_input_iter (self , input_iter : Callable ):
824+ def input_decorator (input_iter ):
825+ def input_callable (self ):
826+ return input_iter ()
827+ return input_callable
828+ self .get_input_iter = input_decorator (input_iter )
829+ self .get_input_iter = input_decorator (input_iter ).__get__ (self , BenchmarkOperator )
830+ self .input_iter = input_iter
831+ self ._available_num_inputs = sum (1 for _ in self .get_input_iter ())
832+ self ._num_inputs = self ._available_num_inputs - self ._input_id
833+
834+ def add_benchmark (self , bm_func_name : str , bm_callable : Callable ):
835+ decorator_kwargs = {"operator_name" :self .name ,"func_name" :bm_func_name ,"enabled" :True }
836+ decorated_func = register_benchmark (** decorator_kwargs )(bm_callable )
837+ bound_method = types .MethodType (decorated_func , self )
838+ setattr (self , bm_func_name or bm_callable .__name__ , bound_method )
839+ REGISTERED_BENCHMARKS [bm_func_name ] = bm_callable
810840
811841 def run (
812842 self ,
@@ -959,9 +989,10 @@ def get_bwd_fn(self, fwd_fn: Callable) -> Callable:
959989
960990 def get_input_iter (self ) -> Generator :
961991 """Return the dynamic input iterator for the model."""
962- raise NotImplementedError (
992+ logger . warning (
963993 "Each operator must implement its own input iterator."
964994 )
995+ return []
965996
966997 def get_grad_to_none (self , args ):
967998 return None
0 commit comments