Skip to content

Commit 1f5149a

Browse files
authored
Implement CopyToAsync in the FileBufferingReadStream (#24499)
* Implement CopyToAsync in the FileBufferingReadStream - overrride Span and Memory overloads and implement array overloads in terms of those overloads. - Implemented CopyToAsync (but not CopyTo) - Added tests Fixes #24032
1 parent e31998c commit 1f5149a

File tree

5 files changed

+268
-34
lines changed

5 files changed

+268
-34
lines changed

src/Http/WebUtilities/src/FileBufferingReadStream.cs

Lines changed: 65 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -208,39 +208,41 @@ private Stream CreateTempFile()
208208
FileOptions.Asynchronous | FileOptions.DeleteOnClose | FileOptions.SequentialScan);
209209
}
210210

211-
public override int Read(byte[] buffer, int offset, int count)
211+
public override int Read(Span<byte> buffer)
212212
{
213213
ThrowIfDisposed();
214+
214215
if (_buffer.Position < _buffer.Length || _completelyBuffered)
215216
{
216217
// Just read from the buffer
217-
return _buffer.Read(buffer, offset, (int)Math.Min(count, _buffer.Length - _buffer.Position));
218+
return _buffer.Read(buffer);
218219
}
219220

220-
int read = _inner.Read(buffer, offset, count);
221+
var read = _inner.Read(buffer);
221222

222223
if (_bufferLimit.HasValue && _bufferLimit - read < _buffer.Length)
223224
{
224-
Dispose();
225225
throw new IOException("Buffer limit exceeded.");
226226
}
227227

228-
if (_inMemory && _buffer.Length + read > _memoryThreshold)
228+
// We're about to go over the threshold, switch to a file
229+
if (_inMemory && _memoryThreshold - read < _buffer.Length)
229230
{
230231
_inMemory = false;
231232
var oldBuffer = _buffer;
232233
_buffer = CreateTempFile();
233234
if (_rentedBuffer == null)
234235
{
236+
// Copy data from the in memory buffer to the file stream using a pooled buffer
235237
oldBuffer.Position = 0;
236238
var rentedBuffer = _bytePool.Rent(Math.Min((int)oldBuffer.Length, _maxRentedBufferSize));
237239
try
238240
{
239-
var copyRead = oldBuffer.Read(rentedBuffer, 0, rentedBuffer.Length);
241+
var copyRead = oldBuffer.Read(rentedBuffer);
240242
while (copyRead > 0)
241243
{
242-
_buffer.Write(rentedBuffer, 0, copyRead);
243-
copyRead = oldBuffer.Read(rentedBuffer, 0, rentedBuffer.Length);
244+
_buffer.Write(rentedBuffer.AsSpan(0, copyRead));
245+
copyRead = oldBuffer.Read(rentedBuffer);
244246
}
245247
}
246248
finally
@@ -250,15 +252,15 @@ public override int Read(byte[] buffer, int offset, int count)
250252
}
251253
else
252254
{
253-
_buffer.Write(_rentedBuffer, 0, (int)oldBuffer.Length);
255+
_buffer.Write(_rentedBuffer.AsSpan(0, (int)oldBuffer.Length));
254256
_bytePool.Return(_rentedBuffer);
255257
_rentedBuffer = null;
256258
}
257259
}
258260

259261
if (read > 0)
260262
{
261-
_buffer.Write(buffer, offset, read);
263+
_buffer.Write(buffer.Slice(0, read));
262264
}
263265
else
264266
{
@@ -268,24 +270,34 @@ public override int Read(byte[] buffer, int offset, int count)
268270
return read;
269271
}
270272

271-
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
273+
public override int Read(byte[] buffer, int offset, int count)
274+
{
275+
return Read(buffer.AsSpan(offset, count));
276+
}
277+
278+
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
279+
{
280+
return ReadAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask();
281+
}
282+
283+
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
272284
{
273285
ThrowIfDisposed();
286+
274287
if (_buffer.Position < _buffer.Length || _completelyBuffered)
275288
{
276289
// Just read from the buffer
277-
return await _buffer.ReadAsync(buffer, offset, (int)Math.Min(count, _buffer.Length - _buffer.Position), cancellationToken);
290+
return await _buffer.ReadAsync(buffer, cancellationToken);
278291
}
279292

280-
int read = await _inner.ReadAsync(buffer, offset, count, cancellationToken);
293+
var read = await _inner.ReadAsync(buffer, cancellationToken);
281294

282295
if (_bufferLimit.HasValue && _bufferLimit - read < _buffer.Length)
283296
{
284-
Dispose();
285297
throw new IOException("Buffer limit exceeded.");
286298
}
287299

288-
if (_inMemory && _buffer.Length + read > _memoryThreshold)
300+
if (_inMemory && _memoryThreshold - read < _buffer.Length)
289301
{
290302
_inMemory = false;
291303
var oldBuffer = _buffer;
@@ -297,11 +309,11 @@ public override async Task<int> ReadAsync(byte[] buffer, int offset, int count,
297309
try
298310
{
299311
// oldBuffer is a MemoryStream, no need to do async reads.
300-
var copyRead = oldBuffer.Read(rentedBuffer, 0, rentedBuffer.Length);
312+
var copyRead = oldBuffer.Read(rentedBuffer);
301313
while (copyRead > 0)
302314
{
303-
await _buffer.WriteAsync(rentedBuffer, 0, copyRead, cancellationToken);
304-
copyRead = oldBuffer.Read(rentedBuffer, 0, rentedBuffer.Length);
315+
await _buffer.WriteAsync(rentedBuffer.AsMemory(0, copyRead), cancellationToken);
316+
copyRead = oldBuffer.Read(rentedBuffer);
305317
}
306318
}
307319
finally
@@ -311,15 +323,15 @@ public override async Task<int> ReadAsync(byte[] buffer, int offset, int count,
311323
}
312324
else
313325
{
314-
await _buffer.WriteAsync(_rentedBuffer, 0, (int)oldBuffer.Length, cancellationToken);
326+
await _buffer.WriteAsync(_rentedBuffer.AsMemory(0, (int)oldBuffer.Length), cancellationToken);
315327
_bytePool.Return(_rentedBuffer);
316328
_rentedBuffer = null;
317329
}
318330
}
319331

320332
if (read > 0)
321333
{
322-
await _buffer.WriteAsync(buffer, offset, read, cancellationToken);
334+
await _buffer.WriteAsync(buffer.Slice(0, read), cancellationToken);
323335
}
324336
else
325337
{
@@ -349,6 +361,39 @@ public override void Flush()
349361
throw new NotSupportedException();
350362
}
351363

364+
public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
365+
{
366+
// If we're completed buffered then copy from the underlying source
367+
if (_completelyBuffered)
368+
{
369+
return _buffer.CopyToAsync(destination, bufferSize, cancellationToken);
370+
}
371+
372+
async Task CopyToAsyncImpl()
373+
{
374+
// At least a 4K buffer
375+
byte[] buffer = _bytePool.Rent(Math.Min(bufferSize, 4096));
376+
try
377+
{
378+
while (true)
379+
{
380+
int bytesRead = await ReadAsync(buffer, cancellationToken);
381+
if (bytesRead == 0)
382+
{
383+
break;
384+
}
385+
await destination.WriteAsync(buffer.AsMemory(0, bytesRead), cancellationToken);
386+
}
387+
}
388+
finally
389+
{
390+
_bytePool.Return(buffer);
391+
}
392+
}
393+
394+
return CopyToAsyncImpl();
395+
}
396+
352397
protected override void Dispose(bool disposing)
353398
{
354399
if (!_disposed)

src/Http/WebUtilities/test/FileBufferingReadStreamTests.cs

Lines changed: 133 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System;
55
using System.Buffers;
66
using System.IO;
7+
using System.Linq;
78
using System.Text;
89
using System.Threading.Tasks;
910
using Moq;
@@ -157,7 +158,6 @@ public void FileBufferingReadStream_SyncReadWithOnDiskLimit_EnforcesLimit()
157158
Assert.Equal("Buffer limit exceeded.", exception.Message);
158159
Assert.False(stream.InMemory);
159160
Assert.NotNull(stream.TempFileName);
160-
Assert.False(File.Exists(tempFileName));
161161
}
162162

163163
Assert.False(File.Exists(tempFileName));
@@ -287,7 +287,6 @@ public async Task FileBufferingReadStream_AsyncReadWithOnDiskLimit_EnforcesLimit
287287
Assert.Equal("Buffer limit exceeded.", exception.Message);
288288
Assert.False(stream.InMemory);
289289
Assert.NotNull(stream.TempFileName);
290-
Assert.False(File.Exists(tempFileName));
291290
}
292291

293292
Assert.False(File.Exists(tempFileName));
@@ -351,6 +350,138 @@ public async Task FileBufferingReadStream_UsingMemoryStream_RentsAndReturnsRente
351350
Assert.False(File.Exists(tempFileName));
352351
}
353352

353+
[Fact]
354+
public async Task CopyToAsyncWorks()
355+
{
356+
var data = Enumerable.Range(0, 1024).Select(b => (byte)b).Reverse().ToArray();
357+
var inner = new MemoryStream(data);
358+
359+
using var stream = new FileBufferingReadStream(inner, 1024 * 1024, bufferLimit: null, GetCurrentDirectory());
360+
361+
var withoutBufferMs = new MemoryStream();
362+
await stream.CopyToAsync(withoutBufferMs);
363+
364+
var withBufferMs = new MemoryStream();
365+
stream.Position = 0;
366+
await stream.CopyToAsync(withBufferMs);
367+
368+
Assert.Equal(data, withoutBufferMs.ToArray());
369+
Assert.Equal(data, withBufferMs.ToArray());
370+
}
371+
372+
[Fact]
373+
public async Task CopyToAsyncWorksWithFileThreshold()
374+
{
375+
var data = Enumerable.Range(0, 1024).Select(b => (byte)b).Reverse().ToArray();
376+
var inner = new MemoryStream(data);
377+
378+
using var stream = new FileBufferingReadStream(inner, 100, bufferLimit: null, GetCurrentDirectory());
379+
380+
var withoutBufferMs = new MemoryStream();
381+
await stream.CopyToAsync(withoutBufferMs);
382+
383+
var withBufferMs = new MemoryStream();
384+
stream.Position = 0;
385+
await stream.CopyToAsync(withBufferMs);
386+
387+
Assert.Equal(data, withoutBufferMs.ToArray());
388+
Assert.Equal(data, withBufferMs.ToArray());
389+
}
390+
391+
[Fact]
392+
public async Task ReadAsyncThenCopyToAsyncWorks()
393+
{
394+
var data = Enumerable.Range(0, 1024).Select(b => (byte)b).ToArray();
395+
var inner = new MemoryStream(data);
396+
397+
using var stream = new FileBufferingReadStream(inner, 1024 * 1024, bufferLimit: null, GetCurrentDirectory());
398+
399+
var withoutBufferMs = new MemoryStream();
400+
var buffer = new byte[100];
401+
await stream.ReadAsync(buffer);
402+
await stream.CopyToAsync(withoutBufferMs);
403+
404+
Assert.Equal(data.AsMemory(0, 100).ToArray(), buffer);
405+
Assert.Equal(data.AsMemory(100).ToArray(), withoutBufferMs.ToArray());
406+
}
407+
408+
[Fact]
409+
public async Task ReadThenCopyToAsyncWorks()
410+
{
411+
var data = Enumerable.Range(0, 1024).Select(b => (byte)b).ToArray();
412+
var inner = new MemoryStream(data);
413+
414+
using var stream = new FileBufferingReadStream(inner, 1024 * 1024, bufferLimit: null, GetCurrentDirectory());
415+
416+
var withoutBufferMs = new MemoryStream();
417+
var buffer = new byte[100];
418+
var read = stream.Read(buffer);
419+
await stream.CopyToAsync(withoutBufferMs);
420+
421+
Assert.Equal(100, read);
422+
Assert.Equal(data.AsMemory(0, read).ToArray(), buffer);
423+
Assert.Equal(data.AsMemory(read).ToArray(), withoutBufferMs.ToArray());
424+
}
425+
426+
[Fact]
427+
public async Task ReadThenSeekThenCopyToAsyncWorks()
428+
{
429+
var data = Enumerable.Range(0, 1024).Select(b => (byte)b).ToArray();
430+
var inner = new MemoryStream(data);
431+
432+
using var stream = new FileBufferingReadStream(inner, 1024 * 1024, bufferLimit: null, GetCurrentDirectory());
433+
434+
var withoutBufferMs = new MemoryStream();
435+
var buffer = new byte[100];
436+
var read = stream.Read(buffer);
437+
stream.Position = 0;
438+
await stream.CopyToAsync(withoutBufferMs);
439+
440+
Assert.Equal(100, read);
441+
Assert.Equal(data.AsMemory(0, read).ToArray(), buffer);
442+
Assert.Equal(data.ToArray(), withoutBufferMs.ToArray());
443+
}
444+
445+
[Fact]
446+
public void PartialReadThenSeekReplaysBuffer()
447+
{
448+
var data = Enumerable.Range(0, 1024).Select(b => (byte)b).ToArray();
449+
var inner = new MemoryStream(data);
450+
451+
using var stream = new FileBufferingReadStream(inner, 1024 * 1024, bufferLimit: null, GetCurrentDirectory());
452+
453+
var withoutBufferMs = new MemoryStream();
454+
var buffer = new byte[100];
455+
var read1 = stream.Read(buffer);
456+
stream.Position = 0;
457+
var buffer2 = new byte[200];
458+
var read2 = stream.Read(buffer2);
459+
Assert.Equal(100, read1);
460+
Assert.Equal(100, read2);
461+
Assert.Equal(data.AsMemory(0, read1).ToArray(), buffer);
462+
Assert.Equal(data.AsMemory(0, read2).ToArray(), buffer2.AsMemory(0, read2).ToArray());
463+
}
464+
465+
[Fact]
466+
public async Task PartialReadAsyncThenSeekReplaysBuffer()
467+
{
468+
var data = Enumerable.Range(0, 1024).Select(b => (byte)b).ToArray();
469+
var inner = new MemoryStream(data);
470+
471+
using var stream = new FileBufferingReadStream(inner, 1024 * 1024, bufferLimit: null, GetCurrentDirectory());
472+
473+
var withoutBufferMs = new MemoryStream();
474+
var buffer = new byte[100];
475+
var read1 = await stream.ReadAsync(buffer);
476+
stream.Position = 0;
477+
var buffer2 = new byte[200];
478+
var read2 = await stream.ReadAsync(buffer2);
479+
Assert.Equal(100, read1);
480+
Assert.Equal(100, read2);
481+
Assert.Equal(data.AsMemory(0, read1).ToArray(), buffer);
482+
Assert.Equal(data.AsMemory(0, read2).ToArray(), buffer2.AsMemory(0, read2).ToArray());
483+
}
484+
354485
private static string GetCurrentDirectory()
355486
{
356487
return AppContext.BaseDirectory;

src/Mvc/Mvc.Core/test/Formatters/JsonInputFormatterTestBase.cs

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -497,8 +497,8 @@ public async Task ReadAsync_DoesNotDisposeBufferedReadStream()
497497
var content = "{\"name\": \"Test\"}";
498498
var contentBytes = Encoding.UTF8.GetBytes(content);
499499
var httpContext = GetHttpContext(contentBytes);
500-
var testBufferedReadStream = new Mock<FileBufferingReadStream>(httpContext.Request.Body, 1024) { CallBase = true };
501-
httpContext.Request.Body = testBufferedReadStream.Object;
500+
var testBufferedReadStream = new VerifyDisposeFileBufferingReadStream(httpContext.Request.Body, 1024);
501+
httpContext.Request.Body = testBufferedReadStream;
502502

503503
var formatterContext = CreateInputFormatterContext(typeof(ComplexModel), httpContext);
504504

@@ -508,8 +508,7 @@ public async Task ReadAsync_DoesNotDisposeBufferedReadStream()
508508
// Assert
509509
var userModel = Assert.IsType<ComplexModel>(result.Model);
510510
Assert.Equal("Test", userModel.Name);
511-
512-
testBufferedReadStream.Verify(v => v.DisposeAsync(), Times.Never());
511+
Assert.False(testBufferedReadStream.Disposed);
513512
}
514513

515514
[Fact]
@@ -635,5 +634,25 @@ protected sealed class ComplexModel
635634

636635
public byte Small { get; set; }
637636
}
637+
638+
private class VerifyDisposeFileBufferingReadStream : FileBufferingReadStream
639+
{
640+
public bool Disposed { get; private set; }
641+
public VerifyDisposeFileBufferingReadStream(Stream inner, int memoryThreshold) : base(inner, memoryThreshold)
642+
{
643+
}
644+
645+
protected override void Dispose(bool disposing)
646+
{
647+
Disposed = true;
648+
base.Dispose(disposing);
649+
}
650+
651+
public override ValueTask DisposeAsync()
652+
{
653+
Disposed = true;
654+
return base.DisposeAsync();
655+
}
656+
}
638657
}
639658
}

0 commit comments

Comments
 (0)