Skip to content

Commit 1eb3609

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Fix approximation utils pyre fix me issues
Differential Revision: D67706741
1 parent b8df17b commit 1eb3609

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

captum/attr/_utils/approximation_methods.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# pyre-strict
44
from enum import Enum
5-
from typing import Callable, List, Tuple
5+
from typing import Callable, cast, List, Tuple
66

77
import torch
88

@@ -121,19 +121,20 @@ def gauss_legendre_builders() -> (
121121

122122
# allow using riemann even without np
123123
import numpy as np
124+
from numpy.typing import NDArray
124125

125126
def step_sizes(n: int) -> List[float]:
126127
assert n > 0, "The number of steps has to be larger than zero"
127128
# Scaling from 2 to 1
128-
# pyre-fixme[6]: For 1st argument expected `Iterable[Variable[_T]]` but got
129-
# `float`.
130-
return list(0.5 * np.polynomial.legendre.leggauss(n)[1])
129+
return cast(
130+
NDArray[np.float64], 0.5 * np.polynomial.legendre.leggauss(n)[1]
131+
).tolist()
131132

132133
def alphas(n: int) -> List[float]:
133134
assert n > 0, "The number of steps has to be larger than zero"
134135
# Scaling from [-1, 1] to [0, 1]
135-
# pyre-fixme[6]: For 1st argument expected `Iterable[Variable[_T]]` but got
136-
# `float`.
137-
return list(0.5 * (1 + np.polynomial.legendre.leggauss(n)[0]))
136+
return cast(
137+
NDArray[np.float64], 0.5 * (1 + np.polynomial.legendre.leggauss(n)[0])
138+
).tolist()
138139

139140
return step_sizes, alphas

0 commit comments

Comments
 (0)