Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@
<data name="ThrowArgument_InPlaceInvalidShape" xml:space="preserve">
<value>In place operations require the same shape for both tensors</value>
</data>
<data name="ThrowArgument_InvalidAxis" xml:space="preserve">
<value>Invalid axis provided. Must be greater then or equal to 0 and less than the tensor rank.</value>
<data name="ThrowArgument_InvalidDimension" xml:space="preserve">
<value>Invalid dimension provided. Must be greater then or equal to 0 and less than the tensor rank.</value>
</data>
<data name="ThrowArgument_InvalidConcatenateShape" xml:space="preserve">
<value>The tensors must have the same shape, except in the dimension corresponding to axis.</value>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,18 +133,14 @@ public static Tensor<T> ConcatenateOnDimension<T>(int dimension, params scoped R
ThrowHelper.ThrowArgument_ConcatenateTooFewTensors();

if (dimension < -1 || dimension > tensors[0].Rank)
ThrowHelper.ThrowArgument_InvalidAxis();
ThrowHelper.ThrowArgument_InvalidDimension();

// Calculate total space needed.
nint totalLength = 0;
for (int i = 0; i < tensors.Length; i++)
totalLength += tensors[i].FlattenedLength;
Tensor<T> tensor;

nint sumOfAxis = 0;
// If axis != -1, make sure all dimensions except the one to concatenate on match.
if (dimension != -1)
{
sumOfAxis = tensors[0].Lengths[dimension];
nint sumOfAxis = tensors[0].Lengths[dimension];
for (int i = 1; i < tensors.Length; i++)
{
if (tensors[0].Rank != tensors[i].Rank)
Expand All @@ -157,22 +153,31 @@ public static Tensor<T> ConcatenateOnDimension<T>(int dimension, params scoped R
ThrowHelper.ThrowArgument_InvalidConcatenateShape();
}
}
sumOfAxis += tensors[i].Lengths[dimension];
checked
{
sumOfAxis += tensors[i].Lengths[dimension];
}
}
}

Tensor<T> tensor;
if (dimension == -1)
{
tensor = Tensor.Create<T>([totalLength]);
}
else
{
nint[] lengths = new nint[tensors[0].Rank];
tensors[0].Lengths.CopyTo(lengths);
lengths[dimension] = sumOfAxis;
tensor = Tensor.Create<T>(lengths);
}
else
{
// Calculate total space needed.
nint totalLength = 0;
for (int i = 0; i < tensors.Length; i++)
{
checked
{
totalLength += tensors[i].FlattenedLength;
}
}

tensor = Tensor.Create<T>([totalLength]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think this is correct and it's now calling the Tensor.Create<T>(T[] array) overload, rather than the Tensor.Create<T>(scoped ReadOnlySpan<nint> lengths, bool pinned = false) overload

I think this should stay nint and we can rely on Tensor.Create<T>(scoped ReadOnlySpan<nint> lengths, bool pinned = false) validating that it can be allocated by the underlying tensor storage.

We just need to ensure that adding the combined tensors flattened lengths together doesn't overflow the nint.


I also think that this represents a UX issue with the Create APIs.

I think we may want to disambiguate as CreateFromShape or similar so that a user passing in an nint[] isn't confused on whether its going to create a Tensor<nint> or a Tensor<T> where nint is the lengths.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed back to nint.

But it still was calling the right overload, not the Tensor.Create<T>(T[] array) one since T != int.

}

ConcatenateOnDimension(dimension, tensors, tensor);
return tensor;
Expand Down Expand Up @@ -201,7 +206,7 @@ public static ref readonly TensorSpan<T> ConcatenateOnDimension<T>(int dimension
ThrowHelper.ThrowArgument_ConcatenateTooFewTensors();

if (dimension < -1 || dimension > tensors[0].Rank)
ThrowHelper.ThrowArgument_InvalidAxis();
ThrowHelper.ThrowArgument_InvalidDimension();

// Calculate total space needed.
nint totalLength = 0;
Expand All @@ -212,11 +217,12 @@ public static ref readonly TensorSpan<T> ConcatenateOnDimension<T>(int dimension
if (dimension != -1)
{
nint sumOfAxis = tensors[0].Lengths[dimension];
int rank = tensors[0].Rank;
for (int i = 1; i < tensors.Length; i++)
{
if (tensors[0].Rank != tensors[i].Rank)
if (rank != tensors[i].Rank)
ThrowHelper.ThrowArgument_InvalidConcatenateShape();
for (int j = 0; j < tensors[0].Rank; j++)
for (int j = 0; j < rank; j++)
{
if (j != dimension)
{
Expand All @@ -228,7 +234,7 @@ public static ref readonly TensorSpan<T> ConcatenateOnDimension<T>(int dimension
}

// Make sure the destination tensor has the correct shape.
nint[] lengths = new nint[tensors[0].Rank];
nint[] lengths = new nint[rank];
tensors[0].Lengths.CopyTo(lengths);
lengths[dimension] = sumOfAxis;

Expand Down Expand Up @@ -339,18 +345,17 @@ public static Tensor<T> Create<T>(T[] array, int start, scoped ReadOnlySpan<nint
/// <returns>A new tensor that contains elements copied from <paramref name="enumerable" />.</returns>
public static Tensor<T> Create<T>(IEnumerable<T> enumerable, bool pinned = false)
{
T[] array = enumerable.ToArray();

if (pinned)
{
T[] array = enumerable.ToArray();

Tensor<T> tensor = CreateUninitialized<T>([array.Length], pinned);
array.CopyTo(tensor._values);

return tensor;
}
else
{
T[] array = enumerable.ToArray();
return Create(array);
}
}
Expand All @@ -364,18 +369,17 @@ public static Tensor<T> Create<T>(IEnumerable<T> enumerable, scoped ReadOnlySpan
/// <returns>A new tensor that contains elements copied from <paramref name="enumerable" /> and with the specified <paramref name="lengths" /> and <paramref name="strides" />.</returns>
public static Tensor<T> Create<T>(IEnumerable<T> enumerable, scoped ReadOnlySpan<nint> lengths, scoped ReadOnlySpan<nint> strides, bool pinned = false)
{
T[] array = enumerable.ToArray();

if (pinned)
{
T[] array = enumerable.ToArray();

Tensor<T> tensor = CreateUninitialized<T>(lengths, strides, pinned);
array.CopyTo(tensor._values);

return tensor;
}
else
{
T[] array = enumerable.ToArray();
return Create(array, lengths, strides);
}
}
Expand Down Expand Up @@ -620,20 +624,8 @@ public static bool EqualsAny<T>(in ReadOnlyTensorSpan<T> x, T y)
/// <param name="value">Value to update in the <paramref name="tensor"/>.</param>
public static ref readonly TensorSpan<T> FilteredUpdate<T>(in this TensorSpan<T> tensor, scoped in ReadOnlyTensorSpan<bool> filter, T value)
{
if (filter.Lengths.Length != tensor.Lengths.Length)
ThrowHelper.ThrowArgument_DimensionsNotSame(nameof(filter));

Span<T> srcSpan = MemoryMarshal.CreateSpan(ref tensor._reference, (int)tensor._shape.LinearLength);
Span<bool> filterSpan = MemoryMarshal.CreateSpan(ref filter._reference, (int)tensor._shape.LinearLength);

for (int i = 0; i < filterSpan.Length; i++)
{
if (filterSpan[i])
{
srcSpan[i] = value;
}
}

TensorOperation.ValidateCompatibility(filter, tensor);
TensorOperation.Invoke<TensorOperation.FilteredUpdate<T>, bool, T, T>(filter, value, tensor);
return ref tensor;
}

Expand All @@ -646,24 +638,8 @@ public static ref readonly TensorSpan<T> FilteredUpdate<T>(in this TensorSpan<T>
/// <param name="values">Values to update in the <paramref name="tensor"/>.</param>
public static ref readonly TensorSpan<T> FilteredUpdate<T>(in this TensorSpan<T> tensor, scoped in ReadOnlyTensorSpan<bool> filter, scoped in ReadOnlyTensorSpan<T> values)
{
if (filter.Lengths.Length != tensor.Lengths.Length)
ThrowHelper.ThrowArgument_DimensionsNotSame(nameof(filter));
if (values.Rank != 1)
ThrowHelper.ThrowArgument_1DTensorRequired(nameof(values));

Span<T> dstSpan = MemoryMarshal.CreateSpan(ref tensor._reference, (int)tensor._shape.LinearLength);
Span<bool> filterSpan = MemoryMarshal.CreateSpan(ref filter._reference, (int)tensor._shape.LinearLength);
Span<T> valuesSpan = MemoryMarshal.CreateSpan(ref values._reference, (int)values._shape.LinearLength);

int index = 0;
for (int i = 0; i < filterSpan.Length; i++)
{
if (filterSpan[i])
{
dstSpan[i] = valuesSpan[index++];
}
}

TensorOperation.ValidateCompatibility(filter, values, tensor);
TensorOperation.Invoke<TensorOperation.FilteredUpdate<T>, bool, T, T>(filter, values, tensor);
return ref tensor;
}
#endregion
Expand Down Expand Up @@ -1409,6 +1385,9 @@ public static Tensor<T> PermuteDimensions<T>(this Tensor<T> tensor, ReadOnlySpan
}
else
{
if (!dimensions.IsEmpty && dimensions.Length != tensor.Lengths.Length)
ThrowHelper.ThrowArgument_PermuteAxisOrder();

scoped Span<nint> newLengths = TensorOperation.RentedBuffer.CreateUninitialized(tensor.Rank, out TensorOperation.RentedBuffer<nint> lengthsRentedBuffer);
scoped Span<nint> newStrides = TensorOperation.RentedBuffer.CreateUninitialized(tensor.Rank, out TensorOperation.RentedBuffer<nint> stridesRentedBuffer);
scoped Span<int> newLinearOrder = TensorOperation.RentedBuffer.CreateUninitialized(tensor.Rank, out TensorOperation.RentedBuffer<int> linearOrderRentedBuffer);
Expand All @@ -1426,11 +1405,12 @@ public static Tensor<T> PermuteDimensions<T>(this Tensor<T> tensor, ReadOnlySpan
}
else
{
if (dimensions.Length != tensor.Lengths.Length)
ThrowHelper.ThrowArgument_PermuteAxisOrder();

for (int i = 0; i < dimensions.Length; i++)
{
if (dimensions[i] >= tensor.Lengths.Length || dimensions[i] < 0)
{
ThrowHelper.ThrowArgument_InvalidDimension();
}
newLengths[i] = tensor.Lengths[dimensions[i]];
newStrides[i] = tensor.Strides[dimensions[i]];
newLinearOrder[i] = tensor._shape.LinearRankOrder[dimensions[i]];
Expand Down Expand Up @@ -1467,7 +1447,8 @@ public static Tensor<T> Reshape<T>(this Tensor<T> tensor, ReadOnlySpan<nint> len

nint[] newLengths = lengths.ToArray();
// Calculate wildcard info.
if (lengths.Contains(-1))
int wildcardIndex = lengths.IndexOf(-1);
if (wildcardIndex >= 0)
{
if (lengths.Count(-1) > 1)
ThrowHelper.ThrowArgument_OnlyOneWildcard();
Expand All @@ -1479,7 +1460,7 @@ public static Tensor<T> Reshape<T>(this Tensor<T> tensor, ReadOnlySpan<nint> len
tempTotal /= lengths[i];
}
}
newLengths[lengths.IndexOf(-1)] = tempTotal;
newLengths[wildcardIndex] = tempTotal;
}

nint tempLinear = TensorPrimitives.Product(newLengths);
Expand Down Expand Up @@ -1538,8 +1519,8 @@ public static TensorSpan<T> Reshape<T>(in this TensorSpan<T> tensor, scoped Read
}

nint[] newLengths = lengths.ToArray();
// Calculate wildcard info.
if (lengths.Contains(-1))
int wildcardIndex = lengths.IndexOf(-1);
if (wildcardIndex >= 0)
{
if (lengths.Count(-1) > 1)
ThrowHelper.ThrowArgument_OnlyOneWildcard();
Expand All @@ -1551,7 +1532,7 @@ public static TensorSpan<T> Reshape<T>(in this TensorSpan<T> tensor, scoped Read
tempTotal /= lengths[i];
}
}
newLengths[lengths.IndexOf(-1)] = tempTotal;
newLengths[wildcardIndex] = tempTotal;

}

Expand Down Expand Up @@ -1615,7 +1596,8 @@ public static ReadOnlyTensorSpan<T> Reshape<T>(in this ReadOnlyTensorSpan<T> ten

nint[] newLengths = lengths.ToArray();
// Calculate wildcard info.
if (lengths.Contains(-1))
int wildcardIndex = lengths.IndexOf(-1);
if (wildcardIndex >= 0)
{
if (lengths.Count(-1) > 1)
ThrowHelper.ThrowArgument_OnlyOneWildcard();
Expand All @@ -1627,7 +1609,7 @@ public static ReadOnlyTensorSpan<T> Reshape<T>(in this ReadOnlyTensorSpan<T> ten
tempTotal /= lengths[i];
}
}
newLengths[lengths.IndexOf(-1)] = tempTotal;
newLengths[wildcardIndex] = tempTotal;

}

Expand Down Expand Up @@ -1701,12 +1683,7 @@ public static Tensor<T> Resize<T>(Tensor<T> tensor, ReadOnlySpan<nint> lengths)
/// <param name="destination">Destination <see cref="TensorSpan{T}"/> with the desired new shape.</param>
public static void ResizeTo<T>(scoped in Tensor<T> tensor, in TensorSpan<T> destination)
{
ReadOnlySpan<T> span = MemoryMarshal.CreateSpan(ref Unsafe.Add(ref tensor.AsTensorSpan()._reference, tensor._start), (int)tensor._values.Length - tensor._start);
Span<T> ospan = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination._shape.LinearLength);
if (ospan.Length >= span.Length)
span.CopyTo(ospan);
else
span.Slice(0, ospan.Length).CopyTo(ospan);
ResizeTo(tensor.AsReadOnlyTensorSpan(), destination);
}

/// <summary>
Expand All @@ -1717,12 +1694,7 @@ public static void ResizeTo<T>(scoped in Tensor<T> tensor, in TensorSpan<T> dest
/// <param name="destination">Destination <see cref="TensorSpan{T}"/> with the desired new shape.</param>
public static void ResizeTo<T>(scoped in TensorSpan<T> tensor, in TensorSpan<T> destination)
{
ReadOnlySpan<T> span = MemoryMarshal.CreateSpan(ref tensor._reference, (int)tensor._shape.LinearLength);
Span<T> ospan = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination._shape.LinearLength);
if (ospan.Length >= span.Length)
span.CopyTo(ospan);
else
span.Slice(0, ospan.Length).CopyTo(ospan);
ResizeTo(tensor.AsReadOnlyTensorSpan(), destination);
}

/// <summary>
Expand Down Expand Up @@ -1890,6 +1862,8 @@ public static ref readonly TensorSpan<T> SetSlice<T>(this in TensorSpan<T> tenso
/// <param name="dimension">The axis to split on.</param>
public static Tensor<T>[] Split<T>(scoped in ReadOnlyTensorSpan<T> tensor, int splitCount, nint dimension)
{
if (dimension < 0 || dimension >= tensor.Rank)
ThrowHelper.ThrowArgument_AxisLargerThanRank();
if (tensor.Lengths[(int)dimension] % splitCount != 0)
ThrowHelper.ThrowArgument_SplitNotSplitEvenly();

Expand Down Expand Up @@ -2221,8 +2195,10 @@ public static Tensor<T> StackAlongDimension<T>(int dimension, params ReadOnlySpa
ThrowHelper.ThrowArgument_StackShapesNotSame();
}

if (dimension < 0)
dimension = tensors[0].Rank - dimension;
// We are safe to do dimension > tensors[0].Rank instead of >= because we are adding a new dimension
// with our call to Unsqueeze.
if (dimension < 0 || dimension > tensors[0].Rank)
ThrowHelper.ThrowArgument_AxisLargerThanRank();

Tensor<T>[] outputs = new Tensor<T>[tensors.Length];
for (int i = 0; i < tensors.Length; i++)
Expand Down Expand Up @@ -2259,8 +2235,10 @@ public static ref readonly TensorSpan<T> StackAlongDimension<T>(scoped ReadOnlyS
ThrowHelper.ThrowArgument_StackShapesNotSame();
}

if (dimension < 0)
dimension = tensors[0].Rank - dimension;
// We are safe to do dimension > tensors[0].Rank instead of >= because we are adding a new dimension
// with our call to Unsqueeze.
if (dimension < 0 || dimension > tensors[0].Rank)
ThrowHelper.ThrowArgument_AxisLargerThanRank();

Tensor<T>[] outputs = new Tensor<T>[tensors.Length];
for (int i = 0; i < tensors.Length; i++)
Expand Down
Loading
Loading