File tree Expand file tree Collapse file tree 2 files changed +6
-1
lines changed Expand file tree Collapse file tree 2 files changed +6
-1
lines changed Original file line number Diff line number Diff line change @@ -808,7 +808,10 @@ def _construct_ablated_input(
808808 current_mask = current_mask .to (expanded_input .device )
809809 assert baseline is not None , "baseline must be provided"
810810 ablated_tensor = (
811- expanded_input * (1 - current_mask ).to (expanded_input .dtype )
811+ expanded_input
812+ * (1 - current_mask ).to (expanded_input .dtype )
813+ # pyre-fixme[58]: `*` is not supported for operand types `Union[None, float,
814+ # Tensor]` and `Tensor`.
812815 ) + (baseline * current_mask .to (expanded_input .dtype ))
813816 return ablated_tensor , current_mask
814817
Original file line number Diff line number Diff line change @@ -323,6 +323,8 @@ def _construct_ablated_input(
323323 torch .ones (1 , dtype = torch .long , device = expanded_input .device )
324324 - input_mask
325325 ).to (expanded_input .dtype )
326+ # pyre-fixme[58]: `*` is not supported for operand types `Union[None, float,
327+ # Tensor]` and `Tensor`.
326328 ) + (baseline * input_mask .to (expanded_input .dtype ))
327329 return ablated_tensor , input_mask
328330
You can’t perform that action at this time.
0 commit comments