Skip to content

Commit 78de0b0

Browse files
committed
Implement CopyToAsync in the FileBufferingReadStream
- overrride Span and Memory overloads and implement array overloads in terms of those overloads. - Implemented CopyToAsync (but not CopyTo) - Added tests Fixes #24032
1 parent 723e32a commit 78de0b0

File tree

2 files changed

+140
-24
lines changed

2 files changed

+140
-24
lines changed

src/Http/WebUtilities/src/FileBufferingReadStream.cs

Lines changed: 67 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -208,39 +208,41 @@ private Stream CreateTempFile()
208208
FileOptions.Asynchronous | FileOptions.DeleteOnClose | FileOptions.SequentialScan);
209209
}
210210

211-
public override int Read(byte[] buffer, int offset, int count)
211+
public override int Read(Span<byte> buffer)
212212
{
213213
ThrowIfDisposed();
214-
if (_buffer.Position < _buffer.Length || _completelyBuffered)
214+
215+
if (_completelyBuffered)
215216
{
216217
// Just read from the buffer
217-
return _buffer.Read(buffer, offset, (int)Math.Min(count, _buffer.Length - _buffer.Position));
218+
return _buffer.Read(buffer);
218219
}
219220

220-
int read = _inner.Read(buffer, offset, count);
221+
var read = _inner.Read(buffer);
221222

222223
if (_bufferLimit.HasValue && _bufferLimit - read < _buffer.Length)
223224
{
224-
Dispose();
225225
throw new IOException("Buffer limit exceeded.");
226226
}
227227

228-
if (_inMemory && _buffer.Length + read > _memoryThreshold)
228+
// We're about to go over the threshold, switch to a file
229+
if (_inMemory && _memoryThreshold - read < _buffer.Length)
229230
{
230231
_inMemory = false;
231232
var oldBuffer = _buffer;
232233
_buffer = CreateTempFile();
233234
if (_rentedBuffer == null)
234235
{
236+
// Copy data from the in memory buffer to the file stream using a pooled buffer
235237
oldBuffer.Position = 0;
236238
var rentedBuffer = _bytePool.Rent(Math.Min((int)oldBuffer.Length, _maxRentedBufferSize));
237239
try
238240
{
239-
var copyRead = oldBuffer.Read(rentedBuffer, 0, rentedBuffer.Length);
241+
var copyRead = oldBuffer.Read(rentedBuffer);
240242
while (copyRead > 0)
241243
{
242-
_buffer.Write(rentedBuffer, 0, copyRead);
243-
copyRead = oldBuffer.Read(rentedBuffer, 0, rentedBuffer.Length);
244+
_buffer.Write(rentedBuffer.AsSpan(0, copyRead));
245+
copyRead = oldBuffer.Read(rentedBuffer);
244246
}
245247
}
246248
finally
@@ -250,15 +252,15 @@ public override int Read(byte[] buffer, int offset, int count)
250252
}
251253
else
252254
{
253-
_buffer.Write(_rentedBuffer, 0, (int)oldBuffer.Length);
255+
_buffer.Write(_rentedBuffer.AsSpan(0, (int)oldBuffer.Length));
254256
_bytePool.Return(_rentedBuffer);
255257
_rentedBuffer = null;
256258
}
257259
}
258260

259261
if (read > 0)
260262
{
261-
_buffer.Write(buffer, offset, read);
263+
_buffer.Write(buffer.Slice(0, read));
262264
}
263265
else
264266
{
@@ -268,24 +270,34 @@ public override int Read(byte[] buffer, int offset, int count)
268270
return read;
269271
}
270272

271-
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
273+
public override int Read(byte[] buffer, int offset, int count)
274+
{
275+
return Read(buffer.AsSpan(offset, count));
276+
}
277+
278+
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
279+
{
280+
return ReadAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask();
281+
}
282+
283+
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
272284
{
273285
ThrowIfDisposed();
274-
if (_buffer.Position < _buffer.Length || _completelyBuffered)
286+
287+
if (_completelyBuffered)
275288
{
276289
// Just read from the buffer
277-
return await _buffer.ReadAsync(buffer, offset, (int)Math.Min(count, _buffer.Length - _buffer.Position), cancellationToken);
290+
return await _buffer.ReadAsync(buffer, cancellationToken);
278291
}
279292

280-
int read = await _inner.ReadAsync(buffer, offset, count, cancellationToken);
293+
var read = await _inner.ReadAsync(buffer, cancellationToken);
281294

282295
if (_bufferLimit.HasValue && _bufferLimit - read < _buffer.Length)
283296
{
284-
Dispose();
285297
throw new IOException("Buffer limit exceeded.");
286298
}
287299

288-
if (_inMemory && _buffer.Length + read > _memoryThreshold)
300+
if (_inMemory && _memoryThreshold - read < _buffer.Length)
289301
{
290302
_inMemory = false;
291303
var oldBuffer = _buffer;
@@ -297,11 +309,11 @@ public override async Task<int> ReadAsync(byte[] buffer, int offset, int count,
297309
try
298310
{
299311
// oldBuffer is a MemoryStream, no need to do async reads.
300-
var copyRead = oldBuffer.Read(rentedBuffer, 0, rentedBuffer.Length);
312+
var copyRead = oldBuffer.Read(rentedBuffer);
301313
while (copyRead > 0)
302314
{
303-
await _buffer.WriteAsync(rentedBuffer, 0, copyRead, cancellationToken);
304-
copyRead = oldBuffer.Read(rentedBuffer, 0, rentedBuffer.Length);
315+
await _buffer.WriteAsync(rentedBuffer.AsMemory(0, copyRead), cancellationToken);
316+
copyRead = oldBuffer.Read(rentedBuffer);
305317
}
306318
}
307319
finally
@@ -311,15 +323,15 @@ public override async Task<int> ReadAsync(byte[] buffer, int offset, int count,
311323
}
312324
else
313325
{
314-
await _buffer.WriteAsync(_rentedBuffer, 0, (int)oldBuffer.Length, cancellationToken);
326+
await _buffer.WriteAsync(_rentedBuffer.AsMemory(0, (int)oldBuffer.Length), cancellationToken);
315327
_bytePool.Return(_rentedBuffer);
316328
_rentedBuffer = null;
317329
}
318330
}
319331

320332
if (read > 0)
321333
{
322-
await _buffer.WriteAsync(buffer, offset, read, cancellationToken);
334+
await _buffer.WriteAsync(buffer.Slice(0, read), cancellationToken);
323335
}
324336
else
325337
{
@@ -349,6 +361,39 @@ public override void Flush()
349361
throw new NotSupportedException();
350362
}
351363

364+
public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
365+
{
366+
// If we're completed buffered then copy from the underlying source
367+
if (_completelyBuffered)
368+
{
369+
return _buffer.CopyToAsync(destination, bufferSize, cancellationToken);
370+
}
371+
372+
async Task CopyToAsyncImpl()
373+
{
374+
// At least a 4K buffer
375+
byte[] buffer = _bytePool.Rent(Math.Min(bufferSize, 4096));
376+
try
377+
{
378+
while (true)
379+
{
380+
int bytesRead = await ReadAsync(buffer, cancellationToken);
381+
if (bytesRead == 0)
382+
{
383+
break;
384+
}
385+
await destination.WriteAsync(buffer.AsMemory(0, bytesRead), cancellationToken);
386+
}
387+
}
388+
finally
389+
{
390+
_bytePool.Return(buffer);
391+
}
392+
}
393+
394+
return CopyToAsyncImpl();
395+
}
396+
352397
protected override void Dispose(bool disposing)
353398
{
354399
if (!_disposed)

src/Http/WebUtilities/test/FileBufferingReadStreamTests.cs

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System;
55
using System.Buffers;
66
using System.IO;
7+
using System.Linq;
78
using System.Text;
89
using System.Threading.Tasks;
910
using Moq;
@@ -157,7 +158,6 @@ public void FileBufferingReadStream_SyncReadWithOnDiskLimit_EnforcesLimit()
157158
Assert.Equal("Buffer limit exceeded.", exception.Message);
158159
Assert.False(stream.InMemory);
159160
Assert.NotNull(stream.TempFileName);
160-
Assert.False(File.Exists(tempFileName));
161161
}
162162

163163
Assert.False(File.Exists(tempFileName));
@@ -287,7 +287,6 @@ public async Task FileBufferingReadStream_AsyncReadWithOnDiskLimit_EnforcesLimit
287287
Assert.Equal("Buffer limit exceeded.", exception.Message);
288288
Assert.False(stream.InMemory);
289289
Assert.NotNull(stream.TempFileName);
290-
Assert.False(File.Exists(tempFileName));
291290
}
292291

293292
Assert.False(File.Exists(tempFileName));
@@ -351,6 +350,78 @@ public async Task FileBufferingReadStream_UsingMemoryStream_RentsAndReturnsRente
351350
Assert.False(File.Exists(tempFileName));
352351
}
353352

353+
[Fact]
354+
public async Task CopyToAsyncWorks()
355+
{
356+
var data = Enumerable.Range(0, 1024).Select(b => (byte)b).Reverse().ToArray();
357+
var inner = new MemoryStream(data);
358+
359+
using var stream = new FileBufferingReadStream(inner, 1024 * 1024, bufferLimit: null, GetCurrentDirectory());
360+
361+
var withoutBufferMs = new MemoryStream();
362+
await stream.CopyToAsync(withoutBufferMs);
363+
364+
var withBufferMs = new MemoryStream();
365+
stream.Position = 0;
366+
await stream.CopyToAsync(withBufferMs);
367+
368+
Assert.Equal(data, withoutBufferMs.ToArray());
369+
Assert.Equal(data, withBufferMs.ToArray());
370+
}
371+
372+
[Fact]
373+
public async Task CopyToAsyncWorksWithFileThreshold()
374+
{
375+
var data = Enumerable.Range(0, 1024).Select(b => (byte)b).Reverse().ToArray();
376+
var inner = new MemoryStream(data);
377+
378+
using var stream = new FileBufferingReadStream(inner, 100, bufferLimit: null, GetCurrentDirectory());
379+
380+
var withoutBufferMs = new MemoryStream();
381+
await stream.CopyToAsync(withoutBufferMs);
382+
383+
var withBufferMs = new MemoryStream();
384+
stream.Position = 0;
385+
await stream.CopyToAsync(withBufferMs);
386+
387+
Assert.Equal(data, withoutBufferMs.ToArray());
388+
Assert.Equal(data, withBufferMs.ToArray());
389+
}
390+
391+
[Fact]
392+
public async Task ReadAsyncThenCopyToAsyncWorks()
393+
{
394+
var data = Enumerable.Range(0, 1024).Select(b => (byte)b).ToArray();
395+
var inner = new MemoryStream(data);
396+
397+
using var stream = new FileBufferingReadStream(inner, 1024 * 1024, bufferLimit: null, GetCurrentDirectory());
398+
399+
var withoutBufferMs = new MemoryStream();
400+
var buffer = new byte[100];
401+
await stream.ReadAsync(buffer);
402+
await stream.CopyToAsync(withoutBufferMs);
403+
404+
Assert.Equal(data.AsMemory(0, 100).ToArray(), buffer);
405+
Assert.Equal(data.AsMemory(100).ToArray(), withoutBufferMs.ToArray());
406+
}
407+
408+
[Fact]
409+
public async Task ReadThenCopyToAsyncWorks()
410+
{
411+
var data = Enumerable.Range(0, 1024).Select(b => (byte)b).ToArray();
412+
var inner = new MemoryStream(data);
413+
414+
using var stream = new FileBufferingReadStream(inner, 1024 * 1024, bufferLimit: null, GetCurrentDirectory());
415+
416+
var withoutBufferMs = new MemoryStream();
417+
var buffer = new byte[100];
418+
stream.Read(buffer);
419+
await stream.CopyToAsync(withoutBufferMs);
420+
421+
Assert.Equal(data.AsMemory(0, 100).ToArray(), buffer);
422+
Assert.Equal(data.AsMemory(100).ToArray(), withoutBufferMs.ToArray());
423+
}
424+
354425
private static string GetCurrentDirectory()
355426
{
356427
return AppContext.BaseDirectory;

0 commit comments

Comments
 (0)