18
18
19
19
import java .io .IOException ;
20
20
import java .util .Collection ;
21
+ import java .util .concurrent .atomic .AtomicBoolean ;
21
22
import javax .servlet .AsyncContext ;
22
23
import javax .servlet .AsyncEvent ;
23
24
import javax .servlet .AsyncListener ;
25
+ import javax .servlet .DispatcherType ;
24
26
import javax .servlet .Servlet ;
25
27
import javax .servlet .ServletConfig ;
28
+ import javax .servlet .ServletException ;
26
29
import javax .servlet .ServletRegistration ;
27
30
import javax .servlet .ServletRequest ;
28
31
import javax .servlet .ServletResponse ;
41
44
import org .springframework .http .HttpMethod ;
42
45
import org .springframework .lang .Nullable ;
43
46
import org .springframework .util .Assert ;
47
+ import org .springframework .web .util .NestedServletException ;
44
48
45
49
/**
46
50
* Adapt {@link HttpHandler} to an {@link HttpServlet} using Servlet Async
@@ -56,9 +60,10 @@ public class ServletHttpHandlerAdapter implements Servlet {
56
60
57
61
private static final Log logger = LogFactory .getLog (ServletHttpHandlerAdapter .class );
58
62
59
-
60
63
private static final int DEFAULT_BUFFER_SIZE = 8192 ;
61
64
65
+ private static final String WRITE_ERROR_ATTRIBUTE_NAME = ServletHttpHandlerAdapter .class .getName () + ".ERROR" ;
66
+
62
67
63
68
private final HttpHandler httpHandler ;
64
69
@@ -151,7 +156,14 @@ private String getServletPath(ServletConfig config) {
151
156
152
157
153
158
@ Override
154
- public void service (ServletRequest request , ServletResponse response ) throws IOException {
159
+ public void service (ServletRequest request , ServletResponse response ) throws ServletException , IOException {
160
+
161
+ if (DispatcherType .ASYNC .equals (request .getDispatcherType ())) {
162
+ Throwable ex = (Throwable ) request .getAttribute (WRITE_ERROR_ATTRIBUTE_NAME );
163
+ Assert .notNull (ex , "Unexpected async dispatch" );
164
+ throw new NestedServletException ("Write publisher error" , ex );
165
+ }
166
+
155
167
// Start async before Read/WriteListener registration
156
168
AsyncContext asyncContext = request .startAsync ();
157
169
asyncContext .setTimeout (-1 );
@@ -163,9 +175,11 @@ public void service(ServletRequest request, ServletResponse response) throws IOE
163
175
httpResponse = new HttpHeadResponseDecorator (httpResponse );
164
176
}
165
177
166
- asyncContext .addListener (ERROR_LISTENER );
178
+ AtomicBoolean isCompleted = new AtomicBoolean ();
179
+ HandlerResultAsyncListener listener = new HandlerResultAsyncListener (isCompleted );
180
+ asyncContext .addListener (listener );
167
181
168
- HandlerResultSubscriber subscriber = new HandlerResultSubscriber (asyncContext );
182
+ HandlerResultSubscriber subscriber = new HandlerResultSubscriber (asyncContext , isCompleted );
169
183
this .httpHandler .handle (httpRequest , httpResponse ).subscribe (subscriber );
170
184
}
171
185
@@ -199,9 +213,9 @@ public void destroy() {
199
213
* We cannot combine ERROR_LISTENER and HandlerResultSubscriber due to:
200
214
* https://issues.jboss.org/browse/WFLY-8515
201
215
*/
202
- private static void runIfAsyncNotComplete (AsyncContext asyncContext , Runnable task ) {
216
+ private static void runIfAsyncNotComplete (AsyncContext asyncContext , AtomicBoolean isCompleted , Runnable task ) {
203
217
try {
204
- if (asyncContext .getRequest ().isAsyncStarted ()) {
218
+ if (asyncContext .getRequest ().isAsyncStarted () && isCompleted . compareAndSet ( false , true ) ) {
205
219
task .run ();
206
220
}
207
221
}
@@ -212,18 +226,27 @@ private static void runIfAsyncNotComplete(AsyncContext asyncContext, Runnable ta
212
226
}
213
227
214
228
215
- private final static AsyncListener ERROR_LISTENER = new AsyncListener () {
229
+ private static class HandlerResultAsyncListener implements AsyncListener {
230
+
231
+ private final AtomicBoolean isCompleted ;
232
+
233
+
234
+ public HandlerResultAsyncListener (AtomicBoolean isCompleted ) {
235
+ this .isCompleted = isCompleted ;
236
+ }
216
237
217
238
@ Override
218
239
public void onTimeout (AsyncEvent event ) {
240
+ logger .debug ("Timeout notification from Servlet container" );
219
241
AsyncContext context = event .getAsyncContext ();
220
- runIfAsyncNotComplete (context , context ::complete );
242
+ runIfAsyncNotComplete (context , this . isCompleted , context ::complete );
221
243
}
222
244
223
245
@ Override
224
246
public void onError (AsyncEvent event ) {
247
+ logger .debug ("Error notification from Servlet container" );
225
248
AsyncContext context = event .getAsyncContext ();
226
- runIfAsyncNotComplete (context , context ::complete );
249
+ runIfAsyncNotComplete (context , this . isCompleted , context ::complete );
227
250
}
228
251
229
252
@ Override
@@ -242,8 +265,12 @@ private class HandlerResultSubscriber implements Subscriber<Void> {
242
265
243
266
private final AsyncContext asyncContext ;
244
267
245
- public HandlerResultSubscriber (AsyncContext asyncContext ) {
268
+ private final AtomicBoolean isCompleted ;
269
+
270
+
271
+ public HandlerResultSubscriber (AsyncContext asyncContext , AtomicBoolean isCompleted ) {
246
272
this .asyncContext = asyncContext ;
273
+ this .isCompleted = isCompleted ;
247
274
}
248
275
249
276
@ Override
@@ -258,20 +285,30 @@ public void onNext(Void aVoid) {
258
285
259
286
@ Override
260
287
public void onError (Throwable ex ) {
261
- runIfAsyncNotComplete (this .asyncContext , () -> {
262
- logger .error ("Could not complete request" , ex );
263
- HttpServletResponse response = (HttpServletResponse ) this .asyncContext .getResponse ();
264
- response .setStatus (HttpServletResponse .SC_INTERNAL_SERVER_ERROR );
265
- this .asyncContext .complete ();
288
+ logger .error ("Handling completed with error" , ex );
289
+ runIfAsyncNotComplete (this .asyncContext , this .isCompleted , () -> {
290
+ if (this .asyncContext .getResponse ().isCommitted ()) {
291
+ logger .debug ("Dispatching into container to raise error" );
292
+ this .asyncContext .getRequest ().setAttribute (WRITE_ERROR_ATTRIBUTE_NAME , ex );
293
+ this .asyncContext .dispatch ();
294
+ }
295
+ else {
296
+ try {
297
+ logger .debug ("Setting response status code to 500" );
298
+ this .asyncContext .getResponse ().resetBuffer ();
299
+ ((HttpServletResponse ) this .asyncContext .getResponse ()).setStatus (500 );
300
+ }
301
+ finally {
302
+ this .asyncContext .complete ();
303
+ }
304
+ }
266
305
});
267
306
}
268
307
269
308
@ Override
270
309
public void onComplete () {
271
- runIfAsyncNotComplete (this .asyncContext , () -> {
272
- logger .debug ("Successfully completed request" );
273
- this .asyncContext .complete ();
274
- });
310
+ logger .debug ("Handling completed with success" );
311
+ runIfAsyncNotComplete (this .asyncContext , this .isCompleted , this .asyncContext ::complete );
275
312
}
276
313
}
277
314
0 commit comments