Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,10 @@ public static Tensor<T> Broadcast<T>(scoped in ReadOnlyTensorSpan<T> source, sco
/// <exception cref="ArgumentException">Thrown when the shapes are not broadcast compatible.</exception>
public static Tensor<T> Broadcast<T>(scoped in ReadOnlyTensorSpan<T> source, scoped ReadOnlySpan<nint> lengths)
{
nint[] newSize = Tensor.GetSmallestBroadcastableLengths(source.Lengths, lengths);

ReadOnlyTensorSpan<T> intermediate = LazyBroadcast(source, newSize);
Tensor<T> output = Tensor.CreateUninitialized<T>(intermediate.Lengths);
intermediate.FlattenTo(MemoryMarshal.CreateSpan(ref output._values[0], (int)output.FlattenedLength));
return output;
TensorOperation.ValidateCompatibility<T>(source, lengths);
Tensor<T> destination = Tensor.CreateUninitialized<T>(lengths);
TensorOperation.Invoke<TensorOperation.CopyTo<T>, T, T>(source, destination);
return destination;
}
#endregion

Expand All @@ -71,12 +69,8 @@ public static Tensor<T> Broadcast<T>(scoped in ReadOnlyTensorSpan<T> source, sco
/// <param name="destination"></param>
public static void BroadcastTo<T>(this Tensor<T> source, in TensorSpan<T> destination)
{
nint[] newSize = Tensor.GetSmallestBroadcastableLengths(source.Lengths, destination.Lengths);
if (!destination.Lengths.SequenceEqual(newSize))
ThrowHelper.ThrowArgument_LengthsNotCompatible();

ReadOnlyTensorSpan<T> intermediate = LazyBroadcast(source, newSize);
intermediate.FlattenTo(MemoryMarshal.CreateSpan(ref destination._reference, (int)destination.FlattenedLength));
TensorOperation.ValidateCompatibility<T, T>(source, destination);
TensorOperation.Invoke<TensorOperation.CopyTo<T>, T, T>(source, destination);
}

/// <summary>
Expand All @@ -86,12 +80,8 @@ public static void BroadcastTo<T>(this Tensor<T> source, in TensorSpan<T> destin
/// <param name="destination">Other <see cref="TensorSpan{T}"/> to make shapes broadcastable.</param>
public static void BroadcastTo<T>(in this TensorSpan<T> source, in TensorSpan<T> destination)
{
nint[] newSize = Tensor.GetSmallestBroadcastableLengths(source.Lengths, destination.Lengths);
if (!destination.Lengths.SequenceEqual(newSize))
ThrowHelper.ThrowArgument_LengthsNotCompatible();

ReadOnlyTensorSpan<T> intermediate = LazyBroadcast(source, newSize);
intermediate.FlattenTo(MemoryMarshal.CreateSpan(ref destination._reference, (int)destination.FlattenedLength));
TensorOperation.ValidateCompatibility<T, T>(source, destination);
TensorOperation.Invoke<TensorOperation.CopyTo<T>, T, T>(source, destination);
}

/// <summary>
Expand All @@ -101,141 +91,8 @@ public static void BroadcastTo<T>(in this TensorSpan<T> source, in TensorSpan<T>
/// <param name="destination"></param>
public static void BroadcastTo<T>(in this ReadOnlyTensorSpan<T> source, in TensorSpan<T> destination)
{
nint[] newSize = Tensor.GetSmallestBroadcastableLengths(source.Lengths, destination.Lengths);
if (!destination.Lengths.SequenceEqual(newSize))
ThrowHelper.ThrowArgument_LengthsNotCompatible();

ReadOnlyTensorSpan<T> intermediate = LazyBroadcast(source, newSize);
intermediate.FlattenTo(MemoryMarshal.CreateSpan(ref destination._reference, (int)destination.FlattenedLength));
}

// Lazy/non-copy broadcasting, internal only for now.
/// <summary>
/// Broadcast the data from <paramref name="input"/> to the new shape <paramref name="lengths"/>. Creates a new <see cref="Tensor{T}"/>
/// but no memory is allocated. It manipulates the strides to achieve this affect.
/// If the shape of the <paramref name="input"/> is not compatible with the new shape, an exception is thrown.
/// </summary>
/// <param name="input">Input <see cref="TensorSpan{T}"/>.</param>
/// <param name="lengths"><see cref="ReadOnlySpan{T}"/> of the desired new shape.</param>
/// <exception cref="ArgumentException">Thrown when the shapes are not broadcast compatible.</exception>
internal static TensorSpan<T> LazyBroadcast<T>(in TensorSpan<T> input, ReadOnlySpan<nint> lengths)
{
if (input.Lengths.SequenceEqual(lengths))
return new TensorSpan<T>(ref input._reference, input._shape.LinearLength, lengths, input.Strides);

if (!TensorHelpers.IsBroadcastableTo(input.Lengths, lengths))
ThrowHelper.ThrowArgument_LengthsNotCompatible();

nint newSize = TensorSpanHelpers.CalculateFlattenedLength(lengths);

if (newSize == input.FlattenedLength)
return Reshape(input, lengths);

nint[] intermediateShape = TensorHelpers.GetIntermediateShape(input.Lengths, lengths.Length);
nint[] strides = new nint[lengths.Length];

nint stride = 1;

for (int i = strides.Length - 1; i >= 0; i--)
{
if ((intermediateShape[i] == 1 && lengths[i] != 1) || (intermediateShape[i] == 1 && lengths[i] == 1))
strides[i] = 0;
else
{
strides[i] = stride;
stride *= intermediateShape[i];
}
}

TensorSpan<T> output = new TensorSpan<T>(ref input._reference, input._shape.LinearLength, lengths, strides);

return output;
}

// Lazy/non-copy broadcasting, internal only for now.
/// <summary>
/// Broadcast the data from <paramref name="input"/> to the new shape <paramref name="shape"/>. Creates a new <see cref="Tensor{T}"/>
/// but no memory is allocated. It manipulates the strides to achieve this affect.
/// If the shape of the <paramref name="input"/> is not compatible with the new shape, an exception is thrown.
/// </summary>
/// <param name="input">Input <see cref="TensorSpan{T}"/>.</param>
/// <param name="shape"><see cref="ReadOnlySpan{T}"/> of the desired new shape.</param>
/// <exception cref="ArgumentException">Thrown when the shapes are not broadcast compatible.</exception>
internal static ReadOnlyTensorSpan<T> LazyBroadcast<T>(in ReadOnlyTensorSpan<T> input, ReadOnlySpan<nint> shape)
{
if (input.Lengths.SequenceEqual(shape))
return new TensorSpan<T>(ref input._reference, input._shape.LinearLength, shape, input.Strides);

if (!TensorHelpers.IsBroadcastableTo(input.Lengths, shape))
ThrowHelper.ThrowArgument_LengthsNotCompatible();

nint newSize = TensorSpanHelpers.CalculateFlattenedLength(shape);

if (newSize == input.FlattenedLength)
return Reshape(input, shape);

nint[] intermediateShape = TensorHelpers.GetIntermediateShape(input.Lengths, shape.Length);
nint[] strides = new nint[shape.Length];

nint stride = 1;

for (int i = strides.Length - 1; i >= 0; i--)
{
if ((intermediateShape[i] == 1 && shape[i] != 1) || (intermediateShape[i] == 1 && shape[i] == 1))
strides[i] = 0;
else
{
strides[i] = stride;
stride *= intermediateShape[i];
}
}

TensorSpan<T> output = new TensorSpan<T>(ref input._reference, input._shape.LinearLength, shape, strides);

return output;
}

// Lazy/non-copy broadcasting, internal only for now.
/// <summary>
/// Broadcast the data from <paramref name="input"/> to the new shape <paramref name="lengths"/>. Creates a new <see cref="Tensor{T}"/>
/// but no memory is allocated. It manipulates the strides to achieve this affect.
/// If the shape of the <paramref name="input"/> is not compatible with the new shape, an exception is thrown.
/// </summary>
/// <param name="input">Input <see cref="Tensor{T}"/>.</param>
/// <param name="lengths"><see cref="ReadOnlySpan{T}"/> of the desired new shape.</param>
/// <exception cref="ArgumentException">Thrown when the shapes are not broadcast compatible.</exception>
internal static Tensor<T> LazyBroadcast<T>(Tensor<T> input, ReadOnlySpan<nint> lengths)
{
if (input.Lengths.SequenceEqual(lengths))
return new Tensor<T>(input._values, lengths, input._start, isPinned: false);

if (!TensorHelpers.IsBroadcastableTo(input.Lengths, lengths))
ThrowHelper.ThrowArgument_LengthsNotCompatible();

nint newSize = TensorSpanHelpers.CalculateFlattenedLength(lengths);

if (newSize == input.FlattenedLength)
return Reshape(input, lengths);

nint[] intermediateShape = TensorHelpers.GetIntermediateShape(input.Lengths, lengths.Length);
nint[] strides = new nint[lengths.Length];

nint stride = 1;

for (int i = strides.Length - 1; i >= 0; i--)
{
if ((intermediateShape[i] == 1 && lengths[i] != 1) || (intermediateShape[i] == 1 && lengths[i] == 1))
strides[i] = 0;
else
{
strides[i] = stride;
stride *= intermediateShape[i];
}
}

Tensor<T> output = new Tensor<T>(input._values, input._start, lengths, strides);

return output;
TensorOperation.ValidateCompatibility<T, T>(source, destination);
TensorOperation.Invoke<TensorOperation.CopyTo<T>, T, T>(source, destination);
}
#endregion

Expand Down Expand Up @@ -265,7 +122,7 @@ public static Tensor<T> ConcatenateOnDimension<T>(int dimension, params scoped R
// Calculate total space needed.
nint totalLength = 0;
for (int i = 0; i < tensors.Length; i++)
totalLength += TensorSpanHelpers.CalculateFlattenedLength(tensors[i].Lengths);
totalLength += tensors[i].FlattenedLength;

nint sumOfAxis = 0;
// If axis != -1, make sure all dimensions except the one to concatenate on match.
Expand Down Expand Up @@ -333,13 +190,12 @@ public static ref readonly TensorSpan<T> ConcatenateOnDimension<T>(int dimension
// Calculate total space needed.
nint totalLength = 0;
for (int i = 0; i < tensors.Length; i++)
totalLength += TensorSpanHelpers.CalculateFlattenedLength(tensors[i].Lengths);
totalLength += tensors[i].FlattenedLength;

nint sumOfAxis = 0;
// If axis != -1, make sure all dimensions except the one to concatenate on match.
if (dimension != -1)
{
sumOfAxis = tensors[0].Lengths[dimension];
nint sumOfAxis = tensors[0].Lengths[dimension];
for (int i = 1; i < tensors.Length; i++)
{
if (tensors[0].Rank != tensors[i].Rank)
Expand All @@ -364,42 +220,13 @@ public static ref readonly TensorSpan<T> ConcatenateOnDimension<T>(int dimension
ThrowHelper.ThrowArgument_DimensionsNotSame(nameof(destination));
}
Span<T> dstSpan = MemoryMarshal.CreateSpan(ref destination._reference, (int)totalLength);
nint valuesCopied = 0;

scoped Span<nint> curIndex;
nint[]? curIndexArray;

if (tensors[0].Rank > TensorShape.MaxInlineRank)
{
curIndexArray = ArrayPool<nint>.Shared.Rent(tensors[0].Rank);
curIndex = curIndexArray.AsSpan(0, tensors[0].Rank);
}
else
{
curIndexArray = null;
curIndex = stackalloc nint[tensors[0].Rank];
}
curIndex.Clear();

nint srcIndex;
nint copyLength;

while (valuesCopied < totalLength)
for (int i = 0; i < tensors.Length; i++)
{
for (int i = 0; i < tensors.Length; i++)
{
srcIndex = TensorSpanHelpers.ComputeLinearIndex(curIndex, tensors[i].Strides, tensors[i].Lengths);
copyLength = CalculateCopyLength(tensors[i].Lengths, dimension);
Span<T> srcSpan = MemoryMarshal.CreateSpan(ref tensors[i]._values[srcIndex], (int)copyLength);
TensorSpanHelpers.Memmove(dstSpan, srcSpan, copyLength, valuesCopied);
valuesCopied += copyLength;
}
TensorSpanHelpers.AdjustIndexes(dimension - 1, 1, curIndex, tensors[0].Lengths);
TensorOperation.Invoke<TensorOperation.CopyTo<T>, T, T>(tensors[i], dstSpan);
dstSpan = dstSpan.Slice((int)tensors[i].FlattenedLength);
}

if (curIndexArray != null)
ArrayPool<nint>.Shared.Return(curIndexArray);

return ref destination;
}

Expand Down Expand Up @@ -1803,15 +1630,15 @@ public static ReadOnlyTensorSpan<T> Reshape<T>(in this ReadOnlyTensorSpan<T> ten
/// <param name="lengths"><see cref="ReadOnlySpan{T}"/> of the desired new shape.</param>
public static Tensor<T> Resize<T>(Tensor<T> tensor, ReadOnlySpan<nint> lengths)
{
nint newSize = TensorSpanHelpers.CalculateFlattenedLength(lengths);
nint newSize = TensorPrimitives.Product(lengths);
T[] values = tensor.IsPinned ? GC.AllocateArray<T>((int)newSize) : (new T[newSize]);
Tensor<T> output = new Tensor<T>(values, lengths, tensor._start, isPinned: false);
ReadOnlySpan<T> span = MemoryMarshal.CreateSpan(ref tensor.AsTensorSpan()._reference, (int)tensor._values.Length);
Tensor<T> output = Tensor.Create(values, 0, lengths, []);
ReadOnlySpan<T> span = MemoryMarshal.CreateSpan(ref Unsafe.Add(ref tensor.AsTensorSpan()._reference, tensor._start), (int)tensor._values.Length - tensor._start);
Span<T> ospan = MemoryMarshal.CreateSpan(ref output.AsTensorSpan()._reference, (int)output.FlattenedLength);
if (newSize > tensor._values.Length)
TensorSpanHelpers.Memmove(ospan, span, tensor._values.Length);
if (newSize >= span.Length)
span.CopyTo(ospan);
else
TensorSpanHelpers.Memmove(ospan, span, newSize);
span.Slice(0, ospan.Length).CopyTo(ospan);

return output;
}
Expand All @@ -1824,12 +1651,12 @@ public static Tensor<T> Resize<T>(Tensor<T> tensor, ReadOnlySpan<nint> lengths)
/// <param name="destination">Destination <see cref="TensorSpan{T}"/> with the desired new shape.</param>
public static void ResizeTo<T>(scoped in Tensor<T> tensor, in TensorSpan<T> destination)
{
ReadOnlySpan<T> span = MemoryMarshal.CreateSpan(ref tensor._values[0], tensor._values.Length);
ReadOnlySpan<T> span = MemoryMarshal.CreateSpan(ref Unsafe.Add(ref tensor.AsTensorSpan()._reference, tensor._start), (int)tensor._values.Length - tensor._start);
Span<T> ospan = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination._shape.LinearLength);
if (destination._shape.LinearLength > tensor._values.Length)
TensorSpanHelpers.Memmove(ospan, span, tensor._values.Length);
if (ospan.Length >= span.Length)
span.CopyTo(ospan);
else
TensorSpanHelpers.Memmove(ospan, span, destination._shape.LinearLength);
span.Slice(0, ospan.Length).CopyTo(ospan);
}

/// <summary>
Expand All @@ -1842,10 +1669,10 @@ public static void ResizeTo<T>(scoped in TensorSpan<T> tensor, in TensorSpan<T>
{
ReadOnlySpan<T> span = MemoryMarshal.CreateSpan(ref tensor._reference, (int)tensor._shape.LinearLength);
Span<T> ospan = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination._shape.LinearLength);
if (destination._shape.LinearLength > tensor._shape.LinearLength)
TensorSpanHelpers.Memmove(ospan, span, tensor._shape.LinearLength);
if (ospan.Length >= span.Length)
span.CopyTo(ospan);
else
TensorSpanHelpers.Memmove(ospan, span, destination._shape.LinearLength);
span.Slice(0, ospan.Length).CopyTo(ospan);
}

/// <summary>
Expand All @@ -1858,10 +1685,10 @@ public static void ResizeTo<T>(scoped in ReadOnlyTensorSpan<T> tensor, in Tensor
{
ReadOnlySpan<T> span = MemoryMarshal.CreateSpan(ref tensor._reference, (int)tensor._shape.LinearLength);
Span<T> ospan = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination._shape.LinearLength);
if (destination._shape.LinearLength > tensor._shape.LinearLength)
TensorSpanHelpers.Memmove(ospan, span, tensor._shape.LinearLength);
if (ospan.Length >= span.Length)
span.CopyTo(ospan);
else
TensorSpanHelpers.Memmove(ospan, span, destination._shape.LinearLength);
span.Slice(0, ospan.Length).CopyTo(ospan);
}
#endregion

Expand Down
Loading