@@ -175,31 +175,30 @@ def test_lambda_fn(self):
175175 _ , new_obj , _ = self .roundtrip (obj , safe_mode = False )
176176 self .assertEqual (obj ["activation" ](3 ), new_obj ["activation" ](3 ))
177177
178- # TODO
179- # def test_lambda_layer(self):
180- # lmbda = keras.layers.Lambda(lambda x: x**2)
181- # with self.assertRaisesRegex(ValueError, "arbitrary code execution"):
182- # self.roundtrip(lmbda, safe_mode=True)
183-
184- # _, new_lmbda, _ = self.roundtrip(lmbda, safe_mode=False)
185- # x = ops.random.normal((2, 2))
186- # y1 = lmbda(x)
187- # y2 = new_lmbda(x)
188- # self.assertAllClose(y1, y2, atol=1e-5)
189-
190- # def test_safe_mode_scope(self):
191- # lmbda = keras.layers.Lambda(lambda x: x**2)
192- # with serialization_lib.SafeModeScope(safe_mode=True):
193- # with self.assertRaisesRegex(
194- # ValueError, "arbitrary code execution"
195- # ):
196- # self.roundtrip(lmbda)
197- # with serialization_lib.SafeModeScope(safe_mode=False):
198- # _, new_lmbda, _ = self.roundtrip(lmbda)
199- # x = ops.random.normal((2, 2))
200- # y1 = lmbda(x)
201- # y2 = new_lmbda(x)
202- # self.assertAllClose(y1, y2, atol=1e-5)
178+ def test_lambda_layer (self ):
179+ lmbda = keras .layers .Lambda (lambda x : x ** 2 )
180+ with self .assertRaisesRegex (ValueError , "Deserializing it is unsafe" ):
181+ self .roundtrip (lmbda , safe_mode = True )
182+
183+ _ , new_lmbda , _ = self .roundtrip (lmbda , safe_mode = False )
184+ x = ops .random .normal ((2 , 2 ))
185+ y1 = lmbda (x )
186+ y2 = new_lmbda (x )
187+ self .assertAllClose (y1 , y2 , atol = 1e-5 )
188+
189+ def test_safe_mode_scope (self ):
190+ lmbda = keras .layers .Lambda (lambda x : x ** 2 )
191+ with serialization_lib .SafeModeScope (safe_mode = True ):
192+ with self .assertRaisesRegex (
193+ ValueError , "Deserializing it is unsafe"
194+ ):
195+ self .roundtrip (lmbda )
196+ with serialization_lib .SafeModeScope (safe_mode = False ):
197+ _ , new_lmbda , _ = self .roundtrip (lmbda )
198+ x = ops .random .normal ((2 , 2 ))
199+ y1 = lmbda (x )
200+ y2 = new_lmbda (x )
201+ self .assertAllClose (y1 , y2 , atol = 1e-5 )
203202
204203 @pytest .mark .requires_trainable_backend
205204 def test_dict_inputs_outputs (self ):
0 commit comments