@@ -484,32 +484,35 @@ function nntest.RReLU()
484484 for _ ,train in ipairs ({true ,false }) do
485485 -- test with separate output buffer and inplace
486486 for _ ,inplace in ipairs ({false ,true }) do
487- module = nn .RReLU (l , u , inplace )
488- if train then
489- module :training ()
490- else
491- module :evaluate ()
492- end
493- input = torch .rand (nframe , size , kW , kH ) - 0.5
494- input :storage ()[1 ] = - 1
495- local original_input = input :clone ()
496- local output = module :forward (input )
497- mytester :assert (output :sign ():eq (original_input :sign ()):all (), ' sign flipped forward ' )
498- local gradOutput = torch .ones (output :size ())
499- local gradInput = module :backward (input , gradOutput )
500- mytester :assert (gradInput :gt (0 ):eq (input :ne (0 )):all (), ' gradient ' )
501- mytester :assert (gradInput :lt (1 ):eq (input :le (0 )):all (), ' backward negative inputs ' )
502- mytester :assert (gradInput :eq (1 ):eq (input :gt (0 )):all (), ' backward positive inputs ' )
503- if not train then
504- local err = gradInput [input :le (0 )]:mean ()- (module .lower + module .upper )/ 2
505- mytester :assertlt (err , precision , ' error on gradient ' )
506- end
487+ -- test with channel-wise
488+ for _ ,cw in ipairs ({true ,false }) do
489+ module = nn .RReLU (l , u , inplace , cw )
490+ if train then
491+ module :training ()
492+ else
493+ module :evaluate ()
494+ end
495+ input = torch .rand (nframe , size , kW , kH ) - 0.5
496+ input :storage ()[1 ] = - 1
497+ local original_input = input :clone ()
498+ local output = module :forward (input )
499+ mytester :assert (output :sign ():eq (original_input :sign ()):all (), ' sign flipped forward ' )
500+ local gradOutput = torch .ones (output :size ())
501+ local gradInput = module :backward (input , gradOutput )
502+ mytester :assert (gradInput :gt (0 ):eq (input :ne (0 )):all (), ' gradient ' )
503+ mytester :assert (gradInput :lt (1 ):eq (input :le (0 )):all (), ' backward negative inputs ' )
504+ mytester :assert (gradInput :eq (1 ):eq (input :gt (0 )):all (), ' backward positive inputs ' )
505+ if not train then
506+ local err = gradInput [input :le (0 )]:mean ()- (module .lower + module .upper )/ 2
507+ mytester :assertlt (err , precision , ' error on gradient ' )
508+ end
507509
508- input = - torch .rand (1000 )
509- module :forward (input ) -- fill internal noise tensor
510- local g = module :backward (input , torch .ones (1000 ))
511- local err = math.abs (g [input :le (0 )]:mean ()- (module .lower + module .upper )/ 2 )
512- mytester :assertlt (err , 0.05 , ' mean deviation of gradient for negative inputs ' )
510+ input = - torch .rand (1000 )
511+ module :forward (input ) -- fill internal noise tensor
512+ local g = module :backward (input , torch .ones (1000 ))
513+ local err = math.abs (g [input :le (0 )]:mean ()- (module .lower + module .upper )/ 2 )
514+ mytester :assertlt (err , 0.05 , ' mean deviation of gradient for negative inputs ' )
515+ end
513516 end
514517 end
515518end
0 commit comments