Skip to content

Implement CopyToAsync in the FileBufferingReadStream #24499

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 65 additions & 20 deletions src/Http/WebUtilities/src/FileBufferingReadStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -208,39 +208,41 @@ private Stream CreateTempFile()
FileOptions.Asynchronous | FileOptions.DeleteOnClose | FileOptions.SequentialScan);
}

public override int Read(byte[] buffer, int offset, int count)
public override int Read(Span<byte> buffer)
Copy link
Member

@halter73 halter73 Aug 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we shouldn't do something like the runtime does with its Sync/AsyncReadWriteAdapter to avoid duplicating this somewhat complicated logic. Seems like that way we'd be a lot less likely to introduce bugs in just one version of Read(Async).

{
ThrowIfDisposed();

if (_buffer.Position < _buffer.Length || _completelyBuffered)
{
// Just read from the buffer
return _buffer.Read(buffer, offset, (int)Math.Min(count, _buffer.Length - _buffer.Position));
return _buffer.Read(buffer);
}

int read = _inner.Read(buffer, offset, count);
var read = _inner.Read(buffer);

if (_bufferLimit.HasValue && _bufferLimit - read < _buffer.Length)
{
Dispose();
throw new IOException("Buffer limit exceeded.");
}

if (_inMemory && _buffer.Length + read > _memoryThreshold)
// We're about to go over the threshold, switch to a file
if (_inMemory && _memoryThreshold - read < _buffer.Length)
{
_inMemory = false;
var oldBuffer = _buffer;
_buffer = CreateTempFile();
if (_rentedBuffer == null)
{
// Copy data from the in memory buffer to the file stream using a pooled buffer
oldBuffer.Position = 0;
var rentedBuffer = _bytePool.Rent(Math.Min((int)oldBuffer.Length, _maxRentedBufferSize));
try
{
var copyRead = oldBuffer.Read(rentedBuffer, 0, rentedBuffer.Length);
var copyRead = oldBuffer.Read(rentedBuffer);
while (copyRead > 0)
{
_buffer.Write(rentedBuffer, 0, copyRead);
copyRead = oldBuffer.Read(rentedBuffer, 0, rentedBuffer.Length);
_buffer.Write(rentedBuffer.AsSpan(0, copyRead));
copyRead = oldBuffer.Read(rentedBuffer);
}
}
finally
Expand All @@ -250,15 +252,15 @@ public override int Read(byte[] buffer, int offset, int count)
}
else
{
_buffer.Write(_rentedBuffer, 0, (int)oldBuffer.Length);
_buffer.Write(_rentedBuffer.AsSpan(0, (int)oldBuffer.Length));
_bytePool.Return(_rentedBuffer);
_rentedBuffer = null;
}
}

if (read > 0)
{
_buffer.Write(buffer, offset, read);
_buffer.Write(buffer.Slice(0, read));
}
else
{
Expand All @@ -268,24 +270,34 @@ public override int Read(byte[] buffer, int offset, int count)
return read;
}

public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
public override int Read(byte[] buffer, int offset, int count)
{
return Read(buffer.AsSpan(offset, count));
}

public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return ReadAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask();
}

public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
ThrowIfDisposed();

if (_buffer.Position < _buffer.Length || _completelyBuffered)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic isn't 100% right, we should be looping in here to fill buffer as much as we can before returning. That isn't a strict contract of Stream though...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I wait for you to do this before reviewing, or do you want to merge this as is for now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not making this change in the PR since it isn't a regression.

{
// Just read from the buffer
return await _buffer.ReadAsync(buffer, offset, (int)Math.Min(count, _buffer.Length - _buffer.Position), cancellationToken);
return await _buffer.ReadAsync(buffer, cancellationToken);
}

int read = await _inner.ReadAsync(buffer, offset, count, cancellationToken);
var read = await _inner.ReadAsync(buffer, cancellationToken);

if (_bufferLimit.HasValue && _bufferLimit - read < _buffer.Length)
{
Dispose();
throw new IOException("Buffer limit exceeded.");
}

if (_inMemory && _buffer.Length + read > _memoryThreshold)
if (_inMemory && _memoryThreshold - read < _buffer.Length)
{
_inMemory = false;
var oldBuffer = _buffer;
Expand All @@ -297,11 +309,11 @@ public override async Task<int> ReadAsync(byte[] buffer, int offset, int count,
try
{
// oldBuffer is a MemoryStream, no need to do async reads.
var copyRead = oldBuffer.Read(rentedBuffer, 0, rentedBuffer.Length);
var copyRead = oldBuffer.Read(rentedBuffer);
while (copyRead > 0)
{
await _buffer.WriteAsync(rentedBuffer, 0, copyRead, cancellationToken);
copyRead = oldBuffer.Read(rentedBuffer, 0, rentedBuffer.Length);
await _buffer.WriteAsync(rentedBuffer.AsMemory(0, copyRead), cancellationToken);
copyRead = oldBuffer.Read(rentedBuffer);
}
}
finally
Expand All @@ -311,15 +323,15 @@ public override async Task<int> ReadAsync(byte[] buffer, int offset, int count,
}
else
{
await _buffer.WriteAsync(_rentedBuffer, 0, (int)oldBuffer.Length, cancellationToken);
await _buffer.WriteAsync(_rentedBuffer.AsMemory(0, (int)oldBuffer.Length), cancellationToken);
_bytePool.Return(_rentedBuffer);
_rentedBuffer = null;
}
}

if (read > 0)
{
await _buffer.WriteAsync(buffer, offset, read, cancellationToken);
await _buffer.WriteAsync(buffer.Slice(0, read), cancellationToken);
}
else
{
Expand Down Expand Up @@ -349,6 +361,39 @@ public override void Flush()
throw new NotSupportedException();
}

public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
Copy link
Member

@Tratcher Tratcher Aug 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How sure are you that you fixed the CopyToAsync issue from #24032? The issue there was that the bufferSize parameter was being set to 1 by the non-virtual CopyToAsync(Stream) if the current stream length was 0 (nothing buffered yet). You're still using the given bufferSize.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops

{
// If we're completed buffered then copy from the underlying source
if (_completelyBuffered)
{
return _buffer.CopyToAsync(destination, bufferSize, cancellationToken);
}

async Task CopyToAsyncImpl()
{
// At least a 4K buffer
byte[] buffer = _bytePool.Rent(Math.Min(bufferSize, 4096));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
byte[] buffer = _bytePool.Rent(Math.Min(bufferSize, 4096));
var buffer = _bytePool.Rent(Math.Min(bufferSize, 4096));

try
{
while (true)
{
int bytesRead = await ReadAsync(buffer, cancellationToken);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
int bytesRead = await ReadAsync(buffer, cancellationToken);
var bytesRead = await ReadAsync(buffer, cancellationToken);

if (bytesRead == 0)
{
break;
}
await destination.WriteAsync(buffer.AsMemory(0, bytesRead), cancellationToken);
}
}
finally
{
_bytePool.Return(buffer);
}
}

return CopyToAsyncImpl();
}

protected override void Dispose(bool disposing)
{
if (!_disposed)
Expand Down
135 changes: 133 additions & 2 deletions src/Http/WebUtilities/test/FileBufferingReadStreamTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Buffers;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Moq;
Expand Down Expand Up @@ -157,7 +158,6 @@ public void FileBufferingReadStream_SyncReadWithOnDiskLimit_EnforcesLimit()
Assert.Equal("Buffer limit exceeded.", exception.Message);
Assert.False(stream.InMemory);
Assert.NotNull(stream.TempFileName);
Assert.False(File.Exists(tempFileName));
}

Assert.False(File.Exists(tempFileName));
Expand Down Expand Up @@ -287,7 +287,6 @@ public async Task FileBufferingReadStream_AsyncReadWithOnDiskLimit_EnforcesLimit
Assert.Equal("Buffer limit exceeded.", exception.Message);
Assert.False(stream.InMemory);
Assert.NotNull(stream.TempFileName);
Assert.False(File.Exists(tempFileName));
}

Assert.False(File.Exists(tempFileName));
Expand Down Expand Up @@ -351,6 +350,138 @@ public async Task FileBufferingReadStream_UsingMemoryStream_RentsAndReturnsRente
Assert.False(File.Exists(tempFileName));
}

[Fact]
public async Task CopyToAsyncWorks()
{
var data = Enumerable.Range(0, 1024).Select(b => (byte)b).Reverse().ToArray();
var inner = new MemoryStream(data);

using var stream = new FileBufferingReadStream(inner, 1024 * 1024, bufferLimit: null, GetCurrentDirectory());

var withoutBufferMs = new MemoryStream();
await stream.CopyToAsync(withoutBufferMs);

var withBufferMs = new MemoryStream();
stream.Position = 0;
await stream.CopyToAsync(withBufferMs);

Assert.Equal(data, withoutBufferMs.ToArray());
Assert.Equal(data, withBufferMs.ToArray());
}

[Fact]
public async Task CopyToAsyncWorksWithFileThreshold()
{
var data = Enumerable.Range(0, 1024).Select(b => (byte)b).Reverse().ToArray();
var inner = new MemoryStream(data);

using var stream = new FileBufferingReadStream(inner, 100, bufferLimit: null, GetCurrentDirectory());

var withoutBufferMs = new MemoryStream();
await stream.CopyToAsync(withoutBufferMs);

var withBufferMs = new MemoryStream();
stream.Position = 0;
await stream.CopyToAsync(withBufferMs);

Assert.Equal(data, withoutBufferMs.ToArray());
Assert.Equal(data, withBufferMs.ToArray());
}

[Fact]
public async Task ReadAsyncThenCopyToAsyncWorks()
{
var data = Enumerable.Range(0, 1024).Select(b => (byte)b).ToArray();
var inner = new MemoryStream(data);

using var stream = new FileBufferingReadStream(inner, 1024 * 1024, bufferLimit: null, GetCurrentDirectory());

var withoutBufferMs = new MemoryStream();
var buffer = new byte[100];
await stream.ReadAsync(buffer);
await stream.CopyToAsync(withoutBufferMs);

Assert.Equal(data.AsMemory(0, 100).ToArray(), buffer);
Assert.Equal(data.AsMemory(100).ToArray(), withoutBufferMs.ToArray());
}

[Fact]
public async Task ReadThenCopyToAsyncWorks()
{
var data = Enumerable.Range(0, 1024).Select(b => (byte)b).ToArray();
var inner = new MemoryStream(data);

using var stream = new FileBufferingReadStream(inner, 1024 * 1024, bufferLimit: null, GetCurrentDirectory());

var withoutBufferMs = new MemoryStream();
var buffer = new byte[100];
var read = stream.Read(buffer);
await stream.CopyToAsync(withoutBufferMs);

Assert.Equal(100, read);
Assert.Equal(data.AsMemory(0, read).ToArray(), buffer);
Assert.Equal(data.AsMemory(read).ToArray(), withoutBufferMs.ToArray());
}

[Fact]
public async Task ReadThenSeekThenCopyToAsyncWorks()
{
var data = Enumerable.Range(0, 1024).Select(b => (byte)b).ToArray();
var inner = new MemoryStream(data);

using var stream = new FileBufferingReadStream(inner, 1024 * 1024, bufferLimit: null, GetCurrentDirectory());

var withoutBufferMs = new MemoryStream();
var buffer = new byte[100];
var read = stream.Read(buffer);
stream.Position = 0;
await stream.CopyToAsync(withoutBufferMs);

Assert.Equal(100, read);
Assert.Equal(data.AsMemory(0, read).ToArray(), buffer);
Assert.Equal(data.ToArray(), withoutBufferMs.ToArray());
}

[Fact]
public void PartialReadThenSeekReplaysBuffer()
{
var data = Enumerable.Range(0, 1024).Select(b => (byte)b).ToArray();
var inner = new MemoryStream(data);

using var stream = new FileBufferingReadStream(inner, 1024 * 1024, bufferLimit: null, GetCurrentDirectory());

var withoutBufferMs = new MemoryStream();
var buffer = new byte[100];
var read1 = stream.Read(buffer);
stream.Position = 0;
var buffer2 = new byte[200];
var read2 = stream.Read(buffer2);
Assert.Equal(100, read1);
Assert.Equal(100, read2);
Assert.Equal(data.AsMemory(0, read1).ToArray(), buffer);
Assert.Equal(data.AsMemory(0, read2).ToArray(), buffer2.AsMemory(0, read2).ToArray());
}

[Fact]
public async Task PartialReadAsyncThenSeekReplaysBuffer()
{
var data = Enumerable.Range(0, 1024).Select(b => (byte)b).ToArray();
var inner = new MemoryStream(data);

using var stream = new FileBufferingReadStream(inner, 1024 * 1024, bufferLimit: null, GetCurrentDirectory());

var withoutBufferMs = new MemoryStream();
var buffer = new byte[100];
var read1 = await stream.ReadAsync(buffer);
stream.Position = 0;
var buffer2 = new byte[200];
var read2 = await stream.ReadAsync(buffer2);
Assert.Equal(100, read1);
Assert.Equal(100, read2);
Assert.Equal(data.AsMemory(0, read1).ToArray(), buffer);
Assert.Equal(data.AsMemory(0, read2).ToArray(), buffer2.AsMemory(0, read2).ToArray());
}

private static string GetCurrentDirectory()
{
return AppContext.BaseDirectory;
Expand Down
27 changes: 23 additions & 4 deletions src/Mvc/Mvc.Core/test/Formatters/JsonInputFormatterTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -497,8 +497,8 @@ public async Task ReadAsync_DoesNotDisposeBufferedReadStream()
var content = "{\"name\": \"Test\"}";
var contentBytes = Encoding.UTF8.GetBytes(content);
var httpContext = GetHttpContext(contentBytes);
var testBufferedReadStream = new Mock<FileBufferingReadStream>(httpContext.Request.Body, 1024) { CallBase = true };
httpContext.Request.Body = testBufferedReadStream.Object;
var testBufferedReadStream = new VerifyDisposeFileBufferingReadStream(httpContext.Request.Body, 1024);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moq doesn't support Span.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inline classes would be so nice!

httpContext.Request.Body = testBufferedReadStream;

var formatterContext = CreateInputFormatterContext(typeof(ComplexModel), httpContext);

Expand All @@ -508,8 +508,7 @@ public async Task ReadAsync_DoesNotDisposeBufferedReadStream()
// Assert
var userModel = Assert.IsType<ComplexModel>(result.Model);
Assert.Equal("Test", userModel.Name);

testBufferedReadStream.Verify(v => v.DisposeAsync(), Times.Never());
Assert.False(testBufferedReadStream.Disposed);
}

[Fact]
Expand Down Expand Up @@ -635,5 +634,25 @@ protected sealed class ComplexModel

public byte Small { get; set; }
}

private class VerifyDisposeFileBufferingReadStream : FileBufferingReadStream
{
public bool Disposed { get; private set; }
public VerifyDisposeFileBufferingReadStream(Stream inner, int memoryThreshold) : base(inner, memoryThreshold)
{
}

protected override void Dispose(bool disposing)
{
Disposed = true;
base.Dispose(disposing);
}

public override ValueTask DisposeAsync()
{
Disposed = true;
return base.DisposeAsync();
}
}
}
}
Loading