@@ -51,10 +51,37 @@ public IAsyncEnumerator<TResult> GetAsyncEnumerator(CancellationToken cancellati
5151 {
5252 ( ( CancellationTokenSource ) ctsState ) . Cancel ( ) ;
5353 } , _cts ) ;
54+
55+ return new CancelableEnumerator < TResult > ( _asyncEnumerable . GetAsyncEnumerator ( ) , registration ) ;
5456 }
5557
5658 return enumerator ;
5759 }
60+
61+ private class CancelableEnumerator < T > : IAsyncEnumerator < T >
62+ {
63+ private IAsyncEnumerator < T > _asyncEnumerator ;
64+ private readonly CancellationTokenRegistration _cancellationTokenRegistration ;
65+
66+ public T Current => ( T ) _asyncEnumerator . Current ;
67+
68+ public CancelableEnumerator ( IAsyncEnumerator < T > asyncEnumerator , CancellationTokenRegistration registration )
69+ {
70+ _asyncEnumerator = asyncEnumerator ;
71+ _cancellationTokenRegistration = registration ;
72+ }
73+
74+ public ValueTask < bool > MoveNextAsync ( )
75+ {
76+ return _asyncEnumerator . MoveNextAsync ( ) ;
77+ }
78+
79+ public ValueTask DisposeAsync ( )
80+ {
81+ _cancellationTokenRegistration . Dispose ( ) ;
82+ return _asyncEnumerator . DisposeAsync ( ) ;
83+ }
84+ }
5885 }
5986
6087 /// <summary>Converts an IAsyncEnumerable of T to an IAsyncEnumerable of object.</summary>
@@ -71,6 +98,10 @@ public CancelableAsyncEnumerable(IAsyncEnumerable<T> asyncEnumerable, Cancellati
7198
7299 public IAsyncEnumerator < object > GetAsyncEnumerator ( CancellationToken cancellationToken = default )
73100 {
101+ // Assume that this will be iterated through with await foreach which always passes a default token.
102+ // Instead use the token from the ctor.
103+ Debug . Assert ( cancellationToken == default ) ;
104+
74105 var enumeratorOfT = _asyncEnumerable . GetAsyncEnumerator ( _cancellationToken ) ;
75106 return enumeratorOfT as IAsyncEnumerator < object > ?? new BoxedAsyncEnumerator ( enumeratorOfT ) ;
76107 }
0 commit comments