Skip to content

Commit 537418b

Browse files
committed
refactor samlsp to be modular
1 parent 62f4c47 commit 537418b

24 files changed

+1169
-684
lines changed

example/service.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@ package main
33

44
import (
55
"bytes"
6+
"context"
67
"crypto/rsa"
78
"crypto/tls"
89
"crypto/x509"
910
"encoding/xml"
1011
"flag"
1112
"fmt"
13+
"log"
1214
"net/http"
1315
"net/url"
1416
"strings"
@@ -18,7 +20,6 @@ import (
1820
"github.com/zenazn/goji"
1921
"github.com/zenazn/goji/web"
2022

21-
"github.com/crewjam/saml/logger"
2223
"github.com/crewjam/saml/samlsp"
2324
)
2425

@@ -100,7 +101,6 @@ OwJlNCASPZRH/JmF8tX0hoHuAQ==
100101
)
101102

102103
func main() {
103-
logr := logger.DefaultLogger
104104
rootURLstr := flag.String("url", "https://962766ce.ngrok.io", "The base URL of this service")
105105
idpMetadataURLstr := flag.String("idp", "https://516becc2.ngrok.io/metadata", "The metadata URL for the IDP")
106106
flag.Parse()
@@ -119,6 +119,9 @@ func main() {
119119
panic(err) // TODO handle error
120120
}
121121

122+
idpMetadata, err := samlsp.FetchMetadata(context.Background(), http.DefaultClient,
123+
*idpMetadataURL)
124+
122125
rootURL, err := url.Parse(*rootURLstr)
123126
if err != nil {
124127
panic(err) // TODO handle error
@@ -127,13 +130,12 @@ func main() {
127130
samlSP, err := samlsp.New(samlsp.Options{
128131
URL: *rootURL,
129132
Key: keyPair.PrivateKey.(*rsa.PrivateKey),
130-
Logger: logr,
131133
Certificate: keyPair.Leaf,
132134
AllowIDPInitiated: true,
133-
IDPMetadataURL: idpMetadataURL,
135+
IDPMetadata: idpMetadata,
134136
})
135137
if err != nil {
136-
logr.Fatalf("%s", err)
138+
log.Fatalf("%s", err)
137139
}
138140

139141
// register with the service provider

example/trivial/trivial.go

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
11
package main
22

33
import (
4+
"context"
5+
"crypto/rsa"
6+
"crypto/tls"
7+
"crypto/x509"
48
"fmt"
59
"net/http"
610
"net/url"
711

8-
"crypto/tls"
9-
"crypto/x509"
10-
11-
"crypto/rsa"
12-
1312
"github.com/crewjam/saml/samlsp"
1413
)
1514

1615
func hello(w http.ResponseWriter, r *http.Request) {
17-
fmt.Fprintf(w, "Hello, %s!", samlsp.Token(r.Context()).Attributes.Get("cn"))
16+
fmt.Fprintf(w, "Hello, %s!", samlsp.AttributeFromContext(r.Context(), "cn"))
1817
}
1918

2019
func main() {
@@ -27,22 +26,21 @@ func main() {
2726
panic(err) // TODO handle error
2827
}
2928

30-
idpMetadataURL, err := url.Parse("https://www.testshib.org/metadata/testshib-providers.xml")
31-
if err != nil {
32-
panic(err) // TODO handle error
33-
}
29+
rootURL, _ := url.Parse("http://localhost:8000")
30+
idpMetadataURL, _ := url.Parse("https://www.testshib.org/metadata/testshib-providers.xml")
3431

35-
rootURL, err := url.Parse("http://localhost:8000")
36-
if err != nil {
37-
panic(err) // TODO handle error
38-
}
32+
idpMetadata, err := samlsp.FetchMetadata(
33+
context.Background(),
34+
http.DefaultClient,
35+
*idpMetadataURL)
3936

40-
samlSP, _ := samlsp.New(samlsp.Options{
41-
IDPMetadataURL: idpMetadataURL,
42-
URL: *rootURL,
43-
Key: keyPair.PrivateKey.(*rsa.PrivateKey),
44-
Certificate: keyPair.Leaf,
37+
samlSP, err := samlsp.New(samlsp.Options{
38+
URL: *rootURL,
39+
IDPMetadata: idpMetadata,
40+
Key: keyPair.PrivateKey.(*rsa.PrivateKey),
41+
Certificate: keyPair.Leaf,
4542
})
43+
4644
app := http.HandlerFunc(hello)
4745
http.Handle("/hello", samlSP.RequireAccount(app))
4846
http.Handle("/saml/", samlSP)

go.mod

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@ go 1.13
44

55
require (
66
github.com/beevik/etree v1.1.0
7+
github.com/crewjam/httperr v0.0.0-20190612203328-a946449404da
8+
github.com/davecgh/go-spew v1.1.1 // indirect
79
github.com/dchest/uniuri v0.0.0-20160212164326-8902c56451e9
810
github.com/dgrijalva/jwt-go v3.2.0+incompatible
911
github.com/jonboulle/clockwork v0.1.0 // indirect
1012
github.com/kr/pretty v0.1.0
13+
github.com/pkg/errors v0.8.1 // indirect
1114
github.com/russellhaering/goxmldsig v0.0.0-20180430223755-7acd5e4a6ef7
1215
github.com/stretchr/testify v1.4.0
1316
github.com/zenazn/goji v0.9.1-0.20160507202103-64eb34159fe5

go.sum

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
github.com/beevik/etree v1.1.0 h1:T0xke/WvNtMoCqgzPhkX2r4rjY3GDZFi+FjpRZY2Jbs=
22
github.com/beevik/etree v1.1.0/go.mod h1:r8Aw8JqVegEf0w2fDnATrX9VpkMcyFeM0FhwO62wh+A=
3+
github.com/crewjam/httperr v0.0.0-20190612203328-a946449404da h1:WXnT88cFG2davqSFqvaFfzkSMC0lqh/8/rKZ+z7tYvI=
4+
github.com/crewjam/httperr v0.0.0-20190612203328-a946449404da/go.mod h1:+rmNIXRvYMqLQeR4DHyTvs6y0MEMymTz4vyFpFkKTPs=
35
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
46
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
7+
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
8+
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
59
github.com/dchest/uniuri v0.0.0-20160212164326-8902c56451e9 h1:74lLNRzvsdIlkTgfDSMuaPjBr4cf6k7pwQQANm/yLKU=
610
github.com/dchest/uniuri v0.0.0-20160212164326-8902c56451e9/go.mod h1:GgB8SF9nRG+GqaDtLcwJZsQFhcogVCJ79j4EdT0c2V4=
711
github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=
@@ -13,6 +17,8 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN
1317
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
1418
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
1519
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
20+
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
21+
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
1622
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
1723
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
1824
github.com/russellhaering/goxmldsig v0.0.0-20180430223755-7acd5e4a6ef7 h1:J4AOUcOh/t1XbQcJfkEqhzgvMJ2tDxdCVvmHxW5QXao=

samlidp/samlidp_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,6 @@ OwJlNCASPZRH/JmF8tX0hoHuAQ==
119119
MetadataURL: mustParseURL("https://sp.example.com/saml2/metadata"),
120120
AcsURL: mustParseURL("https://sp.example.com/saml2/acs"),
121121
IDPMetadata: &saml.EntityDescriptor{},
122-
Logger: logger.DefaultLogger,
123122
}
124123
test.Key = mustParsePrivateKey("-----BEGIN RSA PRIVATE KEY-----\nMIICXgIBAAKBgQDU8wdiaFmPfTyRYuFlVPi866WrH/2JubkHzp89bBQopDaLXYxi\n3PTu3O6Q/KaKxMOFBqrInwqpv/omOGZ4ycQ51O9I+Yc7ybVlW94lTo2gpGf+Y/8E\nPsVbnZaFutRctJ4dVIp9aQ2TpLiGT0xX1OzBO/JEgq9GzDRf+B+eqSuglwIDAQAB\nAoGBAMuy1eN6cgFiCOgBsB3gVDdTKpww87Qk5ivjqEt28SmXO13A1KNVPS6oQ8SJ\nCT5Azc6X/BIAoJCURVL+LHdqebogKljhH/3yIel1kH19vr4E2kTM/tYH+qj8afUS\nJEmArUzsmmK8ccuNqBcllqdwCZjxL4CHDUmyRudFcHVX9oyhAkEA/OV1OkjM3CLU\nN3sqELdMmHq5QZCUihBmk3/N5OvGdqAFGBlEeewlepEVxkh7JnaNXAXrKHRVu/f/\nfbCQxH+qrwJBANeQERF97b9Sibp9xgolb749UWNlAdqmEpmlvmS202TdcaaT1msU\n4rRLiQN3X9O9mq4LZMSVethrQAdX1whawpkCQQDk1yGf7xZpMJ8F4U5sN+F4rLyM\nRq8Sy8p2OBTwzCUXXK+fYeXjybsUUMr6VMYTRP2fQr/LKJIX+E5ZxvcIyFmDAkEA\nyfjNVUNVaIbQTzEbRlRvT6MqR+PTCefC072NF9aJWR93JimspGZMR7viY6IM4lrr\nvBkm0F5yXKaYtoiiDMzlOQJADqmEwXl0D72ZG/2KDg8b4QZEmC9i5gidpQwJXUc6\nhU+IVQoLxRq0fBib/36K9tcrrO5Ba4iEvDcNY+D8yGbUtA==\n-----END RSA PRIVATE KEY-----\n")
125124
test.Certificate = mustParseCertificate("-----BEGIN CERTIFICATE-----\nMIIB7zCCAVgCCQDFzbKIp7b3MTANBgkqhkiG9w0BAQUFADA8MQswCQYDVQQGEwJV\nUzELMAkGA1UECAwCR0ExDDAKBgNVBAoMA2ZvbzESMBAGA1UEAwwJbG9jYWxob3N0\nMB4XDTEzMTAwMjAwMDg1MVoXDTE0MTAwMjAwMDg1MVowPDELMAkGA1UEBhMCVVMx\nCzAJBgNVBAgMAkdBMQwwCgYDVQQKDANmb28xEjAQBgNVBAMMCWxvY2FsaG9zdDCB\nnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA1PMHYmhZj308kWLhZVT4vOulqx/9\nibm5B86fPWwUKKQ2i12MYtz07tzukPymisTDhQaqyJ8Kqb/6JjhmeMnEOdTvSPmH\nO8m1ZVveJU6NoKRn/mP/BD7FW52WhbrUXLSeHVSKfWkNk6S4hk9MV9TswTvyRIKv\nRsw0X/gfnqkroJcCAwEAATANBgkqhkiG9w0BAQUFAAOBgQCMMlIO+GNcGekevKgk\nakpMdAqJfs24maGb90DvTLbRZRD7Xvn1MnVBBS9hzlXiFLYOInXACMW5gcoRFfeT\nQLSouMM8o57h0uKjfTmuoWHLQLi6hnF+cvCsEFiJZ4AbF+DgmO6TarJ8O05t8zvn\nOwJlNCASPZRH/JmF8tX0hoHuAQ==\n-----END CERTIFICATE-----\n")

samlsp/cookie.go

Lines changed: 0 additions & 111 deletions
This file was deleted.

samlsp/error.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package samlsp
2+
3+
import (
4+
"log"
5+
"net/http"
6+
7+
"github.com/crewjam/saml"
8+
"github.com/crewjam/saml/logger"
9+
)
10+
11+
// ErrorFunction is a callback that is invoked to return an error to the
12+
// web user.
13+
type ErrorFunction func(w http.ResponseWriter, r *http.Request, err error)
14+
15+
// DefaultOnError is the default ErrorFunction implementation. It prints
16+
// an message via the standard log package and returns a simple text
17+
// "Forbidden" message to the user.
18+
func DefaultOnError(w http.ResponseWriter, r *http.Request, err error) {
19+
if parseErr, ok := err.(*saml.InvalidResponseError); ok {
20+
log.Printf("WARNING: received invalid saml response: %s (now: %s) %s",
21+
parseErr.Response, parseErr.Now, parseErr.PrivateErr)
22+
} else {
23+
log.Printf("ERROR: %s", err)
24+
}
25+
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
26+
}
27+
28+
// defaultOnErrorWithLogger is like DefaultOnError but accepts a custom logger.
29+
// This is a bridge for backward compatability with people use provide the
30+
// deprecated Logger options field to New().
31+
func defaultOnErrorWithLogger(log logger.Interface) ErrorFunction {
32+
return func(w http.ResponseWriter, r *http.Request, err error) {
33+
if parseErr, ok := err.(*saml.InvalidResponseError); ok {
34+
log.Printf("WARNING: received invalid saml response: %s (now: %s) %s",
35+
parseErr.Response, parseErr.Now, parseErr.PrivateErr)
36+
} else {
37+
log.Printf("ERROR: %s", err)
38+
}
39+
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
40+
}
41+
}

samlsp/fetch_metadata.go

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package samlsp
2+
3+
import (
4+
"context"
5+
"encoding/xml"
6+
"fmt"
7+
"io/ioutil"
8+
"net/http"
9+
"net/url"
10+
11+
"github.com/crewjam/httperr"
12+
13+
"github.com/crewjam/saml"
14+
)
15+
16+
// ParseMetadata parses arbitrary SAML IDP metadata.
17+
//
18+
// Note: this is needed because IDP metadata is sometimes wrapped in
19+
// an <EntitiesDescriptor>, and sometimes the top level element is an
20+
// <EntityDescriptor>.
21+
func ParseMetadata(data []byte) (*saml.EntityDescriptor, error) {
22+
entity := &saml.EntityDescriptor{}
23+
err := xml.Unmarshal(data, entity)
24+
25+
// this comparison is ugly, but it is how the error is generated in encoding/xml
26+
if err != nil && err.Error() == "expected element type <EntityDescriptor> but have <EntitiesDescriptor>" {
27+
entities := &saml.EntitiesDescriptor{}
28+
if err := xml.Unmarshal(data, entities); err != nil {
29+
return nil, err
30+
}
31+
32+
err = fmt.Errorf("no entity found with IDPSSODescriptor")
33+
for i, e := range entities.EntityDescriptors {
34+
if len(e.IDPSSODescriptors) > 0 {
35+
entity = &entities.EntityDescriptors[i]
36+
err = nil
37+
}
38+
}
39+
}
40+
if err != nil {
41+
return nil, err
42+
}
43+
return entity, nil
44+
}
45+
46+
// FetchMetadata returns metadata from an IDP metadata URL.
47+
func FetchMetadata(ctx context.Context, httpClient *http.Client, metadataURL url.URL) (*saml.EntityDescriptor, error) {
48+
req, err := http.NewRequest("GET", metadataURL.String(), nil)
49+
if err != nil {
50+
return nil, err
51+
}
52+
req = req.WithContext(ctx)
53+
54+
resp, err := httpClient.Do(req)
55+
if err != nil {
56+
return nil, err
57+
}
58+
if resp.StatusCode >= 400 {
59+
return nil, httperr.Response(*resp)
60+
}
61+
defer resp.Body.Close()
62+
63+
data, err := ioutil.ReadAll(resp.Body)
64+
if err != nil {
65+
return nil, err
66+
}
67+
68+
return ParseMetadata(data)
69+
}

0 commit comments

Comments
 (0)