Skip to content

Commit 3370723

Browse files
committed
Define Keras interface in core project (WIP).
1 parent 19363cd commit 3370723

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+610
-205
lines changed

src/SciSharp.TensorFlow.Redist/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ Related merged [commits](https://github.com/SciSharp/TensorFlow.NET/commit/854a5
2626

2727
#### Download pre-build package
2828

29-
[Mac OSX CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-2.4.0.tar.gz), [Linux CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.4.0.tar.gz), [Linux GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-2.4.0.tar.gz), [Windows CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-2.4.0.tar.gz), [Windows GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-windows-x86_64-2.4.0.zip)
29+
[Mac OSX CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-2.10.0.tar.gz), [Linux CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.10.0.tar.gz), [Linux GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-2.10.0.tar.gz), [Windows CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-2.10.0.zip), [Windows GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-windows-x86_64-2.10.0.zip)
3030

3131

3232

@@ -35,6 +35,6 @@ Related merged [commits](https://github.com/SciSharp/TensorFlow.NET/commit/854a5
3535
On Windows, the tar command does not support extracting archives with symlinks. So when `dotnet pack` runs on Windows it will only package the Windows binaries.
3636

3737
1. Run `dotnet pack SciSharp.TensorFlow.Redist.nupkgproj` under `src/SciSharp.TensorFlow.Redist` directory in Linux.
38-
2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.4.0.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600`
38+
2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.10.0.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600`
3939

4040

src/TensorFlowNET.Console/Program.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ static void Main(string[] args)
1010
var diag = new Diagnostician();
1111
// diag.Diagnose(@"D:\memory.txt");
1212

13+
var rnn = new SimpleRnnTest();
14+
rnn.Run();
15+
1316
// this class is used explor new features.
1417
var exploring = new Exploring();
1518
// exploring.Run();
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras;
5+
using Tensorflow.NumPy;
6+
using static Tensorflow.Binding;
7+
using static Tensorflow.KerasApi;
8+
9+
namespace Tensorflow
10+
{
11+
public class SimpleRnnTest
12+
{
13+
public void Run()
14+
{
15+
tf.keras = new KerasInterface();
16+
var inputs = np.random.random((32, 10, 8)).astype(np.float32);
17+
var simple_rnn = tf.keras.layers.SimpleRNN(4);
18+
var output = simple_rnn.Apply(inputs); // The output has shape `[32, 4]`.
19+
if (output.shape == (32, 4))
20+
{
21+
22+
}
23+
/*simple_rnn = tf.keras.layers.SimpleRNN(
24+
4, return_sequences = True, return_state = True)
25+
26+
# whole_sequence_output has shape `[32, 10, 4]`.
27+
# final_state has shape `[32, 4]`.
28+
whole_sequence_output, final_state = simple_rnn(inputs)*/
29+
}
30+
}
31+
}

src/TensorFlowNET.Console/Tensorflow.Console.csproj

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
<RootNamespace>Tensorflow</RootNamespace>
77
<AssemblyName>Tensorflow</AssemblyName>
88
<Platforms>AnyCPU;x64</Platforms>
9-
<LangVersion>9.0</LangVersion>
9+
<LangVersion>11.0</LangVersion>
1010
</PropertyGroup>
1111

1212
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
@@ -20,7 +20,7 @@
2020
</PropertyGroup>
2121

2222
<ItemGroup>
23-
<PackageReference Include="SciSharp.TensorFlow.Redist-Windows-GPU" Version="2.7.0" />
23+
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.10.0" />
2424
</ItemGroup>
2525

2626
<ItemGroup>

src/TensorFlowNET.Core/Keras/ArgsDefinition/LSTMArgs.cs

Lines changed: 0 additions & 22 deletions
This file was deleted.
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
using Tensorflow.Keras.ArgsDefinition.Rnn;
2+
3+
namespace Tensorflow.Keras.ArgsDefinition.Lstm
4+
{
5+
public class LSTMArgs : RNNArgs
6+
{
7+
public bool UnitForgetBias { get; set; }
8+
public float Dropout { get; set; }
9+
public float RecurrentDropout { get; set; }
10+
public int Implementation { get; set; }
11+
}
12+
}

src/TensorFlowNET.Core/Keras/ArgsDefinition/LSTMCellArgs.cs renamed to src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMCellArgs.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
namespace Tensorflow.Keras.ArgsDefinition
1+
namespace Tensorflow.Keras.ArgsDefinition.Lstm
22
{
33
public class LSTMCellArgs : LayerArgs
44
{

src/TensorFlowNET.Core/Keras/ArgsDefinition/RNNArgs.cs

Lines changed: 0 additions & 21 deletions
This file was deleted.
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
using System.Collections.Generic;
2+
3+
namespace Tensorflow.Keras.ArgsDefinition.Rnn
4+
{
5+
public class RNNArgs : LayerArgs
6+
{
7+
public interface IRnnArgCell : ILayer
8+
{
9+
object state_size { get; }
10+
}
11+
12+
public IRnnArgCell Cell { get; set; } = null;
13+
public bool ReturnSequences { get; set; } = false;
14+
public bool ReturnState { get; set; } = false;
15+
public bool GoBackwards { get; set; } = false;
16+
public bool Stateful { get; set; } = false;
17+
public bool Unroll { get; set; } = false;
18+
public bool TimeMajor { get; set; } = false;
19+
public Dictionary<string, object> Kwargs { get; set; } = null;
20+
21+
public int Units { get; set; }
22+
public Activation Activation { get; set; }
23+
public Activation RecurrentActivation { get; set; }
24+
public bool UseBias { get; set; } = true;
25+
public IInitializer KernelInitializer { get; set; }
26+
public IInitializer RecurrentInitializer { get; set; }
27+
public IInitializer BiasInitializer { get; set; }
28+
29+
// kernel_regularizer=None,
30+
// recurrent_regularizer=None,
31+
// bias_regularizer=None,
32+
// activity_regularizer=None,
33+
// kernel_constraint=None,
34+
// recurrent_constraint=None,
35+
// bias_constraint=None,
36+
// dropout=0.,
37+
// recurrent_dropout=0.,
38+
// return_sequences=False,
39+
// return_state=False,
40+
// go_backwards=False,
41+
// stateful=False,
42+
// unroll=False,
43+
// **kwargs):
44+
}
45+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
namespace Tensorflow.Keras.ArgsDefinition.Rnn
2+
{
3+
public class SimpleRNNArgs : RNNArgs
4+
{
5+
6+
}
7+
}

0 commit comments

Comments
 (0)