99 SourceIR ,
1010 cast_trt_tensor ,
1111 get_trt_tensor ,
12+ set_layer_name ,
1213)
1314from torch_tensorrt .fx .types import TRTTensor
1415
@@ -22,36 +23,45 @@ def arange(
2223 end : Union [int , TRTTensor ],
2324 step : Union [int , TRTTensor ],
2425) -> TRTTensor :
25- if any (isinstance (tensor , TRTTensor ) for tensor in (start , end , step )):
26+ """
27+ Creates a sequence of values (arange) either dynamically or statically,
28+ then outputs a TensorRT tensor.
29+
30+ If any of (start, end, step) is a TRT tensor, it sets up a dynamic arange
31+ using a Fill layer. Otherwise, it creates a static NumPy array and converts
32+ it into a TensorRT constant tensor.
33+ """
34+ # If any argument is a TRT tensor, use dynamic arange with a Fill layer
35+ if any (isinstance (x , TRTTensor ) for x in (start , end , step )):
36+ # Convert start, end, step into TRT tensors with appropriate rank
2637 start_rank_0 = get_trt_tensor (ctx , start , name + "_start_rank_0" , min_rank = 0 )
2738 start_rank_1 = get_trt_tensor (ctx , start , name + "_start_rank_1" , min_rank = 1 )
2839 end = get_trt_tensor (ctx , end , name + "_end" , min_rank = 1 )
2940 step = get_trt_tensor (ctx , step , name + "_step" , min_rank = 1 )
30- # Calculate shape = (end-start) / step
41+
42+ # Compute (end - start) / step to determine the output length
3143 shape = impl .elementwise .sub (
32- ctx ,
33- target ,
34- source_ir ,
35- name + "_sub" ,
36- end ,
37- start_rank_1 ,
44+ ctx , target , source_ir , name + "_sub" , end , start_rank_1
3845 )
3946 shape = impl .elementwise .trunc_div (
40- ctx ,
41- target ,
42- source_ir ,
43- name + "_shape" ,
44- shape ,
45- step ,
47+ ctx , target , source_ir , name + "_shape" , shape , step
4648 )
4749 shape = cast_trt_tensor (ctx , shape , end .dtype , name + "_shape_casted" )
50+
51+ # Build a Fill layer in LINSPACE mode
4852 fill_layer = ctx .net .add_fill (
4953 shape .shape , trt .FillOperation .LINSPACE , shape .dtype
5054 )
51- fill_layer .set_input (0 , shape )
52- # Set start index
53- fill_layer .set_input (1 , start_rank_0 )
54- # Set delta/step
55- fill_layer .set_input (2 , step )
55+ fill_layer .set_input (0 , shape ) # output length
56+ fill_layer .set_input (1 , start_rank_0 ) # start value
57+ fill_layer .set_input (2 , step ) # step size
58+
5659 return fill_layer .get_output (0 )
57- return np .arange (start , end , step )
60+
61+ else :
62+ # All arguments are static, so use NumPy arange and create a TRT constant
63+ arr = np .arange (start , end , step , dtype = np .int32 )
64+ weights = trt .Weights (arr )
65+ const_layer = ctx .net .add_constant (arr .shape , weights )
66+ set_layer_name (const_layer , target , f"{ name } _arange_const" , source_ir )
67+ return const_layer .get_output (0 )
0 commit comments