@@ -706,6 +706,25 @@ def version_7(cls, ctx, node, **kwargs):
706706 _make_softmax_cross_entropy_with_logits (ctx , labels , logits , node )
707707
708708
709+ def _make_sparse_softmax_cross_entropy_with_logits (ctx , label , logit , tf_ori_node ):
710+ label_dtype = ctx .get_dtype (label .output [0 ])
711+ logit_dtype = ctx .get_dtype (logit .output [0 ])
712+ utils .make_sure (label_dtype == logit_dtype , "the following logic only works on same dtype of label and logit" )
713+
714+ log_softmax = ctx .make_node (op_type = "LogSoftmax" , inputs = logit .output )
715+ # implement tf.multiply(-1, tf.reduce_sum(tf.multiply(label, log_softmax), axis=1))
716+ mul1 = ctx .make_node (op_type = "Mul" , inputs = [label .output [0 ], log_softmax .output [0 ]])
717+ reduce_sum = ctx .make_node (op_type = "ReduceSum" , inputs = [mul1 .output [0 ]], attr = {"axes" : [- 1 ]})
718+ const_negative_one = ctx .make_const (name = utils .make_name ("const_negative_one" ),
719+ np_val = np .array (- 1 ).astype (utils .ONNX_TO_NUMPY_DTYPE [logit_dtype ]))
720+ mul2 = ctx .make_node (op_type = "Mul" , inputs = [const_negative_one .output [0 ], reduce_sum .output [0 ]])
721+ shapes = tf_ori_node .output_shapes
722+ dtypes = tf_ori_node .output_dtypes
723+ ctx .remove_node (tf_ori_node .name )
724+ ctx .make_node (op_type = "Squeeze" , inputs = [mul2 .output [0 ]], attr = {"axes" : [1 ]},
725+ outputs = [tf_ori_node .output [0 ]], shapes = [shapes [0 ]], dtypes = [dtypes [0 ]])
726+
727+
709728@tf_op ("SparseSoftmaxCrossEntropyWithLogits" )
710729class SparseSoftmaxCrossEntropyWithLogits :
711730 @classmethod
@@ -778,4 +797,4 @@ def version_9(cls, ctx, node, **kwargs):
778797 if logit_dtype != TensorProto .INT64 :
779798 label_node = ctx .make_node ("Cast" , label_node .output , attr = {"to" : logit_dtype }, dtypes = [logit_dtype ])
780799
781- _make_softmax_cross_entropy_with_logits (ctx , label_node , logit_node , node )
800+ _make_sparse_softmax_cross_entropy_with_logits (ctx , label_node , logit_node , node )
0 commit comments