2
2
3
3
# pyre-strict
4
4
5
- from typing import no_type_check , Optional , Tuple , Union
5
+ from typing import Iterable , no_type_check , Optional , Tuple , Union
6
6
7
7
import torch
8
8
import torch .nn as nn
@@ -417,6 +417,77 @@ def forward(self, input1, input2, input3=None):
417
417
return self .linear2 (self .relu (self .linear1 (embeddings ))).sum (1 )
418
418
419
419
420
+ class GradientUnsupportedLayerOutput (nn .Module ):
421
+ """
422
+ This layer is used to test the case where the model returns a layer that
423
+ is not supported by the gradient computation.
424
+ """
425
+
426
+ def __init__ (self ) -> None :
427
+ super ().__init__ ()
428
+
429
+ @no_type_check
430
+ def forward (
431
+ self , unsupported_layer_output : Optional [Iterable [Tensor ]]
432
+ ) -> Optional [Iterable [Tensor ]]:
433
+ return unsupported_layer_output
434
+
435
+
436
+ class BasicModel_GradientLayerAttribution (nn .Module ):
437
+ def __init__ (
438
+ self ,
439
+ inplace : bool = False ,
440
+ unsupported_layer_output : Optional [Iterable [Tensor ]] = None ,
441
+ ) -> None :
442
+ super ().__init__ ()
443
+ # Linear 0 is simply identity transform
444
+ self .unsupported_layer_output = unsupported_layer_output
445
+ self .linear0 = nn .Linear (3 , 3 )
446
+ self .linear0 .weight = nn .Parameter (torch .eye (3 ))
447
+ self .linear0 .bias = nn .Parameter (torch .zeros (3 ))
448
+ self .linear1 = nn .Linear (3 , 4 )
449
+ self .linear1 .weight = nn .Parameter (torch .ones (4 , 3 ))
450
+ self .linear1 .bias = nn .Parameter (torch .tensor ([- 10.0 , 1.0 , 1.0 , 1.0 ]))
451
+
452
+ self .linear1_alt = nn .Linear (3 , 4 )
453
+ self .linear1_alt .weight = nn .Parameter (torch .ones (4 , 3 ))
454
+ self .linear1_alt .bias = nn .Parameter (torch .tensor ([- 10.0 , 1.0 , 1.0 , 1.0 ]))
455
+
456
+ self .relu = nn .ReLU (inplace = inplace )
457
+ self .relu_alt = nn .ReLU (inplace = inplace )
458
+ self .unsupportedLayer = GradientUnsupportedLayerOutput ()
459
+
460
+ self .linear2 = nn .Linear (4 , 2 )
461
+ self .linear2 .weight = nn .Parameter (torch .ones (2 , 4 ))
462
+ self .linear2 .bias = nn .Parameter (torch .tensor ([- 1.0 , 1.0 ]))
463
+
464
+ self .linear3 = nn .Linear (4 , 2 )
465
+ self .linear3 .weight = nn .Parameter (torch .ones (2 , 4 ))
466
+ self .linear3 .bias = nn .Parameter (torch .tensor ([- 1.0 , 1.0 ]))
467
+
468
+ @no_type_check
469
+ # pyre-fixme[3]: Return type must be annotated.
470
+ def forward (self , x : Tensor , add_input : Optional [Tensor ] = None ):
471
+ input = x if add_input is None else x + add_input
472
+ lin0_out = self .linear0 (input )
473
+ lin1_out = self .linear1 (lin0_out )
474
+ lin1_out_alt = self .linear1_alt (lin0_out )
475
+
476
+ if self .unsupported_layer_output is not None :
477
+ self .unsupportedLayer (self .unsupported_layer_output )
478
+ # unsupportedLayer is unused in the forward func. Used only to check whether layer output is supported.
479
+ self .relu_alt (
480
+ lin1_out_alt
481
+ ) # relu_alt's output is supported but it's unused in the forward func.
482
+
483
+ relu_out = self .relu (lin1_out )
484
+ lin2_out = self .linear2 (relu_out )
485
+
486
+ lin3_out = self .linear3 (lin1_out_alt ).to (torch .int64 )
487
+
488
+ return torch .cat ((lin2_out , lin3_out ), dim = 1 )
489
+
490
+
420
491
class MultiRelu (nn .Module ):
421
492
def __init__ (self , inplace : bool = False ) -> None :
422
493
super ().__init__ ()
@@ -429,7 +500,11 @@ def forward(self, arg1: Tensor, arg2: Tensor) -> Tuple[Tensor, Tensor]:
429
500
430
501
431
502
class BasicModel_MultiLayer (nn .Module ):
432
- def __init__ (self , inplace : bool = False , multi_input_module : bool = False ) -> None :
503
+ def __init__ (
504
+ self ,
505
+ inplace : bool = False ,
506
+ multi_input_module : bool = False ,
507
+ ) -> None :
433
508
super ().__init__ ()
434
509
# Linear 0 is simply identity transform
435
510
self .multi_input_module = multi_input_module
@@ -461,6 +536,7 @@ def forward(
461
536
input = x if add_input is None else x + add_input
462
537
lin0_out = self .linear0 (input )
463
538
lin1_out = self .linear1 (lin0_out )
539
+
464
540
if self .multi_input_module :
465
541
relu_out1 , relu_out2 = self .multi_relu (lin1_out , self .linear1_alt (input ))
466
542
relu_out = relu_out1 + relu_out2
0 commit comments