Skip to content

Commit 6435d0c

Browse files
committed
crypto/tls: implement TLS 1.3 PSK authentication (server side)
Added some assertions to testHandshake, but avoided checking the error of one of the Close() because the one that would lose the race would write the closeNotify to a connection closed on the other side which is broken on js/wasm (#28650). Moved that Close() after the chan sync to ensure it happens second. Accepting a ticket with client certificates when NoClientCert is configured is probably not a problem, and we could hide them to avoid confusing the application, but the current behavior is to skip the ticket, and I'd rather keep behavior changes to a minimum. Updates #9671 Change-Id: I93b56e44ddfe3d48c2bef52c83285ba2f46f297a Reviewed-on: https://go-review.googlesource.com/c/147445 Reviewed-by: Adam Langley <[email protected]>
1 parent d669cc4 commit 6435d0c

25 files changed

+1960
-955
lines changed

src/crypto/tls/common.go

+15
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,17 @@ const (
234234
RequireAndVerifyClientCert
235235
)
236236

237+
// requiresClientCert returns whether the ClientAuthType requires a client
238+
// certificate to be provided.
239+
func requiresClientCert(c ClientAuthType) bool {
240+
switch c {
241+
case RequireAnyClientCert, RequireAndVerifyClientCert:
242+
return true
243+
default:
244+
return false
245+
}
246+
}
247+
237248
// ClientSessionState contains the state needed by clients to resume TLS
238249
// sessions.
239250
type ClientSessionState struct {
@@ -599,6 +610,10 @@ func ticketKeyFromBytes(b [32]byte) (key ticketKey) {
599610
return key
600611
}
601612

613+
// maxSessionTicketLifetime is the maximum allowed lifetime of a TLS 1.3 session
614+
// ticket, and the lifetime we set for tickets we send.
615+
const maxSessionTicketLifetime = 7 * 24 * time.Hour
616+
602617
// Clone returns a shallow clone of c. It is safe to clone a Config that is
603618
// being used concurrently by a TLS client or server.
604619
func (c *Config) Clone() *Config {

src/crypto/tls/handshake_client_test.go

+33-18
Original file line numberDiff line numberDiff line change
@@ -869,11 +869,14 @@ func TestClientKeyUpdate(t *testing.T) {
869869
runClientTestTLS13(t, test)
870870
}
871871

872-
func TestClientResumption(t *testing.T) {
873-
// TODO(filippo): update to test both TLS 1.3 and 1.2 once PSK are
874-
// supported server-side.
872+
func TestResumption(t *testing.T) {
873+
t.Run("TLSv12", func(t *testing.T) { testResumption(t, VersionTLS12) })
874+
t.Run("TLSv13", func(t *testing.T) { testResumption(t, VersionTLS13) })
875+
}
875876

877+
func testResumption(t *testing.T, version uint16) {
876878
serverConfig := &Config{
879+
MaxVersion: version,
877880
CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
878881
Certificates: testConfig.Certificates,
879882
}
@@ -887,6 +890,7 @@ func TestClientResumption(t *testing.T) {
887890
rootCAs.AddCert(issuer)
888891

889892
clientConfig := &Config{
893+
MaxVersion: version,
890894
CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
891895
ClientSessionCache: NewLRUClientSessionCache(32),
892896
RootCAs: rootCAs,
@@ -924,9 +928,12 @@ func TestClientResumption(t *testing.T) {
924928
testResumeState("Handshake", false)
925929
ticket := getTicket()
926930
testResumeState("Resume", true)
927-
if !bytes.Equal(ticket, getTicket()) {
931+
if !bytes.Equal(ticket, getTicket()) && version != VersionTLS13 {
928932
t.Fatal("first ticket doesn't match ticket after resumption")
929933
}
934+
if bytes.Equal(ticket, getTicket()) && version == VersionTLS13 {
935+
t.Fatal("ticket didn't change after resumption")
936+
}
930937

931938
key1 := randomKey()
932939
serverConfig.SetSessionTicketKeys([][32]byte{key1})
@@ -946,16 +953,21 @@ func TestClientResumption(t *testing.T) {
946953
// Reset serverConfig to ensure that calling SetSessionTicketKeys
947954
// before the serverConfig is used works.
948955
serverConfig = &Config{
956+
MaxVersion: version,
949957
CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
950958
Certificates: testConfig.Certificates,
951959
}
952960
serverConfig.SetSessionTicketKeys([][32]byte{key2})
953961

954962
testResumeState("FreshConfig", true)
955963

956-
clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA}
957-
testResumeState("DifferentCipherSuite", false)
958-
testResumeState("DifferentCipherSuiteRecovers", true)
964+
// In TLS 1.3, cross-cipher suite resumption is allowed as long as the KDF
965+
// hash matches. Also, Config.CipherSuites does not apply to TLS 1.3.
966+
if version != VersionTLS13 {
967+
clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA}
968+
testResumeState("DifferentCipherSuite", false)
969+
testResumeState("DifferentCipherSuiteRecovers", true)
970+
}
959971

960972
deleteTicket()
961973
testResumeState("WithoutSessionTicket", false)
@@ -966,18 +978,21 @@ func TestClientResumption(t *testing.T) {
966978
serverConfig.ClientAuth = RequireAndVerifyClientCert
967979
clientConfig.Certificates = serverConfig.Certificates
968980
testResumeState("InitialHandshake", false)
969-
testResumeState("WithClientCertificates", true)
970-
971-
// Tickets should be removed from the session cache on TLS handshake failure
972-
farFuture := func() time.Time { return time.Unix(16725225600, 0) }
973-
serverConfig.Time = farFuture
974-
_, _, err = testHandshake(t, clientConfig, serverConfig)
975-
if err == nil {
976-
t.Fatalf("handshake did not fail after client certificate expiry")
981+
if version != VersionTLS13 {
982+
// TODO(filippo): reenable when client authentication is implemented
983+
testResumeState("WithClientCertificates", true)
984+
985+
// Tickets should be removed from the session cache on TLS handshake failure
986+
farFuture := func() time.Time { return time.Unix(16725225600, 0) }
987+
serverConfig.Time = farFuture
988+
_, _, err = testHandshake(t, clientConfig, serverConfig)
989+
if err == nil {
990+
t.Fatalf("handshake did not fail after client certificate expiry")
991+
}
992+
serverConfig.Time = nil
993+
testResumeState("AfterHandshakeFailure", false)
994+
serverConfig.ClientAuth = NoClientCert
977995
}
978-
serverConfig.Time = nil
979-
testResumeState("AfterHandshakeFailure", false)
980-
serverConfig.ClientAuth = NoClientCert
981996

982997
clientConfig.ClientSessionCache = nil
983998
testResumeState("WithoutSessionCache", false)

src/crypto/tls/handshake_client_tls13.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error {
565565
return nil
566566
}
567567
lifetime := time.Duration(msg.lifetime) * time.Second
568-
if lifetime > 7*24*time.Hour {
568+
if lifetime > maxSessionTicketLifetime {
569569
c.sendAlert(alertIllegalParameter)
570570
return errors.New("tls: received a session ticket with invalid lifetime")
571571
}

src/crypto/tls/handshake_messages.go

+82-44
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,23 @@ func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) {
3030
}))
3131
}
3232

33+
// addUint64 appends a big-endian, 64-bit value to the cryptobyte.Builder.
34+
func addUint64(b *cryptobyte.Builder, v uint64) {
35+
b.AddUint32(uint32(v >> 32))
36+
b.AddUint32(uint32(v))
37+
}
38+
39+
// readUint64 decodes a big-endian, 64-bit value into out and advances over it.
40+
// It reports whether the read was successful.
41+
func readUint64(s *cryptobyte.String, out *uint64) bool {
42+
var hi, lo uint32
43+
if !s.ReadUint32(&hi) || !s.ReadUint32(&lo) {
44+
return false
45+
}
46+
*out = uint64(hi)<<32 | uint64(lo)
47+
return true
48+
}
49+
3350
// readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a
3451
// []byte instead of a cryptobyte.String.
3552
func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
@@ -1266,89 +1283,110 @@ func (m *certificateMsgTLS13) marshal() []byte {
12661283
b.AddUint8(typeCertificate)
12671284
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
12681285
b.AddUint8(0) // certificate_request_context
1269-
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
1270-
for i, cert := range m.certificate.Certificate {
1271-
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
1272-
b.AddBytes(cert)
1273-
})
1274-
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1275-
if i > 0 {
1276-
// This library only supports OCSP and SCT for leaf certificates.
1277-
return
1278-
}
1279-
if m.ocspStapling {
1280-
b.AddUint16(extensionStatusRequest)
1281-
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1282-
b.AddUint8(statusTypeOCSP)
1283-
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
1284-
b.AddBytes(m.certificate.OCSPStaple)
1285-
})
1286-
})
1287-
}
1288-
if m.scts {
1289-
b.AddUint16(extensionSCT)
1290-
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1291-
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1292-
for _, sct := range m.certificate.SignedCertificateTimestamps {
1293-
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1294-
b.AddBytes(sct)
1295-
})
1296-
}
1297-
})
1298-
})
1299-
}
1300-
})
1301-
}
1302-
})
1286+
1287+
certificate := m.certificate
1288+
if !m.ocspStapling {
1289+
certificate.OCSPStaple = nil
1290+
}
1291+
if !m.scts {
1292+
certificate.SignedCertificateTimestamps = nil
1293+
}
1294+
marshalCertificate(b, certificate)
13031295
})
13041296

13051297
m.raw = b.BytesOrPanic()
13061298
return m.raw
13071299
}
13081300

1301+
func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) {
1302+
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
1303+
for i, cert := range certificate.Certificate {
1304+
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
1305+
b.AddBytes(cert)
1306+
})
1307+
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1308+
if i > 0 {
1309+
// This library only supports OCSP and SCT for leaf certificates.
1310+
return
1311+
}
1312+
if certificate.OCSPStaple != nil {
1313+
b.AddUint16(extensionStatusRequest)
1314+
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1315+
b.AddUint8(statusTypeOCSP)
1316+
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
1317+
b.AddBytes(certificate.OCSPStaple)
1318+
})
1319+
})
1320+
}
1321+
if certificate.SignedCertificateTimestamps != nil {
1322+
b.AddUint16(extensionSCT)
1323+
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1324+
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1325+
for _, sct := range certificate.SignedCertificateTimestamps {
1326+
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1327+
b.AddBytes(sct)
1328+
})
1329+
}
1330+
})
1331+
})
1332+
}
1333+
})
1334+
}
1335+
})
1336+
}
1337+
13091338
func (m *certificateMsgTLS13) unmarshal(data []byte) bool {
13101339
*m = certificateMsgTLS13{raw: data}
13111340
s := cryptobyte.String(data)
13121341

1313-
var context, certList cryptobyte.String
1342+
var context cryptobyte.String
13141343
if !s.Skip(4) || // message type and uint24 length field
13151344
!s.ReadUint8LengthPrefixed(&context) || !context.Empty() ||
1316-
!s.ReadUint24LengthPrefixed(&certList) ||
1345+
!unmarshalCertificate(&s, &m.certificate) ||
13171346
!s.Empty() {
13181347
return false
13191348
}
13201349

1350+
m.scts = m.certificate.SignedCertificateTimestamps != nil
1351+
m.ocspStapling = m.certificate.OCSPStaple != nil
1352+
1353+
return true
1354+
}
1355+
1356+
func unmarshalCertificate(s *cryptobyte.String, certificate *Certificate) bool {
1357+
var certList cryptobyte.String
1358+
if !s.ReadUint24LengthPrefixed(&certList) {
1359+
return false
1360+
}
13211361
for !certList.Empty() {
13221362
var cert []byte
13231363
var extensions cryptobyte.String
13241364
if !readUint24LengthPrefixed(&certList, &cert) ||
13251365
!certList.ReadUint16LengthPrefixed(&extensions) {
13261366
return false
13271367
}
1328-
m.certificate.Certificate = append(m.certificate.Certificate, cert)
1368+
certificate.Certificate = append(certificate.Certificate, cert)
13291369
for !extensions.Empty() {
13301370
var extension uint16
13311371
var extData cryptobyte.String
13321372
if !extensions.ReadUint16(&extension) ||
13331373
!extensions.ReadUint16LengthPrefixed(&extData) {
13341374
return false
13351375
}
1336-
if len(m.certificate.Certificate) > 1 {
1376+
if len(certificate.Certificate) > 1 {
13371377
// This library only supports OCSP and SCT for leaf certificates.
13381378
continue
13391379
}
13401380

13411381
switch extension {
13421382
case extensionStatusRequest:
1343-
m.ocspStapling = true
13441383
var statusType uint8
13451384
if !extData.ReadUint8(&statusType) || statusType != statusTypeOCSP ||
1346-
!readUint24LengthPrefixed(&extData, &m.certificate.OCSPStaple) ||
1347-
len(m.certificate.OCSPStaple) == 0 {
1385+
!readUint24LengthPrefixed(&extData, &certificate.OCSPStaple) ||
1386+
len(certificate.OCSPStaple) == 0 {
13481387
return false
13491388
}
13501389
case extensionSCT:
1351-
m.scts = true
13521390
var sctList cryptobyte.String
13531391
if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() {
13541392
return false
@@ -1359,8 +1397,8 @@ func (m *certificateMsgTLS13) unmarshal(data []byte) bool {
13591397
len(sct) == 0 {
13601398
return false
13611399
}
1362-
m.certificate.SignedCertificateTimestamps = append(
1363-
m.certificate.SignedCertificateTimestamps, sct)
1400+
certificate.SignedCertificateTimestamps = append(
1401+
certificate.SignedCertificateTimestamps, sct)
13641402
}
13651403
default:
13661404
// Ignore unknown extensions.

src/crypto/tls/handshake_messages_test.go

+22
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ var tests = []interface{}{
2929
&nextProtoMsg{},
3030
&newSessionTicketMsg{},
3131
&sessionState{},
32+
&sessionStateTLS13{},
3233
&encryptedExtensionsMsg{},
3334
&endOfEarlyDataMsg{},
3435
&keyUpdateMsg{},
@@ -332,6 +333,27 @@ func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
332333
return reflect.ValueOf(s)
333334
}
334335

336+
func (*sessionStateTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
337+
s := &sessionStateTLS13{}
338+
s.cipherSuite = uint16(rand.Intn(10000))
339+
s.resumptionSecret = randomBytes(rand.Intn(100)+1, rand)
340+
s.createdAt = uint64(rand.Int63())
341+
for i := 0; i < rand.Intn(2)+1; i++ {
342+
s.certificate.Certificate = append(
343+
s.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
344+
}
345+
if rand.Intn(10) > 5 {
346+
s.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
347+
}
348+
if rand.Intn(10) > 5 {
349+
for i := 0; i < rand.Intn(2)+1; i++ {
350+
s.certificate.SignedCertificateTimestamps = append(
351+
s.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
352+
}
353+
}
354+
return reflect.ValueOf(s)
355+
}
356+
335357
func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value {
336358
m := &endOfEarlyDataMsg{}
337359
return reflect.ValueOf(m)

src/crypto/tls/handshake_server.go

+9-5
Original file line numberDiff line numberDiff line change
@@ -323,9 +323,13 @@ func (hs *serverHandshakeState) checkForResumption() bool {
323323
return false
324324
}
325325

326-
var ok bool
327-
var sessionTicket = append([]uint8{}, hs.clientHello.sessionTicket...)
328-
if hs.sessionState, ok = c.decryptTicket(sessionTicket); !ok {
326+
plaintext, usedOldKey := c.decryptTicket(hs.clientHello.sessionTicket)
327+
if plaintext == nil {
328+
return false
329+
}
330+
hs.sessionState = &sessionState{usedOldKey: usedOldKey}
331+
ok := hs.sessionState.unmarshal(plaintext)
332+
if !ok {
329333
return false
330334
}
331335

@@ -352,7 +356,7 @@ func (hs *serverHandshakeState) checkForResumption() bool {
352356
}
353357

354358
sessionHasClientCerts := len(hs.sessionState.certificates) != 0
355-
needClientCerts := c.config.ClientAuth == RequireAnyClientCert || c.config.ClientAuth == RequireAndVerifyClientCert
359+
needClientCerts := requiresClientCert(c.config.ClientAuth)
356360
if needClientCerts && !sessionHasClientCerts {
357361
return false
358362
}
@@ -657,7 +661,7 @@ func (hs *serverHandshakeState) sendSessionTicket() error {
657661
masterSecret: hs.masterSecret,
658662
certificates: hs.certsFromClient,
659663
}
660-
m.ticket, err = c.encryptTicket(&state)
664+
m.ticket, err = c.encryptTicket(state.marshal())
661665
if err != nil {
662666
return err
663667
}

0 commit comments

Comments
 (0)