File tree Expand file tree Collapse file tree 1 file changed +8
-8
lines changed
py/torch_tensorrt/dynamo/conversion/impl/normalization Expand file tree Collapse file tree 1 file changed +8
-8
lines changed Original file line number Diff line number Diff line change @@ -48,16 +48,16 @@ def batch_norm(
4848 assert input .shape [1 ] != - 1 , "Channel dim can't be dynamic for batch norm."
4949
5050 if weight is None :
51- weight = np . array ( 1.0 )
51+ weight = 1.0
5252
5353 if bias is None :
54- bias = np . array ( 0.0 )
54+ bias = 0.0
5555
5656 if running_mean is None :
57- running_mean = np . array ( 0.0 )
57+ running_mean = 0.0
5858
5959 if running_var is None :
60- running_var = np . array ( 1.0 )
60+ running_var = 1.0
6161
6262 scale = cast (torch .Tensor , to_numpy (weight )) / np .sqrt (
6363 cast (torch .Tensor , to_numpy (running_var )) + eps
@@ -115,10 +115,10 @@ def layer_norm(
115115 )
116116
117117 if weight is None :
118- weight = np . array (1.0 )
118+ weight = to_numpy (1.0 )
119119
120120 if bias is None :
121- bias = np . array (0.0 )
121+ bias = to_numpy (0.0 )
122122
123123 gamma = (
124124 weight .detach ().cpu ().float ().numpy ()
@@ -181,10 +181,10 @@ def layer_norm_no_plugin(
181181 )
182182
183183 if weight is None :
184- weight = np . array (1.0 )
184+ weight = to_numpy (1.0 )
185185
186186 if bias is None :
187- bias = np . array (0.0 )
187+ bias = to_numpy (0.0 )
188188
189189 shape = weight .shape
190190 broadcasted_shape = (1 ,) * (len (input .shape ) - len (shape )) + shape
You can’t perform that action at this time.
0 commit comments