Skip to content

Commit 27526d5

Browse files
Refactoring: never assign unacceptable TLS versions
This commit makes security linting easier by never setting a TLS version outside v1.2 or v1.3, even in case of an unacceptable user input.
1 parent b41b946 commit 27526d5

File tree

2 files changed

+58
-33
lines changed

2 files changed

+58
-33
lines changed

main.go

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -370,14 +370,19 @@ func concurrency(c int) controller.Options {
370370
func GetTLSOptionOverrideFuncs(options TLSOptions) ([]func(*tls.Config), error) {
371371
var tlsOptions []func(config *tls.Config)
372372

373-
tlsMinVersion, err := GetTLSVersion(options.TLSMinVersion)
374-
if err != nil {
375-
return nil, err
376-
}
377-
378-
tlsMaxVersion, err := GetTLSVersion(options.TLSMaxVersion)
379-
if err != nil {
380-
return nil, err
373+
// To make a static analyzer happy, this block ensures there is no code
374+
// path that sets a TLS version outside the acceptable values, even in
375+
// case of unexpected user input.
376+
var tlsMinVersion, tlsMaxVersion uint16
377+
for version, option := range map[*uint16]string{&tlsMinVersion: options.TLSMinVersion, &tlsMaxVersion: options.TLSMaxVersion} {
378+
switch option {
379+
case TLSVersion12:
380+
*version = tls.VersionTLS12
381+
case TLSVersion13:
382+
*version = tls.VersionTLS13
383+
default:
384+
return nil, fmt.Errorf("unexpected TLS version %q (must be one of: %s)", option, strings.Join(tlsSupportedVersions, ", "))
385+
}
381386
}
382387

383388
if tlsMaxVersion != 0 && tlsMinVersion > tlsMaxVersion {
@@ -419,18 +424,3 @@ func GetTLSOptionOverrideFuncs(options TLSOptions) ([]func(*tls.Config), error)
419424

420425
return tlsOptions, nil
421426
}
422-
423-
// GetTLSVersion returns the corresponding tls.Version or error.
424-
func GetTLSVersion(version string) (uint16, error) {
425-
var v uint16
426-
427-
switch version {
428-
case TLSVersion12:
429-
v = tls.VersionTLS12
430-
case TLSVersion13:
431-
v = tls.VersionTLS13
432-
default:
433-
return 0, fmt.Errorf("unexpected TLS version %q (must be one of: %s)", version, strings.Join(tlsSupportedVersions, ", "))
434-
}
435-
return v, nil
436-
}

main_test.go

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package main
1818

1919
import (
2020
"bytes"
21+
"crypto/tls"
2122
"testing"
2223

2324
. "github.com/onsi/gomega"
@@ -75,25 +76,59 @@ func Test13CipherSuite(t *testing.T) {
7576
klog.SetOutput(bufWriter)
7677
klog.LogToStderr(false) // this is important, because klog by default logs to stderr only
7778
_, err := GetTLSOptionOverrideFuncs(tlsMockOptions)
78-
g.Expect(bufWriter.String()).Should(ContainSubstring("warning: Cipher suites should not be set for TLS version 1.3. Ignoring ciphers"))
7979
g.Expect(err).Should(BeNil())
80+
g.Expect(bufWriter.String()).Should(ContainSubstring("warning: Cipher suites should not be set for TLS version 1.3. Ignoring ciphers"))
8081
})
8182
}
8283

83-
func TestGetTLSVersion(t *testing.T) {
84-
t.Run("should error out when incorrect tls version passed", func(t *testing.T) {
84+
func TestGetTLSOverrideFuncs(t *testing.T) {
85+
t.Run("should error out when incorrect min tls version passed", func(t *testing.T) {
86+
g := NewWithT(t)
87+
_, err := GetTLSOptionOverrideFuncs(TLSOptions{
88+
TLSMinVersion: "TLS11",
89+
TLSMaxVersion: "TLS12",
90+
})
91+
g.Expect(err.Error()).Should(Equal("unexpected TLS version \"TLS11\" (must be one of: TLS12, TLS13)"))
92+
})
93+
t.Run("should error out when incorrect max tls version passed", func(t *testing.T) {
8594
g := NewWithT(t)
86-
tlsVersion := "TLS11"
87-
_, err := GetTLSVersion(tlsVersion)
95+
_, err := GetTLSOptionOverrideFuncs(TLSOptions{
96+
TLSMinVersion: "TLS12",
97+
TLSMaxVersion: "TLS11",
98+
})
8899
g.Expect(err.Error()).Should(Equal("unexpected TLS version \"TLS11\" (must be one of: TLS12, TLS13)"))
89100
})
90-
t.Run("should pass and output correct tls version", func(t *testing.T) {
91-
const VersionTLS12 uint16 = 771
101+
t.Run("should apply the requested TLS versions", func(t *testing.T) {
102+
g := NewWithT(t)
103+
tlsOptionOverrides, err := GetTLSOptionOverrideFuncs(TLSOptions{
104+
TLSMinVersion: "TLS12",
105+
TLSMaxVersion: "TLS13",
106+
})
107+
108+
var tlsConfig tls.Config
109+
for _, apply := range tlsOptionOverrides {
110+
apply(&tlsConfig)
111+
}
112+
113+
g.Expect(err).Should(BeNil())
114+
g.Expect(tlsConfig.MinVersion).To(Equal(uint16(tls.VersionTLS12)))
115+
g.Expect(tlsConfig.MaxVersion).To(Equal(uint16(tls.VersionTLS13)))
116+
})
117+
t.Run("should apply the requested non-default TLS versions", func(t *testing.T) {
92118
g := NewWithT(t)
93-
tlsVersion := "TLS12"
94-
version, err := GetTLSVersion(tlsVersion)
95-
g.Expect(version).To(Equal(VersionTLS12))
119+
tlsOptionOverrides, err := GetTLSOptionOverrideFuncs(TLSOptions{
120+
TLSMinVersion: "TLS13",
121+
TLSMaxVersion: "TLS13",
122+
})
123+
124+
var tlsConfig tls.Config
125+
for _, apply := range tlsOptionOverrides {
126+
apply(&tlsConfig)
127+
}
128+
96129
g.Expect(err).Should(BeNil())
130+
g.Expect(tlsConfig.MinVersion).To(Equal(uint16(tls.VersionTLS13)))
131+
g.Expect(tlsConfig.MaxVersion).To(Equal(uint16(tls.VersionTLS13)))
97132
})
98133
}
99134

0 commit comments

Comments
 (0)