Skip to content

Commit 6824765

Browse files
committed
crypto/tls: add WrapSession and UnwrapSession
There was a bug in TestResumption: the first ExpiredSessionTicket was inserting a ticket far in the future, so the second ExpiredSessionTicket wasn't actually supposed to fail. However, there was a bug in checkForResumption->sendSessionTicket, too: if a session was not resumed because it was too old, its createdAt was still persisted in the next ticket. The two bugs used to cancel each other out. For #60105 Fixes #19199 Change-Id: Ic9b2aab943dcbf0de62b8758a6195319dc286e2f Reviewed-on: https://go-review.googlesource.com/c/go/+/496821 Run-TryBot: Filippo Valsorda <[email protected]> TryBot-Result: Gopher Robot <[email protected]> Reviewed-by: Roland Shoemaker <[email protected]> Reviewed-by: Damien Neil <[email protected]>
1 parent 371ebe7 commit 6824765

File tree

7 files changed

+182
-55
lines changed

7 files changed

+182
-55
lines changed

api/next/60105.txt

+4
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,7 @@ pkg crypto/tls, method (*SessionState) Bytes() ([]uint8, error) #60105
33
pkg crypto/tls, type SessionState struct #60105
44
pkg crypto/tls, func NewResumptionState([]uint8, *SessionState) (*ClientSessionState, error) #60105
55
pkg crypto/tls, method (*ClientSessionState) ResumptionState() ([]uint8, *SessionState, error) #60105
6+
pkg crypto/tls, method (*Config) DecryptTicket([]uint8, ConnectionState) (*SessionState, error) #60105
7+
pkg crypto/tls, method (*Config) EncryptTicket(ConnectionState, *SessionState) ([]uint8, error) #60105
8+
pkg crypto/tls, type Config struct, UnwrapSession func([]uint8, ConnectionState) (*SessionState, error) #60105
9+
pkg crypto/tls, type Config struct, WrapSession func(ConnectionState, *SessionState) ([]uint8, error) #60105

src/crypto/tls/common.go

+31
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,35 @@ type Config struct {
673673
// session resumption. It is only used by clients.
674674
ClientSessionCache ClientSessionCache
675675

676+
// UnwrapSession is called on the server to turn a ticket/identity
677+
// previously produced by [WrapSession] into a usable session.
678+
//
679+
// UnwrapSession will usually either decrypt a session state in the ticket
680+
// (for example with [Config.EncryptTicket]), or use the ticket as a handle
681+
// to recover a previously stored state. It must use [ParseSessionState] to
682+
// deserialize the session state.
683+
//
684+
// If UnwrapSession returns an error, the connection is terminated. If it
685+
// returns (nil, nil), the session is ignored. crypto/tls may still choose
686+
// not to resume the returned session.
687+
UnwrapSession func(identity []byte, cs ConnectionState) (*SessionState, error)
688+
689+
// WrapSession is called on the server to produce a session ticket/identity.
690+
//
691+
// WrapSession must serialize the session state with [SessionState.Bytes].
692+
// It may then encrypt the serialized state (for example with
693+
// [Config.DecryptTicket]) and use it as the ticket, or store the state and
694+
// return a handle for it.
695+
//
696+
// If WrapSession returns an error, the connection is terminated.
697+
//
698+
// Warning: the return value will be exposed on the wire and to clients in
699+
// plaintext. The application is in charge of encrypting and authenticating
700+
// it (and rotating keys) or returning high-entropy identifiers. Failing to
701+
// do so correctly can compromise current, previous, and future connections
702+
// depending on the protocol version.
703+
WrapSession func(ConnectionState, *SessionState) ([]byte, error)
704+
676705
// MinVersion contains the minimum TLS version that is acceptable.
677706
//
678707
// By default, TLS 1.2 is currently used as the minimum when acting as a
@@ -794,6 +823,8 @@ func (c *Config) Clone() *Config {
794823
SessionTicketsDisabled: c.SessionTicketsDisabled,
795824
SessionTicketKey: c.SessionTicketKey,
796825
ClientSessionCache: c.ClientSessionCache,
826+
UnwrapSession: c.UnwrapSession,
827+
WrapSession: c.WrapSession,
797828
MinVersion: c.MinVersion,
798829
MaxVersion: c.MaxVersion,
799830
CurvePreferences: c.CurvePreferences,

src/crypto/tls/handshake_client_test.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,7 @@ func testResumption(t *testing.T, version uint16) {
900900
}
901901

902902
testResumeState := func(test string, didResume bool) {
903+
t.Helper()
903904
_, hs, err := testHandshake(t, clientConfig, serverConfig)
904905
if err != nil {
905906
t.Fatalf("%s: handshake failed: %s", test, err)
@@ -985,17 +986,18 @@ func testResumption(t *testing.T, version uint16) {
985986

986987
// Age the session ticket a bit at a time, but don't expire it.
987988
d := 0 * time.Hour
989+
serverConfig.Time = func() time.Time { return time.Now().Add(d) }
990+
deleteTicket()
991+
testResumeState("GetFreshSessionTicket", false)
988992
for i := 0; i < 13; i++ {
989993
d += 12 * time.Hour
990-
serverConfig.Time = func() time.Time { return time.Now().Add(d) }
991994
testResumeState("OldSessionTicket", true)
992995
}
993996
// Expire it (now a little more than 7 days) and make sure a full
994997
// handshake occurs for TLS 1.2. Resumption should still occur for
995998
// TLS 1.3 since the client should be using a fresh ticket sent over
996999
// by the server.
9971000
d += 12 * time.Hour
998-
serverConfig.Time = func() time.Time { return time.Now().Add(d) }
9991001
if version == VersionTLS13 {
10001002
testResumeState("ExpiredSessionTicket", true)
10011003
} else {

src/crypto/tls/handshake_server.go

+56-32
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,11 @@ func (hs *serverHandshakeState) handshake() error {
7070

7171
// For an overview of TLS handshaking, see RFC 5246, Section 7.3.
7272
c.buffering = true
73-
if hs.checkForResumption() {
73+
if err := hs.checkForResumption(); err != nil {
74+
return err
75+
}
76+
if hs.sessionState != nil {
7477
// The client has included a session ticket and so we do an abbreviated handshake.
75-
c.didResume = true
7678
if err := hs.doResumeHandshake(); err != nil {
7779
return err
7880
}
@@ -399,65 +401,80 @@ func (hs *serverHandshakeState) cipherSuiteOk(c *cipherSuite) bool {
399401
}
400402

401403
// checkForResumption reports whether we should perform resumption on this connection.
402-
func (hs *serverHandshakeState) checkForResumption() bool {
404+
func (hs *serverHandshakeState) checkForResumption() error {
403405
c := hs.c
404406

405407
if c.config.SessionTicketsDisabled {
406-
return false
408+
return nil
407409
}
408410

409-
plaintext := c.decryptTicket(hs.clientHello.sessionTicket)
410-
if plaintext == nil {
411-
return false
412-
}
413-
ss, err := ParseSessionState(plaintext)
414-
if err != nil {
415-
return false
411+
var sessionState *SessionState
412+
if c.config.UnwrapSession != nil {
413+
ss, err := c.config.UnwrapSession(hs.clientHello.sessionTicket, c.connectionStateLocked())
414+
if err != nil {
415+
return err
416+
}
417+
if ss == nil {
418+
return nil
419+
}
420+
sessionState = ss
421+
} else {
422+
plaintext := c.config.decryptTicket(hs.clientHello.sessionTicket, c.ticketKeys)
423+
if plaintext == nil {
424+
return nil
425+
}
426+
ss, err := ParseSessionState(plaintext)
427+
if err != nil {
428+
return nil
429+
}
430+
sessionState = ss
416431
}
417-
hs.sessionState = ss
418432

419433
// TLS 1.2 tickets don't natively have a lifetime, but we want to avoid
420434
// re-wrapping the same master secret in different tickets over and over for
421435
// too long, weakening forward secrecy.
422-
createdAt := time.Unix(int64(hs.sessionState.createdAt), 0)
436+
createdAt := time.Unix(int64(sessionState.createdAt), 0)
423437
if c.config.time().Sub(createdAt) > maxSessionTicketLifetime {
424-
return false
438+
return nil
425439
}
426440

427441
// Never resume a session for a different TLS version.
428-
if c.vers != hs.sessionState.version {
429-
return false
442+
if c.vers != sessionState.version {
443+
return nil
430444
}
431445

432446
cipherSuiteOk := false
433447
// Check that the client is still offering the ciphersuite in the session.
434448
for _, id := range hs.clientHello.cipherSuites {
435-
if id == hs.sessionState.cipherSuite {
449+
if id == sessionState.cipherSuite {
436450
cipherSuiteOk = true
437451
break
438452
}
439453
}
440454
if !cipherSuiteOk {
441-
return false
455+
return nil
442456
}
443457

444458
// Check that we also support the ciphersuite from the session.
445-
hs.suite = selectCipherSuite([]uint16{hs.sessionState.cipherSuite},
459+
suite := selectCipherSuite([]uint16{sessionState.cipherSuite},
446460
c.config.cipherSuites(), hs.cipherSuiteOk)
447-
if hs.suite == nil {
448-
return false
461+
if suite == nil {
462+
return nil
449463
}
450464

451-
sessionHasClientCerts := len(hs.sessionState.peerCertificates) != 0
465+
sessionHasClientCerts := len(sessionState.peerCertificates) != 0
452466
needClientCerts := requiresClientCert(c.config.ClientAuth)
453467
if needClientCerts && !sessionHasClientCerts {
454-
return false
468+
return nil
455469
}
456470
if sessionHasClientCerts && c.config.ClientAuth == NoClientCert {
457-
return false
471+
return nil
458472
}
459473

460-
return true
474+
hs.sessionState = sessionState
475+
hs.suite = suite
476+
c.didResume = true
477+
return nil
461478
}
462479

463480
func (hs *serverHandshakeState) doResumeHandshake() error {
@@ -769,13 +786,20 @@ func (hs *serverHandshakeState) sendSessionTicket() error {
769786
// the original time it was created.
770787
state.createdAt = hs.sessionState.createdAt
771788
}
772-
stateBytes, err := state.Bytes()
773-
if err != nil {
774-
return err
775-
}
776-
m.ticket, err = c.encryptTicket(stateBytes)
777-
if err != nil {
778-
return err
789+
if c.config.WrapSession != nil {
790+
m.ticket, err = c.config.WrapSession(c.connectionStateLocked(), state)
791+
if err != nil {
792+
return err
793+
}
794+
} else {
795+
stateBytes, err := state.Bytes()
796+
if err != nil {
797+
return err
798+
}
799+
m.ticket, err = c.config.encryptTicket(stateBytes, c.ticketKeys)
800+
if err != nil {
801+
return err
802+
}
779803
}
780804

781805
if _, err := hs.c.writeHandshakeRecord(m, &hs.finishedHash); err != nil {

src/crypto/tls/handshake_server_tls13.go

+37-13
Original file line numberDiff line numberDiff line change
@@ -275,12 +275,29 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error {
275275
break
276276
}
277277

278-
plaintext := c.decryptTicket(identity.label)
279-
if plaintext == nil {
280-
continue
278+
var sessionState *SessionState
279+
if c.config.UnwrapSession != nil {
280+
var err error
281+
sessionState, err = c.config.UnwrapSession(identity.label, c.connectionStateLocked())
282+
if err != nil {
283+
return err
284+
}
285+
if sessionState == nil {
286+
continue
287+
}
288+
} else {
289+
plaintext := c.config.decryptTicket(identity.label, c.ticketKeys)
290+
if plaintext == nil {
291+
continue
292+
}
293+
var err error
294+
sessionState, err = ParseSessionState(plaintext)
295+
if err != nil {
296+
continue
297+
}
281298
}
282-
sessionState, err := ParseSessionState(plaintext)
283-
if err != nil || sessionState.version != VersionTLS13 {
299+
300+
if sessionState.version != VersionTLS13 {
284301
continue
285302
}
286303

@@ -781,14 +798,21 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error {
781798
return err
782799
}
783800
state.secret = psk
784-
stateBytes, err := state.Bytes()
785-
if err != nil {
786-
c.sendAlert(alertInternalError)
787-
return err
788-
}
789-
m.label, err = c.encryptTicket(stateBytes)
790-
if err != nil {
791-
return err
801+
if c.config.WrapSession != nil {
802+
m.label, err = c.config.WrapSession(c.connectionStateLocked(), state)
803+
if err != nil {
804+
return err
805+
}
806+
} else {
807+
stateBytes, err := state.Bytes()
808+
if err != nil {
809+
c.sendAlert(alertInternalError)
810+
return err
811+
}
812+
m.label, err = c.config.encryptTicket(stateBytes, c.ticketKeys)
813+
if err != nil {
814+
return err
815+
}
792816
}
793817
m.lifetime = uint32(maxSessionTicketLifetime / time.Second)
794818

src/crypto/tls/ticket.go

+38-6
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,21 @@ func (c *Conn) sessionState() (*SessionState, error) {
228228
}, nil
229229
}
230230

231-
func (c *Conn) encryptTicket(state []byte) ([]byte, error) {
232-
if len(c.ticketKeys) == 0 {
231+
// EncryptTicket encrypts a ticket with the Config's configured (or default)
232+
// session ticket keys. It can be used as a [Config.WrapSession] implementation.
233+
func (c *Config) EncryptTicket(cs ConnectionState, ss *SessionState) ([]byte, error) {
234+
ticketKeys := c.ticketKeys(nil)
235+
stateBytes, err := ss.Bytes()
236+
if err != nil {
237+
return nil, err
238+
}
239+
return c.encryptTicket(stateBytes, ticketKeys)
240+
}
241+
242+
var _ = &Config{WrapSession: (&Config{}).EncryptTicket}
243+
244+
func (c *Config) encryptTicket(state []byte, ticketKeys []ticketKey) ([]byte, error) {
245+
if len(ticketKeys) == 0 {
233246
return nil, errors.New("tls: internal error: session ticket keys unavailable")
234247
}
235248

@@ -239,10 +252,10 @@ func (c *Conn) encryptTicket(state []byte) ([]byte, error) {
239252
authenticated := encrypted[:len(encrypted)-sha256.Size]
240253
macBytes := encrypted[len(encrypted)-sha256.Size:]
241254

242-
if _, err := io.ReadFull(c.config.rand(), iv); err != nil {
255+
if _, err := io.ReadFull(c.rand(), iv); err != nil {
243256
return nil, err
244257
}
245-
key := c.ticketKeys[0]
258+
key := ticketKeys[0]
246259
block, err := aes.NewCipher(key.aesKey[:])
247260
if err != nil {
248261
return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
@@ -256,7 +269,26 @@ func (c *Conn) encryptTicket(state []byte) ([]byte, error) {
256269
return encrypted, nil
257270
}
258271

259-
func (c *Conn) decryptTicket(encrypted []byte) []byte {
272+
// DecryptTicket decrypts a ticket encrypted by [Config.EncryptTicket]. It can
273+
// be used as a [Config.UnwrapSession] implementation.
274+
//
275+
// If the ticket can't be decrypted or parsed, DecryptTicket returns (nil, nil).
276+
func (c *Config) DecryptTicket(identity []byte, cs ConnectionState) (*SessionState, error) {
277+
ticketKeys := c.ticketKeys(nil)
278+
stateBytes := c.decryptTicket(identity, ticketKeys)
279+
if stateBytes == nil {
280+
return nil, nil
281+
}
282+
s, err := ParseSessionState(stateBytes)
283+
if err != nil {
284+
return nil, nil // drop unparsable tickets on the floor
285+
}
286+
return s, nil
287+
}
288+
289+
var _ = &Config{UnwrapSession: (&Config{}).DecryptTicket}
290+
291+
func (c *Config) decryptTicket(encrypted []byte, ticketKeys []ticketKey) []byte {
260292
if len(encrypted) < aes.BlockSize+sha256.Size {
261293
return nil
262294
}
@@ -266,7 +298,7 @@ func (c *Conn) decryptTicket(encrypted []byte) []byte {
266298
authenticated := encrypted[:len(encrypted)-sha256.Size]
267299
macBytes := encrypted[len(encrypted)-sha256.Size:]
268300

269-
for _, key := range c.ticketKeys {
301+
for _, key := range ticketKeys {
270302
mac := hmac.New(sha256.New, key.hmacKey[:])
271303
mac.Write(authenticated)
272304
expected := mac.Sum(nil)

0 commit comments

Comments
 (0)