@@ -28,35 +28,42 @@ public class VarianceScaling : IInitializer
28
28
{
29
29
protected float _scale ;
30
30
protected string _mode ;
31
- protected string _distribution ;
32
31
protected int ? _seed ;
33
32
protected TF_DataType _dtype ;
34
- protected bool _uniform ;
33
+ protected string _distribution ;
35
34
private readonly Dictionary < string , object > _config ;
36
35
37
36
public virtual string ClassName => "VarianceScaling" ;
38
37
39
38
public virtual IDictionary < string , object > Config => _config ;
40
39
41
- public VarianceScaling ( float factor = 2 .0f,
42
- string mode = "FAN_IN " ,
43
- bool uniform = false ,
40
+ public VarianceScaling ( float scale = 1 .0f,
41
+ string mode = "fan_in " ,
42
+ string distribution = "truncated_normal" ,
44
43
int ? seed = null ,
45
44
TF_DataType dtype = TF_DataType . TF_FLOAT )
46
45
{
47
46
if ( ! dtype . is_floating ( ) )
48
47
throw new TypeError ( "Cannot create initializer for non-floating point type." ) ;
49
- if ( ! new string [ ] { "FAN_IN" , "FAN_OUT" , "FAN_AVG" } . Contains ( mode ) )
50
- throw new TypeError ( $ "Unknown { mode } %s [FAN_IN, FAN_OUT, FAN_AVG]") ;
48
+ if ( ! new string [ ] { "fan_in" , "fan_out" , "fan_avg" } . Contains ( mode ) )
49
+ throw new TypeError ( $ "Unknown { mode } %s [fan_in, fan_out, fan_avg]") ;
50
+ if ( distribution == "normal" )
51
+ {
52
+ distribution = "truncated_normal" ;
53
+ }
54
+ if ( ! new string [ ] { "uniform" , "truncated_normal" , "untruncated_normal" } . Contains ( distribution ) )
55
+ {
56
+ throw new ValueError ( $ "Invalid `distribution` argument: { distribution } ") ;
57
+ }
51
58
52
- if ( factor < 0 )
59
+ if ( scale <= 0 )
53
60
throw new ValueError ( "`scale` must be positive float." ) ;
54
61
55
- _scale = factor ;
62
+ _scale = scale ;
56
63
_mode = mode ;
57
64
_seed = seed ;
58
65
_dtype = dtype ;
59
- _uniform = uniform ;
66
+ _distribution = distribution ;
60
67
61
68
_config = new ( ) ;
62
69
_config [ "scale" ] = _scale ;
@@ -72,23 +79,28 @@ public Tensor Apply(InitializerArgs args)
72
79
73
80
float n = 0 ;
74
81
var ( fan_in , fan_out ) = _compute_fans ( args . Shape ) ;
75
- if ( _mode == "FAN_IN" )
76
- n = fan_in ;
77
- else if ( _mode == "FAN_OUT" )
78
- n = fan_out ;
79
- else if ( _mode == "FAN_AVG" )
80
- n = ( fan_in + fan_out ) / 2.0f ;
82
+ var scale = this . _scale ;
83
+ if ( _mode == "fan_in" )
84
+ scale /= Math . Max ( 1.0f , fan_in ) ;
85
+ else if ( _mode == "fan_out" )
86
+ scale /= Math . Max ( 1.0f , fan_out ) ;
87
+ else
88
+ scale /= Math . Max ( 1.0f , ( fan_in + fan_out ) / 2 ) ;
81
89
82
- if ( _uniform )
90
+ if ( _distribution == "truncated_normal" )
83
91
{
84
- var limit = Convert . ToSingle ( Math . Sqrt ( 3.0f * _scale / n ) ) ;
85
- return random_ops . random_uniform ( args . Shape , - limit , limit , args . DType ) ;
92
+ var stddev = Math . Sqrt ( scale ) / .87962566103423978f ;
93
+ return random_ops . truncated_normal ( args . Shape , 0.0f , ( float ) stddev , args . DType ) ;
94
+ }
95
+ else if ( _distribution == "untruncated_normal" )
96
+ {
97
+ var stddev = Math . Sqrt ( scale ) ;
98
+ return random_ops . random_normal ( args . Shape , 0.0f , ( float ) stddev , args . DType ) ;
86
99
}
87
100
else
88
101
{
89
- var trunc_stddev = Convert . ToSingle ( Math . Sqrt ( 1.3f * _scale / n ) ) ;
90
- return random_ops . truncated_normal ( args . Shape , 0.0f , trunc_stddev , args . DType ,
91
- seed : _seed ) ;
102
+ var limit = ( float ) Math . Sqrt ( scale * 3.0f ) ;
103
+ return random_ops . random_uniform ( args . Shape , - limit , limit , args . DType ) ;
92
104
}
93
105
}
94
106
0 commit comments