Skip to content

Commit 058437e

Browse files
committed
devapp: redirect to TLS when autocert is turned on
Updates golang/go#20691 Change-Id: I5247683f62cbe922a880246dbd2e99e31686e2d3 Reviewed-on: https://go-review.googlesource.com/46911 Reviewed-by: Brad Fitzpatrick <[email protected]>
1 parent 9a31848 commit 058437e

File tree

3 files changed

+111
-82
lines changed

3 files changed

+111
-82
lines changed

devapp/devapp.go

Lines changed: 25 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import (
2020
"os"
2121
"strconv"
2222
"strings"
23-
"sync"
2423
"sync/atomic"
2524
"time"
2625

@@ -37,11 +36,6 @@ func init() {
3736
os.Stderr.WriteString("devapp generates the dashboard that powers dev.golang.org.\n")
3837
flag.PrintDefaults()
3938
}
40-
41-
// TODO don't bind relative to a working directory.
42-
http.Handle("/", http.FileServer(http.Dir("./static/")))
43-
http.HandleFunc("/favicon.ico", faviconHandler)
44-
http.Handle("/release", hstsHandler(func(w http.ResponseWriter, r *http.Request) { servePage(w, r, "release") }))
4539
}
4640

4741
func main() {
@@ -50,11 +44,14 @@ func main() {
5044
devTLSPort = flag.Int("dev-tls-port", 0, "if non-zero, port number to run localhost self-signed TLS server")
5145
autocertBucket = flag.String("autocert-bucket", "", "if non-empty, listen on port 443 and serve a LetsEncrypt TLS cert using this Google Cloud Storage bucket as a cache")
5246
updateInterval = flag.Duration("update-interval", 5*time.Minute, "how often to update the dashboard data")
47+
staticDir = flag.String("static-dir", "./static/", "location of static directory relative to binary location")
5348
)
5449
flag.Parse()
5550

5651
go updateLoop(*updateInterval)
5752

53+
s := newServer(http.NewServeMux(), *staticDir)
54+
5855
ln, err := net.Listen("tcp", *listen)
5956
if err != nil {
6057
log.Fatalf("Error listening on %s: %v\n", *listen, err)
@@ -63,13 +60,19 @@ func main() {
6360

6461
errc := make(chan error)
6562
if ln != nil {
66-
go func() { errc <- fmt.Errorf("http.Serve = %v", http.Serve(ln, nil)) }()
63+
go func() {
64+
handler := http.Handler(s)
65+
if *autocertBucket != "" {
66+
handler = http.HandlerFunc(redirectHTTP)
67+
}
68+
errc <- fmt.Errorf("http.Serve = %v", http.Serve(ln, handler))
69+
}()
6770
}
6871
if *autocertBucket != "" {
69-
go func() { errc <- serveAutocertTLS(*autocertBucket) }()
72+
go func() { errc <- serveAutocertTLS(s, *autocertBucket) }()
7073
}
7174
if *devTLSPort != 0 {
72-
go func() { errc <- serveDevTLS(*devTLSPort) }()
75+
go func() { errc <- serveDevTLS(s, *devTLSPort) }()
7376
}
7477

7578
log.Fatal(<-errc)
@@ -85,15 +88,23 @@ func updateLoop(interval time.Duration) {
8588
}
8689
}
8790

88-
func serveDevTLS(port int) error {
91+
func redirectHTTP(w http.ResponseWriter, r *http.Request) {
92+
if r.TLS != nil || r.Host == "" {
93+
http.NotFound(w, r)
94+
return
95+
}
96+
http.Redirect(w, r, "https://"+r.Host+r.RequestURI, http.StatusFound)
97+
}
98+
99+
func serveDevTLS(h http.Handler, port int) error {
89100
ln, err := net.Listen("tcp", "localhost:"+strconv.Itoa(port))
90101
if err != nil {
91102
return err
92103
}
93104
defer ln.Close()
94105
log.Printf("Serving self-signed TLS at https://%s", ln.Addr())
95106
// Abuse httptest for its localhost TLS setup code:
96-
ts := httptest.NewUnstartedServer(http.DefaultServeMux)
107+
ts := httptest.NewUnstartedServer(h)
97108
// Ditch the provided listener, replace with our own:
98109
ts.Listener.Close()
99110
ts.Listener = ln
@@ -106,7 +117,7 @@ func serveDevTLS(port int) error {
106117
select {}
107118
}
108119

109-
func serveAutocertTLS(bucket string) error {
120+
func serveAutocertTLS(h http.Handler, bucket string) error {
110121
ln, err := net.Listen("tcp", ":443")
111122
if err != nil {
112123
return err
@@ -132,7 +143,8 @@ func serveAutocertTLS(bucket string) error {
132143
}
133144
tlsLn := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config)
134145
server := &http.Server{
135-
Addr: ln.Addr().String(),
146+
Addr: ln.Addr().String(),
147+
Handler: h,
136148
}
137149
if err := http2.ConfigureServer(server, nil); err != nil {
138150
log.Fatalf("http2.ConfigureServer: %v", err)
@@ -155,55 +167,6 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
155167
return tc, nil
156168
}
157169

158-
// hstsHandler wraps an http.HandlerFunc such that it sets the HSTS header.
159-
func hstsHandler(fn http.HandlerFunc) http.Handler {
160-
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
161-
w.Header().Set("Strict-Transport-Security", "max-age=31536000; preload")
162-
fn(w, r)
163-
})
164-
}
165-
166-
type page struct {
167-
// Content is the complete HTML of the page.
168-
Content []byte
169-
}
170-
171-
var (
172-
pageStore = map[string]*page{}
173-
pageStoreMu sync.Mutex
174-
)
175-
176-
func getPage(name string) (*page, error) {
177-
pageStoreMu.Lock()
178-
defer pageStoreMu.Unlock()
179-
p, ok := pageStore[name]
180-
if ok {
181-
return p, nil
182-
}
183-
return nil, fmt.Errorf("page key %s not found", name)
184-
}
185-
186-
func writePage(pageStr string, content []byte) error {
187-
pageStoreMu.Lock()
188-
defer pageStoreMu.Unlock()
189-
entity := &page{
190-
Content: content,
191-
}
192-
pageStore[pageStr] = entity
193-
return nil
194-
}
195-
196-
func servePage(w http.ResponseWriter, r *http.Request, pageStr string) {
197-
entity, err := getPage(pageStr)
198-
if err != nil {
199-
log.Printf("getPage(%s) = %v", pageStr, err)
200-
http.NotFound(w, r)
201-
return
202-
}
203-
w.Header().Set("Content-Type", "text/html; charset=utf-8")
204-
w.Write(entity.Content)
205-
}
206-
207170
type countTransport struct {
208171
http.RoundTripper
209172
count int64
@@ -270,11 +233,3 @@ func newTransport(ctx context.Context) http.RoundTripper {
270233
}
271234
return t
272235
}
273-
274-
// GET /favicon.ico
275-
func faviconHandler(w http.ResponseWriter, r *http.Request) {
276-
// Need to specify content type for consistent tests, without this it's
277-
// determined from mime.types on the box the test is running on
278-
w.Header().Set("Content-Type", "image/x-icon")
279-
http.ServeFile(w, r, "./static/favicon.ico")
280-
}

devapp/server.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// Copyright 2017 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package main
6+
7+
import (
8+
"fmt"
9+
"log"
10+
"net/http"
11+
"sync"
12+
)
13+
14+
// A server is an http.Handler that serves content within staticDir at root and
15+
// the dynamically-generated dashboards at their respective endpoints.
16+
type server struct {
17+
mux *http.ServeMux
18+
staticDir string
19+
}
20+
21+
func newServer(mux *http.ServeMux, staticDir string) *server {
22+
s := &server{
23+
mux: mux,
24+
staticDir: staticDir,
25+
}
26+
s.mux.Handle("/", http.FileServer(http.Dir(s.staticDir)))
27+
s.mux.HandleFunc("/release", handleRelease)
28+
return s
29+
}
30+
31+
// ServeHTTP satisfies the http.Handler interface.
32+
func (s *server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
33+
if r.TLS != nil {
34+
w.Header().Set("Strict-Transport-Security", "max-age=31536000; preload")
35+
}
36+
s.mux.ServeHTTP(w, r)
37+
}
38+
39+
var (
40+
pageStoreMu sync.Mutex
41+
pageStore = map[string][]byte{}
42+
)
43+
44+
func getPage(name string) ([]byte, error) {
45+
pageStoreMu.Lock()
46+
defer pageStoreMu.Unlock()
47+
p, ok := pageStore[name]
48+
if ok {
49+
return p, nil
50+
}
51+
return nil, fmt.Errorf("page key %s not found", name)
52+
}
53+
54+
func writePage(key string, content []byte) error {
55+
pageStoreMu.Lock()
56+
defer pageStoreMu.Unlock()
57+
pageStore[key] = content
58+
return nil
59+
}
60+
61+
func servePage(w http.ResponseWriter, r *http.Request, key string) {
62+
b, err := getPage(key)
63+
if err != nil {
64+
log.Printf("getPage(%q) = %v", key, err)
65+
http.NotFound(w, r)
66+
return
67+
}
68+
w.Header().Set("Content-Type", "text/html; charset=utf-8")
69+
w.Write(b)
70+
}
71+
72+
func handleRelease(w http.ResponseWriter, r *http.Request) {
73+
servePage(w, r, "release")
74+
}

devapp/devapp_test.go renamed to devapp/server_test.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,20 @@
55
package main
66

77
import (
8-
"fmt"
8+
"crypto/tls"
99
"net/http"
1010
"net/http/httptest"
1111
"testing"
1212
)
1313

14+
var testServer = newServer(http.DefaultServeMux, "./static/")
15+
1416
func TestStaticAssetsFound(t *testing.T) {
1517
req := httptest.NewRequest("GET", "/", nil)
1618
w := httptest.NewRecorder()
17-
http.DefaultServeMux.ServeHTTP(w, req)
18-
if w.Code != 200 {
19-
t.Errorf("expected code 200, got %d", w.Code)
19+
testServer.ServeHTTP(w, req)
20+
if w.Code != http.StatusOK {
21+
t.Errorf("expected code %d, got %d", http.StatusOK, w.Code)
2022
}
2123
if hdr := w.Header().Get("Content-Type"); hdr != "text/html; charset=utf-8" {
2224
t.Errorf("incorrect Content-Type header, got headers: %v", w.Header())
@@ -26,22 +28,20 @@ func TestStaticAssetsFound(t *testing.T) {
2628
func TestFaviconFound(t *testing.T) {
2729
req := httptest.NewRequest("GET", "/favicon.ico", nil)
2830
w := httptest.NewRecorder()
29-
http.DefaultServeMux.ServeHTTP(w, req)
30-
if w.Code != 200 {
31-
t.Errorf("expected code 200, got %d", w.Code)
31+
testServer.ServeHTTP(w, req)
32+
if w.Code != http.StatusOK {
33+
t.Errorf("expected code %d, got %d", http.StatusOK, w.Code)
3234
}
3335
if hdr := w.Header().Get("Content-Type"); hdr != "image/x-icon" {
3436
t.Errorf("incorrect Content-Type header, got headers: %v", w.Header())
3537
}
3638
}
3739

3840
func TestHSTSHeaderSet(t *testing.T) {
39-
http.Handle("/test_hsts", hstsHandler(func(w http.ResponseWriter, r *http.Request) {
40-
fmt.Fprintln(w, "much secure")
41-
}))
42-
req := httptest.NewRequest("GET", "/test_hsts", nil)
41+
req := httptest.NewRequest("GET", "/", nil)
42+
req.TLS = &tls.ConnectionState{}
4343
w := httptest.NewRecorder()
44-
http.DefaultServeMux.ServeHTTP(w, req)
44+
testServer.ServeHTTP(w, req)
4545
if hdr := w.Header().Get("Strict-Transport-Security"); hdr == "" {
4646
t.Errorf("missing Strict-Transport-Security header; headers = %v", w.Header())
4747
}

0 commit comments

Comments
 (0)