Skip to content

Add Session.changeSessionId #835

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 20, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/src/test/java/docs/IndexDocTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public class IndexDocTests {

@Test
public void repositoryDemo() {
RepositoryDemo<Session> demo = new RepositoryDemo<>();
RepositoryDemo<MapSession> demo = new RepositoryDemo<>();
demo.repository = new MapSessionRepository();

demo.demo();
Expand Down Expand Up @@ -82,7 +82,7 @@ public void demo() {

@Test
public void expireRepositoryDemo() {
ExpiringRepositoryDemo<Session> demo = new ExpiringRepositoryDemo<>();
ExpiringRepositoryDemo<MapSession> demo = new ExpiringRepositoryDemo<>();
demo.repository = new MapSessionRepository();

demo.demo();
Expand Down
5 changes: 2 additions & 3 deletions samples/misc/hazelcast/src/main/java/sample/Initializer.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import org.springframework.session.MapSession;
import org.springframework.session.MapSessionRepository;
import org.springframework.session.Session;
import org.springframework.session.SessionRepository;
import org.springframework.session.web.http.SessionRepositoryFilter;

@WebListener
Expand All @@ -48,8 +47,8 @@ public void contextInitialized(ServletContextEvent sce) {
this.instance = createHazelcastInstance();
Map<String, Session> sessions = this.instance.getMap(SESSION_MAP_NAME);

SessionRepository<Session> sessionRepository = new MapSessionRepository(sessions);
SessionRepositoryFilter<Session> filter = new SessionRepositoryFilter<>(
MapSessionRepository sessionRepository = new MapSessionRepository(sessions);
SessionRepositoryFilter<? extends Session> filter = new SessionRepositoryFilter<>(
sessionRepository);

Dynamic fr = sce.getServletContext().addFilter("springSessionFilter", filter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
* @author Rob Winch
* @since 2.0
*/
public class MapReactorSessionRepository implements ReactorSessionRepository<Session> {
public class MapReactorSessionRepository implements ReactorSessionRepository<MapSession> {
/**
* If non-null, this value is used to override
* {@link Session#setMaxInactiveInterval(Duration)}.
Expand Down Expand Up @@ -80,7 +80,7 @@ public MapReactorSessionRepository(Session... sessions) {
}
this.sessions = new ConcurrentHashMap<>();
for (Session session : sessions) {
this.performSave(session);
this.performSave(new MapSession(session));
}
}

Expand All @@ -96,7 +96,7 @@ public MapReactorSessionRepository(Iterable<Session> sessions) {
}
this.sessions = new ConcurrentHashMap<>();
for (Session session : sessions) {
this.performSave(session);
this.performSave(new MapSession(session));
}
}

Expand All @@ -110,15 +110,19 @@ public void setDefaultMaxInactiveInterval(int defaultMaxInactiveInterval) {
this.defaultMaxInactiveInterval = Integer.valueOf(defaultMaxInactiveInterval);
}

public Mono<Void> save(Session session) {
public Mono<Void> save(MapSession session) {
return Mono.fromRunnable(() -> performSave(session));
}

private void performSave(Session session) {
private void performSave(MapSession session) {
if (!session.getId().equals(session.getOriginalId())) {
this.sessions.remove(session.getOriginalId());
session.setOriginalId(session.getId());
}
this.sessions.put(session.getId(), new MapSession(session));
}

public Mono<Session> findById(String id) {
public Mono<MapSession> findById(String id) {
return Mono.defer(() -> {
Session saved = this.sessions.get(id);
if (saved == null) {
Expand All @@ -136,9 +140,9 @@ public Mono<Void> delete(String id) {
return Mono.fromRunnable(() -> this.sessions.remove(id));
}

public Mono<Session> createSession() {
public Mono<MapSession> createSession() {
return Mono.defer(() -> {
Session result = new MapSession();
MapSession result = new MapSession();
if (this.defaultMaxInactiveInterval != null) {
result.setMaxInactiveInterval(
Duration.ofSeconds(this.defaultMaxInactiveInterval));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public final class MapSession implements Session, Serializable {
public static final int DEFAULT_MAX_INACTIVE_INTERVAL_SECONDS = 1800;

private String id;
private String originalId;
private Map<String, Object> sessionAttrs = new HashMap<>();
private Instant creationTime = Instant.now();
private Instant lastAccessedTime = this.creationTime;
Expand All @@ -65,9 +66,10 @@ public final class MapSession implements Session, Serializable {
* Creates a new instance with a secure randomly generated identifier.
*/
public MapSession() {
this(UUID.randomUUID().toString());
this(generateId());
}


/**
* Creates a new instance with the specified id. This is preferred to the default
* constructor when the id is known to prevent unnecessary consumption on entropy
Expand All @@ -77,6 +79,7 @@ public MapSession() {
*/
public MapSession(String id) {
this.id = id;
this.originalId = id;
}

/**
Expand All @@ -90,6 +93,7 @@ public MapSession(Session session) {
throw new IllegalArgumentException("session cannot be null");
}
this.id = session.getId();
this.originalId = this.id;
this.sessionAttrs = new HashMap<>(
session.getAttributeNames().size());
for (String attrName : session.getAttributeNames()) {
Expand All @@ -115,6 +119,20 @@ public String getId() {
return this.id;
}

String getOriginalId() {
return this.originalId;
}

void setOriginalId(String originalId) {
this.originalId = originalId;
}

public String changeSessionId() {
String changedId = generateId();
setId(changedId);
return changedId;
}

public Instant getLastAccessedTime() {
return this.lastAccessedTime;
}
Expand Down Expand Up @@ -188,5 +206,9 @@ public int hashCode() {
return this.id.hashCode();
}

private static String generateId() {
return UUID.randomUUID().toString();
}

private static final long serialVersionUID = 7160779239673823561L;
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
* @author Rob Winch
* @since 1.0
*/
public class MapSessionRepository implements SessionRepository<Session> {
public class MapSessionRepository implements SessionRepository<MapSession> {
/**
* If non-null, this value is used to override
* {@link Session#setMaxInactiveInterval(Duration)}.
Expand Down Expand Up @@ -76,11 +76,15 @@ public void setDefaultMaxInactiveInterval(int defaultMaxInactiveInterval) {
this.defaultMaxInactiveInterval = Integer.valueOf(defaultMaxInactiveInterval);
}

public void save(Session session) {
public void save(MapSession session) {
if (!session.getId().equals(session.getOriginalId())) {
this.sessions.remove(session.getOriginalId());
session.setOriginalId(session.getId());
}
this.sessions.put(session.getId(), new MapSession(session));
}

public Session findById(String id) {
public MapSession findById(String id) {
Session saved = this.sessions.get(id);
if (saved == null) {
return null;
Expand All @@ -96,8 +100,8 @@ public void deleteById(String id) {
this.sessions.remove(id);
}

public Session createSession() {
Session result = new MapSession();
public MapSession createSession() {
MapSession result = new MapSession();
if (this.defaultMaxInactiveInterval != null) {
result.setMaxInactiveInterval(
Duration.ofSeconds(this.defaultMaxInactiveInterval));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ public interface Session {
*/
String getId();

/**
* Changes the session id. After invoking the {@link #getId()} will return a new identifier.
* @return the new session id which {@link #getId()} will now return
*/
String changeSessionId();

/**
* Gets the Object associated with the specified name or null if no Object is
* associated to that name.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@

import java.io.IOException;
import java.time.Instant;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;

import javax.servlet.FilterChain;
import javax.servlet.ServletContext;
Expand Down Expand Up @@ -274,30 +271,7 @@ public String changeSessionId() {
"Cannot change session ID. There is no session associated with this request.");
}

// eagerly get session attributes in case implementation lazily loads them
Map<String, Object> attrs = new HashMap<>();
Enumeration<String> iAttrNames = session.getAttributeNames();
while (iAttrNames.hasMoreElements()) {
String attrName = iAttrNames.nextElement();
Object value = session.getAttribute(attrName);

attrs.put(attrName, value);
}

SessionRepositoryFilter.this.sessionRepository.deleteById(session.getId());
HttpSessionWrapper original = getCurrentSession();
setCurrentSession(null);

HttpSessionWrapper newSession = getSession();
original.setSession(newSession.getSession());

newSession.setMaxInactiveInterval(session.getMaxInactiveInterval());
for (Map.Entry<String, Object> attr : attrs.entrySet()) {
String attrName = attr.getKey();
Object attrValue = attr.getValue();
newSession.setAttribute(attrName, attrValue);
}
return newSession.getId();
return getCurrentSession().getSession().changeSessionId();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,4 +145,32 @@ public void createSessionWhenCustomMaxInactiveIntervalThenCustomMaxInactiveInter
assertThat(session.getMaxInactiveInterval())
.isEqualTo(expectedMaxInterval);
}

@Test
public void changeSessionIdWhenNotYetSaved() {
MapSession createSession = this.repository.createSession().block();

String originalId = createSession.getId();
createSession.changeSessionId();

this.repository.save(createSession).block();

assertThat(this.repository.findById(originalId).block()).isNull();
assertThat(this.repository.findById(createSession.getId()).block()).isNotNull();
}

@Test
public void changeSessionIdWhenSaved() {
MapSession createSession = this.repository.createSession().block();

this.repository.save(createSession).block();

String originalId = createSession.getId();
createSession.changeSessionId();

this.repository.save(createSession).block();

assertThat(this.repository.findById(originalId).block()).isNull();
assertThat(this.repository.findById(createSession.getId()).block()).isNotNull();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,32 @@ public void createSessionCustomDefaultExpiration() {
assertThat(session.getMaxInactiveInterval())
.isEqualTo(expectedMaxInterval);
}

@Test
public void changeSessionIdWhenNotYetSaved() {
MapSession createSession = this.repository.createSession();

String originalId = createSession.getId();
createSession.changeSessionId();

this.repository.save(createSession);

assertThat(this.repository.findById(originalId)).isNull();
assertThat(this.repository.findById(createSession.getId())).isNotNull();
}

@Test
public void changeSessionIdWhenSaved() {
MapSession createSession = this.repository.createSession();

this.repository.save(createSession);

String originalId = createSession.getId();
createSession.changeSessionId();

this.repository.save(createSession);

assertThat(this.repository.findById(originalId)).isNull();
assertThat(this.repository.findById(createSession.getId())).isNotNull();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ public Instant getCreationTime() {
return Instant.EPOCH;
}

public String changeSessionId() {
throw new UnsupportedOperationException();
}

public String getId() {
return "id";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ public class SessionRepositoryFilterTests {

private Map<String, Session> sessions;

private SessionRepository<Session> sessionRepository;
private SessionRepository<MapSession> sessionRepository;

private SessionRepositoryFilter<Session> filter;
private SessionRepositoryFilter<MapSession> filter;

private MockHttpServletRequest request;

Expand Down Expand Up @@ -422,7 +422,7 @@ public void doFilter(HttpServletRequest wrappedRequest) {
public void doFilterSetsCookieIfChanged() throws Exception {
this.sessionRepository = new MapSessionRepository() {
@Override
public Session findById(String id) {
public MapSession findById(String id) {
return createSession();
}
};
Expand Down Expand Up @@ -1256,7 +1256,7 @@ public void doFilter(HttpServletRequest wrappedRequest,
@SuppressWarnings("unchecked")
public void doFilterRequestSessionNoRequestSessionNoSessionRepositoryInteractions()
throws Exception {
SessionRepository<Session> sessionRepository = spy(new MapSessionRepository());
SessionRepository<MapSession> sessionRepository = spy(new MapSessionRepository());

this.filter = new SessionRepositoryFilter<>(sessionRepository);

Expand All @@ -1283,7 +1283,7 @@ public void doFilter(HttpServletRequest wrappedRequest,

@Test
public void doFilterLazySessionCreation() throws Exception {
SessionRepository<Session> sessionRepository = spy(new MapSessionRepository());
SessionRepository<MapSession> sessionRepository = spy(new MapSessionRepository());

this.filter = new SessionRepositoryFilter<>(sessionRepository);

Expand All @@ -1299,9 +1299,9 @@ public void doFilter(HttpServletRequest wrappedRequest,

@Test
public void doFilterLazySessionUpdates() throws Exception {
Session session = this.sessionRepository.createSession();
MapSession session = this.sessionRepository.createSession();
this.sessionRepository.save(session);
SessionRepository<Session> sessionRepository = spy(this.sessionRepository);
SessionRepository<MapSession> sessionRepository = spy(this.sessionRepository);
setSessionCookie(session.getId());

this.filter = new SessionRepositoryFilter<>(sessionRepository);
Expand Down
Loading