Skip to content

feat: Support training of RNN and LSTM. #1110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/TensorFlowNET.Core/APIs/c_api.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.

using System;
using System.Runtime.InteropServices;
using static Tensorflow.CppShapeInferenceResult.Types;

namespace Tensorflow
{
Expand Down Expand Up @@ -50,6 +51,19 @@ public static string StringPiece(IntPtr handle)
return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle);
}

public unsafe static byte[] ByteStringPiece(IntPtr handle)
{
byte* str_data = (byte*)handle.ToPointer();
List<byte> bytes = new List<byte>();
byte current = 255;
while (current != ((byte)'\0'))
{
current = *(str_data++);
bytes.Add(current);
}
return bytes.Take(bytes.Count - 1).ToArray();
}

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate void Deallocator(IntPtr data, IntPtr size, ref DeallocatorArgs args);

Expand Down
10 changes: 5 additions & 5 deletions src/TensorFlowNET.Core/APIs/tf.control_flow.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ public Tensor while_loop(Func<Tensor, Tensor> cond,
Tensor loop_vars,
int parallel_iterations = 10)
{
Func<Tensor[], Tensor> cond1 = x
Func<Tensors, Tensor> cond1 = x
=> cond(x[0]);

Func<Tensor[], Tensor[]> body1 = x
Func<Tensors, Tensors> body1 = x
=> new[] { body(x[0]) };

var results = control_flow_ops.while_loop(cond1,
Expand All @@ -58,9 +58,9 @@ public Tensor while_loop(Func<Tensor, Tensor> cond,
return results[0];
}

public Tensor[] while_loop(Func<Tensor[], Tensor> cond,
Func<Tensor[], Tensor[]> body,
Tensor[] loop_vars,
public Tensor[] while_loop(Func<Tensors, Tensor> cond,
Func<Tensors, Tensors> body,
Tensors loop_vars,
int parallel_iterations = 10,
string name = null)
=> control_flow_ops.while_loop(cond, body, loop_vars,
Expand Down
6 changes: 3 additions & 3 deletions src/TensorFlowNET.Core/APIs/tf.tensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@ public Tensor strided_slice<T>(Tensor input, T[] begin, T[] end, T[] strides = n
public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null)
=> array_ops.split(
value: value,
num_split: num_split,
num_or_size_splits: num_split,
axis: axis,
name: name);

public Tensor[] split(Tensor value, int num_split, int axis, string name = null)
=> array_ops.split(
value: value,
num_split: num_split,
axis: axis,
num_or_size_splits: num_split,
axis: ops.convert_to_tensor(axis),
name: name);

public Tensor ensure_shape(Tensor x, Shape shape, string name = null)
Expand Down
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Binding.Util.cs
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ public static TF_DataType GetDataType(this object data)
case Tensors tensors:
return tensors.dtype;
case IEnumerable<Tensor> tensors:
return tensors.First().dtype;
return tensors.Where(x => x is not null).First().dtype;
case RefVariable variable:
return variable.dtype;
case ResourceVariable variable:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Extensions
namespace Tensorflow.Common.Extensions
{
public static class JObjectExtensions
{
public static T? TryGetOrReturnNull<T>(this JObject obj, string key)
{
var res = obj[key];
if(res is null)
if (res is null)
{
return default(T);
return default;
}
else
{
Expand Down
38 changes: 38 additions & 0 deletions src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Tensorflow.Common.Extensions
{
public static class LinqExtensions
{
#if NETSTANDARD2_0
public static IEnumerable<T> TakeLast<T>(this IEnumerable<T> sequence, int count)
{
return sequence.Skip(sequence.Count() - count);
}

public static IEnumerable<T> SkipLast<T>(this IEnumerable<T> sequence, int count)
{
return sequence.Take(sequence.Count() - count);
}
#endif
public static Tensors ToTensors(this Tensor[] tensors)
{
return new Tensors(tensors);
}

public static Tensors ToTensors(this IList<Tensor> tensors)
{
return new Tensors(tensors);
}

public static void Deconstruct<T1, T2, T3>(this (T1, T2, T3) values, out T1 first, out T2 second, out T3 third)
{
first = values.Item1;
second = values.Item2;
third = values.Item3;
}
}
}
33 changes: 33 additions & 0 deletions src/TensorFlowNET.Core/Common/Extensions/NestExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Types;

namespace Tensorflow.Common.Extensions
{
public static class NestExtensions
{
public static Tensors ToTensors(this INestable<Tensor> tensors)
{
return new Tensors(tensors.AsNest());
}

public static Tensors? ToTensors(this Nest<Tensor> tensors)
{
return Tensors.FromNest(tensors);
}

/// <summary>
/// If the nested object is already a nested type, this function could reduce it.
/// For example, `Nest[Nest[T]]` can be reduced to `Nest[T]`.
/// </summary>
/// <typeparam name="TIn"></typeparam>
/// <typeparam name="TOut"></typeparam>
/// <param name="input"></param>
/// <returns></returns>
public static Nest<TOut> ReduceTo<TIn, TOut>(this INestStructure<TIn> input) where TIn: INestStructure<TOut>
{
return Nest<TOut>.ReduceFrom(input);
}
}
}
20 changes: 20 additions & 0 deletions src/TensorFlowNET.Core/Common/Types/FakeTensorByTensorArray.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
/// <summary>
/// This is a temp solution, which should be removed after refactoring `Tensors`
/// </summary>
[Obsolete]
public class FakeTensorByTensorArray: Tensor
{
public TensorArray TensorArray { get; set; }

public FakeTensorByTensorArray(TensorArray array)
{
TensorArray = array;
}
}
}
69 changes: 69 additions & 0 deletions src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;

namespace Tensorflow.Common.Types
{
public class GeneralizedTensorShape: Nest<Shape>
{
public GeneralizedTensorShape(Shape value, string? name = null)
{
NodeValue = value;
NestType = NestType.Node;
}

public GeneralizedTensorShape(IEnumerable<Shape> values, string? name = null)
{
ListValue = values.Select(s => new Nest<Shape>(s) as INestStructure<Shape>).ToList();
Name = name;
NestType = NestType.List;
}

public GeneralizedTensorShape(Dictionary<string, Shape> value, string? name = null)
{
DictValue = value.ToDictionary(x => x.Key, x => new Nest<Shape>(x.Value) as INestStructure<Shape>);
Name = name;
NestType = NestType.Dictionary;
}

public GeneralizedTensorShape(Nest<Shape> other)
{
NestType = other.NestType;
NodeValue = other.NodeValue;
DictValue = other.DictValue;
ListValue = other.ListValue;
Name = other.Name;
}

public Shape ToSingleShape()
{
var shapes = Flatten().ToList();
if (shapes.Count != 1)
{
throw new ValueError("The generalized shape contains more than 1 dim.");
}
return shapes[0];
}

public long ToNumber()
{
var shapes = Flatten().ToList();
if (shapes.Count != 1 || shapes[0].ndim != 1)
{
throw new ValueError("The generalized shape contains more than 1 dim.");
}
return shapes[0].dims[0];
}

public INestStructure<TensorShapeConfig> ToTensorShapeConfigs()
{
return MapStructure(s => new TensorShapeConfig() { Items = s.dims.Select<long, long?>(x => x == -1 ? null : x).ToArray() });
}

public static implicit operator GeneralizedTensorShape(Shape shape)
{
return new GeneralizedTensorShape(shape);
}
}
}
40 changes: 40 additions & 0 deletions src/TensorFlowNET.Core/Common/Types/INestStructure.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
/// <summary>
/// This interface indicates that a class may have a nested structure and provide
/// methods to manipulate with the structure.
/// </summary>
public interface INestStructure<T>: INestable<T>
{
NestType NestType { get; }

/// <summary>
/// The item count of depth 1 of the nested structure.
/// For example, [1, 2, [3, 4, 5]] has ShallowNestedCount = 3.
/// </summary>
int ShallowNestedCount { get; }
/// <summary>
/// The total item count of depth 1 of the nested structure.
/// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5.
/// </summary>
int TotalNestedCount { get; }

/// <summary>
/// Flatten the Nestable object. Node that if the object contains only one value,
/// it will be flattened to an enumerable with one element.
/// </summary>
/// <returns></returns>
IEnumerable<T> Flatten();
/// <summary>
/// Construct a new object with the same nested structure.
/// </summary>
/// <typeparam name="TOut"></typeparam>
/// <param name="func"></param>
/// <returns></returns>
INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func);
}
}
11 changes: 11 additions & 0 deletions src/TensorFlowNET.Core/Common/Types/INestable.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
public interface INestable<T>
{
Nest<T> AsNest();
}
}
21 changes: 21 additions & 0 deletions src/TensorFlowNET.Core/Common/Types/IOptionalArgs.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
/// <summary>
/// This interface is used when some corresponding python methods have optional args.
/// For example, `Keras.Layer.Apply` generally takes three args as the inputs, while
/// `Keras.Layer.RNN` takes more. Then when calling RNN, you should add `RnnOptionalArgs`
/// as the parameter of the method.
/// </summary>
public interface IOptionalArgs
{
/// <summary>
/// The identifier of the class. It is not an argument but only something to
/// separate different OptionalArgs.
/// </summary>
string Identifier { get; }
}
}
Loading