2323from torch .utils .data import DataLoader , Dataset
2424
2525
26- # pyre-fixme[3]: Return type must be annotated.
2726# pyre-fixme[2]: Parameter must be annotated.
28- def _isSorted (x , key = lambda x : x , descending = True ):
27+ def _isSorted (x , key = lambda x : x , descending = True ) -> bool :
2928 if descending :
30- return all ([ key (x [i ]) >= key (x [i + 1 ]) for i in range (len (x ) - 1 )] )
29+ return all (key (x [i ]) >= key (x [i + 1 ]) for i in range (len (x ) - 1 ))
3130 else :
32- return all ([ key (x [i ]) <= key (x [i + 1 ]) for i in range (len (x ) - 1 )] )
31+ return all (key (x [i ]) <= key (x [i + 1 ]) for i in range (len (x ) - 1 ))
3332
3433
35- # pyre-fixme[3]: Return type must be annotated.
3634# pyre-fixme[2]: Parameter must be annotated.
37- def _wrap_model_in_dataparallel (net ):
35+ def _wrap_model_in_dataparallel (net ) -> Module :
3836 alt_device_ids = [0 ] + [x for x in range (torch .cuda .device_count () - 1 , 0 , - 1 )]
3937 net = net .cuda ()
4038 return torch .nn .DataParallel (net , device_ids = alt_device_ids )
@@ -60,9 +58,7 @@ def __init__(
6058 def __len__ (self ) -> int :
6159 return len (self .samples )
6260
63- # pyre-fixme[3]: Return type must be annotated.
64- # pyre-fixme[2]: Parameter must be annotated.
65- def __getitem__ (self , idx ):
61+ def __getitem__ (self , idx : int ) -> Tuple [Tensor , Tensor ]:
6662 return (self .samples [idx ], self .labels [idx ])
6763
6864
@@ -83,8 +79,7 @@ def __len__(self) -> int:
8379 return len (self .samples [0 ])
8480
8581 # pyre-fixme[3]: Return type must be annotated.
86- # pyre-fixme[2]: Parameter must be annotated.
87- def __getitem__ (self , idx ):
82+ def __getitem__ (self , idx : int ):
8883 """
8984 The signature of the returning item is: List[List], where the contents
9085 are: [sample_0, sample_1, ...] + [labels] (two lists concacenated).
@@ -98,10 +93,8 @@ def __init__(
9893 num_features : int ,
9994 use_gpu : bool = False ,
10095 ) -> None :
101- # pyre-fixme[4]: Attribute must be annotated.
102- self .samples = torch .diag (torch .ones (num_features ))
103- # pyre-fixme[4]: Attribute must be annotated.
104- self .labels = torch .zeros (num_features ).unsqueeze (1 )
96+ self .samples : Tensor = torch .diag (torch .ones (num_features ))
97+ self .labels : Tensor = torch .zeros (num_features ).unsqueeze (1 )
10598 if use_gpu :
10699 self .samples = self .samples .cuda ()
107100 self .labels = self .labels .cuda ()
@@ -115,23 +108,22 @@ def __init__(
115108 num_features : int ,
116109 use_gpu : bool = False ,
117110 ) -> None :
118- # pyre-fixme[4]: Attribute must be annotated.
119- self .samples = (
111+ self .samples : Tensor = (
120112 torch .arange (start = low , end = high , dtype = torch .float )
121113 .repeat (num_features , 1 )
122114 .transpose (1 , 0 )
123115 )
124- # pyre-fixme[4]: Attribute must be annotated.
125- self .labels = torch .arange (start = low , end = high , dtype = torch .float ).unsqueeze (1 )
116+ self .labels : Tensor = torch .arange (
117+ start = low , end = high , dtype = torch .float
118+ ).unsqueeze (1 )
126119 if use_gpu :
127120 self .samples = self .samples .cuda ()
128121 self .labels = self .labels .cuda ()
129122
130123
131124class BinaryDataset (ExplicitDataset ):
132125 def __init__ (self , use_gpu : bool = False ) -> None :
133- # pyre-fixme[4]: Attribute must be annotated.
134- self .samples = F .normalize (
126+ self .samples : Tensor = F .normalize (
135127 torch .stack (
136128 (
137129 torch .Tensor ([1 , 1 ]),
@@ -161,8 +153,7 @@ def __init__(self, use_gpu: bool = False) -> None:
161153 )
162154 )
163155 )
164- # pyre-fixme[4]: Attribute must be annotated.
165- self .labels = torch .cat (
156+ self .labels : Tensor = torch .cat (
166157 (
167158 torch .Tensor ([1 ]).repeat (12 , 1 ),
168159 torch .Tensor ([- 1 ]).repeat (12 , 1 ),
@@ -350,13 +341,10 @@ def get_random_model_and_data(
350341 tmpdir ,
351342 # pyre-fixme[2]: Parameter must be annotated.
352343 unpack_inputs ,
353- # pyre-fixme[2]: Parameter must be annotated.
354- return_test_data = True ,
344+ return_test_data : bool = True ,
355345 gpu_setting : Optional [str ] = None ,
356- # pyre-fixme[2]: Parameter must be annotated.
357- return_hessian_data = False ,
358- # pyre-fixme[2]: Parameter must be annotated.
359- model_type = "random" ,
346+ return_hessian_data : bool = False ,
347+ model_type : str = "random" ,
360348):
361349 """
362350 returns a model, training data, and optionally data for computing the hessian
@@ -534,10 +522,9 @@ def generate_symmetric_matrix_given_eigenvalues(
534522 return torch .matmul (Q , torch .matmul (torch .diag (torch .tensor (eigenvalues )), Q .T ))
535523
536524
537- # pyre-fixme[3]: Return type must be annotated.
538525def generate_assymetric_matrix_given_eigenvalues (
539526 eigenvalues : Union [Tensor , List [float ]]
540- ):
527+ ) -> Tensor :
541528 """
542529 following https://github.com/google-research/jax-influence/blob/74bd321156b5445bb35b9594568e4eaaec1a76a3/jax_influence/test_utils.py#L105 # noqa: E501
543530 generate assymetric random matrix with specified eigenvalues. this is used in
0 commit comments