|
6 | 6 | using System.IO;
|
7 | 7 | using System.Threading;
|
8 | 8 | using System.Threading.Tasks;
|
| 9 | +using Microsoft.AspNet.Server.Kestrel.Infrastructure; |
9 | 10 | using Microsoft.Extensions.Primitives;
|
10 | 11 |
|
11 | 12 | namespace Microsoft.AspNet.Server.Kestrel.Http
|
@@ -38,35 +39,40 @@ protected MessageBody(FrameContext context)
|
38 | 39 | return result;
|
39 | 40 | }
|
40 | 41 |
|
41 |
| - public async Task Consume(CancellationToken cancellationToken = default(CancellationToken)) |
| 42 | + public Task Consume(CancellationToken cancellationToken = default(CancellationToken)) |
42 | 43 | {
|
43 |
| - Task<int> result; |
44 |
| - var send100checked = false; |
45 |
| - do |
| 44 | + var result = ReadAsyncImplementation(default(ArraySegment<byte>), cancellationToken); |
| 45 | + if (!result.IsCompleted) |
46 | 46 | {
|
47 |
| - result = ReadAsyncImplementation(default(ArraySegment<byte>), cancellationToken); |
48 |
| - if (!result.IsCompleted) |
49 |
| - { |
50 |
| - if (!send100checked) |
51 |
| - { |
52 |
| - if (Interlocked.Exchange(ref _send100Continue, 0) == 1) |
53 |
| - { |
54 |
| - _context.FrameControl.ProduceContinue(); |
55 |
| - } |
56 |
| - send100checked = true; |
57 |
| - } |
58 |
| - } |
59 |
| - else if (result.GetAwaiter().GetResult() == 0) |
60 |
| - { |
61 |
| - // Completed Task, end of stream |
62 |
| - return; |
63 |
| - } |
64 |
| - else |
| 47 | + if (Interlocked.Exchange(ref _send100Continue, 0) == 1) |
65 | 48 | {
|
66 |
| - // Completed Task, get next Task rather than await |
67 |
| - continue; |
| 49 | + _context.FrameControl.ProduceContinue(); |
68 | 50 | }
|
69 |
| - } while (await result != 0); |
| 51 | + |
| 52 | + return ConsumeAwaited(result, cancellationToken); |
| 53 | + } |
| 54 | + else if (result.GetAwaiter().GetResult() == 0) |
| 55 | + { |
| 56 | + // Completed Task, end of stream |
| 57 | + return TaskUtilities.CompletedTask; |
| 58 | + } |
| 59 | + else |
| 60 | + { |
| 61 | + // Completed Task, but non-zero get next Task and await |
| 62 | + return ConsumeAwaited(ReadAsyncImplementation(default(ArraySegment<byte>), cancellationToken), cancellationToken); |
| 63 | + } |
| 64 | + } |
| 65 | + |
| 66 | + private async Task ConsumeAwaited(Task<int> currentTask, CancellationToken cancellationToken) |
| 67 | + { |
| 68 | + var count = await currentTask; |
| 69 | + |
| 70 | + if (count == 0) return; |
| 71 | + |
| 72 | + while (await ReadAsyncImplementation(default(ArraySegment<byte>), cancellationToken) != 0) |
| 73 | + { |
| 74 | + // Consume until complete |
| 75 | + } |
70 | 76 | }
|
71 | 77 |
|
72 | 78 | public abstract Task<int> ReadAsyncImplementation(ArraySegment<byte> buffer, CancellationToken cancellationToken);
|
@@ -156,17 +162,37 @@ public ForContentLength(bool keepAlive, int contentLength, FrameContext context)
|
156 | 162 | _inputLength = _contentLength;
|
157 | 163 | }
|
158 | 164 |
|
159 |
| - public override async Task<int> ReadAsyncImplementation(ArraySegment<byte> buffer, CancellationToken cancellationToken) |
| 165 | + public override Task<int> ReadAsyncImplementation(ArraySegment<byte> buffer, CancellationToken cancellationToken) |
160 | 166 | {
|
161 | 167 | var input = _context.SocketInput;
|
162 | 168 |
|
163 | 169 | var limit = buffer.Array == null ? _inputLength : Math.Min(buffer.Count, _inputLength);
|
164 | 170 | if (limit == 0)
|
165 | 171 | {
|
166 |
| - return 0; |
| 172 | + return TaskUtilities.ZeroTask; |
| 173 | + } |
| 174 | + |
| 175 | + var task = _context.SocketInput.ReadAsync(buffer.Array, buffer.Offset, limit); |
| 176 | + |
| 177 | + if (task.IsCompleted) |
| 178 | + { |
| 179 | + var actual = task.GetAwaiter().GetResult(); |
| 180 | + if (actual == 0) |
| 181 | + { |
| 182 | + throw new InvalidDataException("Unexpected end of request content"); |
| 183 | + } |
| 184 | + _inputLength -= actual; |
| 185 | + return task; |
167 | 186 | }
|
| 187 | + else |
| 188 | + { |
| 189 | + return ReadAsyncImplementationAwaited(task); |
| 190 | + } |
| 191 | + } |
168 | 192 |
|
169 |
| - var actual = await _context.SocketInput.ReadAsync(buffer.Array, buffer.Offset, limit); |
| 193 | + private async Task<int> ReadAsyncImplementationAwaited(Task<int> currentTask) |
| 194 | + { |
| 195 | + var actual = await currentTask; |
170 | 196 | _inputLength -= actual;
|
171 | 197 |
|
172 | 198 | if (actual == 0)
|
|
0 commit comments