From f92f9c1d5b04aefb467355576e63cc2cc6d78d92 Mon Sep 17 00:00:00 2001 From: rstoyanchev Date: Mon, 24 Feb 2025 11:44:35 +0000 Subject: [PATCH] Fix handling of timeout in SseEmitter Closes gh-34426 --- .../annotation/ResponseBodyEmitter.java | 92 ++++++++++++++----- 1 file changed, 70 insertions(+), 22 deletions(-) diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java index e4e5d0e6b7cb..afa3008cdc1a 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2024 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Set; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import org.springframework.http.MediaType; @@ -73,21 +73,20 @@ public class ResponseBodyEmitter { @Nullable private Handler handler; + private final AtomicReference state = new AtomicReference<>(State.START); + /** Store send data before handler is initialized. */ private final Set earlySendAttempts = new LinkedHashSet<>(8); - /** Store successful completion before the handler is initialized. */ - private final AtomicBoolean complete = new AtomicBoolean(); - /** Store an error before the handler is initialized. */ @Nullable private Throwable failure; - private final DefaultCallback timeoutCallback = new DefaultCallback(); + private final TimeoutCallback timeoutCallback = new TimeoutCallback(); private final ErrorCallback errorCallback = new ErrorCallback(); - private final DefaultCallback completionCallback = new DefaultCallback(); + private final CompletionCallback completionCallback = new CompletionCallback(); /** @@ -128,7 +127,7 @@ synchronized void initialize(Handler handler) throws IOException { this.earlySendAttempts.clear(); } - if (this.complete.get()) { + if (this.state.get() == State.COMPLETE) { if (this.failure != null) { this.handler.completeWithError(this.failure); } @@ -144,7 +143,7 @@ synchronized void initialize(Handler handler) throws IOException { } void initializeWithError(Throwable ex) { - if (this.complete.compareAndSet(false, true)) { + if (this.state.compareAndSet(State.START, State.COMPLETE)) { this.failure = ex; this.earlySendAttempts.clear(); this.errorCallback.accept(ex); @@ -186,8 +185,7 @@ public void send(Object object) throws IOException { * @throws java.lang.IllegalStateException wraps any other errors */ public synchronized void send(Object object, @Nullable MediaType mediaType) throws IOException { - Assert.state(!this.complete.get(), () -> "ResponseBodyEmitter has already completed" + - (this.failure != null ? " with error: " + this.failure : "")); + assertNotComplete(); if (this.handler != null) { try { this.handler.send(object, mediaType); @@ -214,11 +212,15 @@ public synchronized void send(Object object, @Nullable MediaType mediaType) thro * @since 6.0.12 */ public synchronized void send(Set items) throws IOException { - Assert.state(!this.complete.get(), () -> "ResponseBodyEmitter has already completed" + - (this.failure != null ? " with error: " + this.failure : "")); + assertNotComplete(); sendInternal(items); } + private void assertNotComplete() { + Assert.state(this.state.get() == State.START, () -> "ResponseBodyEmitter has already completed" + + (this.failure != null ? " with error: " + this.failure : "")); + } + private void sendInternal(Set items) throws IOException { if (items.isEmpty()) { return; @@ -248,7 +250,7 @@ private void sendInternal(Set items) throws IOException { * related events such as an error while {@link #send(Object) sending}. */ public void complete() { - if (this.complete.compareAndSet(false, true) && this.handler != null) { + if (trySetComplete() && this.handler != null) { this.handler.complete(); } } @@ -265,7 +267,7 @@ public void complete() { * {@link #send(Object) sending}. */ public void completeWithError(Throwable ex) { - if (this.complete.compareAndSet(false, true)) { + if (trySetComplete()) { this.failure = ex; if (this.handler != null) { this.handler.completeWithError(ex); @@ -273,6 +275,11 @@ public void completeWithError(Throwable ex) { } } + private boolean trySetComplete() { + return (this.state.compareAndSet(State.START, State.COMPLETE) || + (this.state.compareAndSet(State.TIMEOUT, State.COMPLETE))); + } + /** * Register code to invoke when the async request times out. This method is * called from a container thread when an async request times out. @@ -369,7 +376,7 @@ public MediaType getMediaType() { } - private class DefaultCallback implements Runnable { + private class TimeoutCallback implements Runnable { private final List delegates = new ArrayList<>(1); @@ -379,9 +386,10 @@ public synchronized void addDelegate(Runnable delegate) { @Override public void run() { - ResponseBodyEmitter.this.complete.compareAndSet(false, true); - for (Runnable delegate : this.delegates) { - delegate.run(); + if (ResponseBodyEmitter.this.state.compareAndSet(State.START, State.TIMEOUT)) { + for (Runnable delegate : this.delegates) { + delegate.run(); + } } } } @@ -397,11 +405,51 @@ public synchronized void addDelegate(Consumer callback) { @Override public void accept(Throwable t) { - ResponseBodyEmitter.this.complete.compareAndSet(false, true); - for(Consumer delegate : this.delegates) { - delegate.accept(t); + if (ResponseBodyEmitter.this.state.compareAndSet(State.START, State.COMPLETE)) { + for (Consumer delegate : this.delegates) { + delegate.accept(t); + } + } + } + } + + + private class CompletionCallback implements Runnable { + + private final List delegates = new ArrayList<>(1); + + public synchronized void addDelegate(Runnable delegate) { + this.delegates.add(delegate); + } + + @Override + public void run() { + if (ResponseBodyEmitter.this.state.compareAndSet(State.START, State.COMPLETE)) { + for (Runnable delegate : this.delegates) { + delegate.run(); + } } } } + + /** + * Represents a state for {@link ResponseBodyEmitter}. + *

+	 *     START ----+
+	 *       |       |
+	 *       v       |
+	 *    TIMEOUT    |
+	 *       |       |
+	 *       v       |
+	 *   COMPLETE <--+
+	 * 
+ * @since 6.2.4 + */ + private enum State { + START, + TIMEOUT, // handling a timeout + COMPLETE + } + }