Skip to content

Commit ccda2c3

Browse files
committed
Add BinaryCrossentropy loss function.
1 parent 8a21ad2 commit ccda2c3

File tree

15 files changed

+202
-58
lines changed

15 files changed

+202
-58
lines changed

src/TensorFlowNET.Core/GlobalUsing.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
global using System;
2+
global using System.Collections.Generic;
3+
global using System.Text;

src/TensorFlowNET.Core/Keras/IKerasApi.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
using System.Collections.Generic;
33
using System.Text;
44
using Tensorflow.Keras.Layers;
5+
using Tensorflow.Keras.Losses;
56

67
namespace Tensorflow.Keras
78
{
89
public interface IKerasApi
910
{
1011
public ILayersApi layers { get; }
12+
public ILossesApi losses { get; }
1113
public IInitializersApi initializers { get; }
1214
}
1315
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
namespace Tensorflow.Keras.Losses;
2+
3+
public interface ILossFunc
4+
{
5+
public string Reduction { get; }
6+
public string Name { get; }
7+
Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null);
8+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
namespace Tensorflow.Keras.Losses;
2+
3+
public interface ILossesApi
4+
{
5+
ILossFunc BinaryCrossentropy(bool from_logits = false,
6+
float label_smoothing = 0f,
7+
int axis = -1,
8+
string reduction = "auto",
9+
string name = "binary_crossentropy");
10+
11+
ILossFunc SparseCategoricalCrossentropy(string reduction = null,
12+
string name = null,
13+
bool from_logits = false);
14+
15+
ILossFunc CategoricalCrossentropy(string reduction = null,
16+
string name = null,
17+
bool from_logits = false);
18+
19+
ILossFunc MeanSquaredError(string reduction = null,
20+
string name = null);
21+
22+
ILossFunc MeanSquaredLogarithmicError(string reduction = null,
23+
string name = null);
24+
25+
ILossFunc MeanAbsolutePercentageError(string reduction = null,
26+
string name = null);
27+
28+
ILossFunc MeanAbsoluteError(string reduction = null,
29+
string name = null);
30+
31+
ILossFunc CosineSimilarity(string reduction = null,
32+
int axis = -1,
33+
string name = null);
34+
35+
ILossFunc Huber(string reduction = null,
36+
string name = null,
37+
Tensor delta = null);
38+
39+
ILossFunc LogCosh(string reduction = null,
40+
string name = null);
41+
}

src/TensorFlowNET.Keras/BackendImpl.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,20 @@ public Tensor categorical_crossentropy(Tensor target, Tensor output, bool from_l
276276
return -math_ops.reduce_sum(target * math_ops.log(output), new Axis(axis));
277277
}
278278

279+
public Tensor binary_crossentropy(Tensor target, Tensor output, bool from_logits = false)
280+
{
281+
if (from_logits)
282+
return tf.nn.sigmoid_cross_entropy_with_logits(labels: target, logits: output);
283+
284+
var epsilon_ = constant_op.constant(epsilon(), dtype: output.dtype.as_base_dtype());
285+
output = tf.clip_by_value(output, epsilon_, 1.0f - epsilon_);
286+
287+
// Compute cross entropy from probabilities.
288+
var bce = target * tf.math.log(output + epsilon());
289+
bce += (1 - target) * tf.math.log(1 - output + epsilon());
290+
return -bce;
291+
}
292+
279293
/// <summary>
280294
/// Resizes the images contained in a 4D tensor.
281295
/// </summary>
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
global using System;
2+
global using System.Collections.Generic;
3+
global using System.Text;
4+
global using static Tensorflow.Binding;
5+
global using static Tensorflow.KerasApi;

src/TensorFlowNET.Keras/KerasInterface.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ public class KerasInterface : IKerasApi
2121
public IInitializersApi initializers { get; } = new InitializersApi();
2222
public Regularizers regularizers { get; } = new Regularizers();
2323
public ILayersApi layers { get; } = new LayersApi();
24-
public LossesApi losses { get; } = new LossesApi();
24+
public ILossesApi losses { get; } = new LossesApi();
2525
public Activations activations { get; } = new Activations();
2626
public Preprocessing preprocessing { get; } = new Preprocessing();
2727
ThreadLocal<BackendImpl> _backend = new ThreadLocal<BackendImpl>(() => new BackendImpl());
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
namespace Tensorflow.Keras.Losses;
2+
3+
public class BinaryCrossentropy : LossFunctionWrapper, ILossFunc
4+
{
5+
float label_smoothing;
6+
public BinaryCrossentropy(
7+
bool from_logits = false,
8+
float label_smoothing = 0,
9+
string reduction = null,
10+
string name = null) :
11+
base(reduction: reduction,
12+
name: name == null ? "binary_crossentropy" : name,
13+
from_logits: from_logits)
14+
{
15+
this.label_smoothing = label_smoothing;
16+
}
17+
18+
19+
public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1)
20+
{
21+
var sum = keras.backend.binary_crossentropy(y_true, y_pred, from_logits: from_logits);
22+
return keras.backend.mean(sum, axis: axis);
23+
}
24+
}
Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,24 @@
1-
using System;
2-
using System.Collections.Generic;
3-
using System.Text;
4-
using static Tensorflow.Binding;
5-
using static Tensorflow.KerasApi;
1+
namespace Tensorflow.Keras.Losses;
62

7-
namespace Tensorflow.Keras.Losses
3+
public class CategoricalCrossentropy : LossFunctionWrapper, ILossFunc
84
{
9-
public class CategoricalCrossentropy : LossFunctionWrapper, ILossFunc
5+
float label_smoothing;
6+
public CategoricalCrossentropy(
7+
bool from_logits = false,
8+
float label_smoothing = 0,
9+
string reduction = null,
10+
string name = null) :
11+
base(reduction: reduction,
12+
name: name == null ? "categorical_crossentropy" : name,
13+
from_logits: from_logits)
1014
{
11-
float label_smoothing;
12-
public CategoricalCrossentropy(
13-
bool from_logits = false,
14-
float label_smoothing = 0,
15-
string reduction = null,
16-
string name = null) :
17-
base(reduction: reduction,
18-
name: name == null ? "categorical_crossentropy" : name,
19-
from_logits: from_logits)
20-
{
21-
this.label_smoothing = label_smoothing;
22-
}
15+
this.label_smoothing = label_smoothing;
16+
}
2317

2418

25-
public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1)
26-
{
27-
// Try to adjust the shape so that rank of labels = rank of logits - 1.
28-
return keras.backend.categorical_crossentropy(y_true, y_pred, from_logits: from_logits);
29-
}
19+
public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1)
20+
{
21+
// Try to adjust the shape so that rank of labels = rank of logits - 1.
22+
return keras.backend.categorical_crossentropy(y_true, y_pred, from_logits: from_logits);
3023
}
3124
}

src/TensorFlowNET.Keras/Losses/ILossFunc.cs

Lines changed: 0 additions & 9 deletions
This file was deleted.

0 commit comments

Comments
 (0)