Skip to content

Commit 9baafab

Browse files
committed
crypto/rsa: refactor RSA-PSS signing and verification
Cleaned up for readability and consistency. There is one tiny behavioral change: when PSSSaltLengthEqualsHash is used and both hash and opts.Hash were set, hash.Size() was used for the salt length instead of opts.Hash.Size(). That's clearly wrong because opts.Hash is documented to override hash. Change-Id: I3e25dad933961eac827c6d2e3bbfe45fc5a6fb0e Reviewed-on: https://go-review.googlesource.com/c/go/+/226937 Run-TryBot: Filippo Valsorda <[email protected]> TryBot-Result: Gobot Gobot <[email protected]> Reviewed-by: Katie Hockman <[email protected]>
1 parent aa4d92b commit 9baafab

File tree

2 files changed

+96
-86
lines changed

2 files changed

+96
-86
lines changed

src/crypto/rsa/pss.go

Lines changed: 91 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44

55
package rsa
66

7-
// This file implements the PSS signature scheme [1].
8-
//
9-
// [1] https://www.emc.com/collateral/white-papers/h11300-pkcs-1v2-2-rsa-cryptography-standard-wp.pdf
7+
// This file implements the RSASSA-PSS signature scheme according to RFC 8017.
108

119
import (
1210
"bytes"
@@ -17,8 +15,22 @@ import (
1715
"math/big"
1816
)
1917

18+
// Per RFC 8017, Section 9.1
19+
//
20+
// EM = MGF1 xor DB || H( 8*0x00 || mHash || salt ) || 0xbc
21+
//
22+
// where
23+
//
24+
// DB = PS || 0x01 || salt
25+
//
26+
// and PS can be empty so
27+
//
28+
// emLen = dbLen + hLen + 1 = psLen + sLen + hLen + 2
29+
//
30+
2031
func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byte, error) {
21-
// See [1], section 9.1.1
32+
// See RFC 8017, Section 9.1.1.
33+
2234
hLen := hash.Size()
2335
sLen := len(salt)
2436
emLen := (emBits + 7) / 8
@@ -30,7 +42,7 @@ func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byt
3042
// 2. Let mHash = Hash(M), an octet string of length hLen.
3143

3244
if len(mHash) != hLen {
33-
return nil, errors.New("crypto/rsa: input must be hashed message")
45+
return nil, errors.New("crypto/rsa: input must be hashed with given hash")
3446
}
3547

3648
// 3. If emLen < hLen + sLen + 2, output "encoding error" and stop.
@@ -40,8 +52,9 @@ func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byt
4052
}
4153

4254
em := make([]byte, emLen)
43-
db := em[:emLen-sLen-hLen-2+1+sLen]
44-
h := em[emLen-sLen-hLen-2+1+sLen : emLen-1]
55+
psLen := emLen - sLen - hLen - 2
56+
db := em[:psLen+1+sLen]
57+
h := em[psLen+1+sLen : emLen-1]
4558

4659
// 4. Generate a random octet string salt of length sLen; if sLen = 0,
4760
// then salt is the empty string.
@@ -69,8 +82,8 @@ func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byt
6982
// 8. Let DB = PS || 0x01 || salt; DB is an octet string of length
7083
// emLen - hLen - 1.
7184

72-
db[emLen-sLen-hLen-2] = 0x01
73-
copy(db[emLen-sLen-hLen-1:], salt)
85+
db[psLen] = 0x01
86+
copy(db[psLen+1:], salt)
7487

7588
// 9. Let dbMask = MGF(H, emLen - hLen - 1).
7689
//
@@ -81,47 +94,57 @@ func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byt
8194
// 11. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in
8295
// maskedDB to zero.
8396

84-
db[0] &= (0xFF >> uint(8*emLen-emBits))
97+
db[0] &= 0xff >> (8*emLen - emBits)
8598

8699
// 12. Let EM = maskedDB || H || 0xbc.
87-
em[emLen-1] = 0xBC
100+
em[emLen-1] = 0xbc
88101

89102
// 13. Output EM.
90103
return em, nil
91104
}
92105

93106
func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
107+
// See RFC 8017, Section 9.1.2.
108+
109+
hLen := hash.Size()
110+
if sLen == PSSSaltLengthEqualsHash {
111+
sLen = hLen
112+
}
113+
emLen := (emBits + 7) / 8
114+
if emLen != len(em) {
115+
return errors.New("rsa: internal error: inconsistent length")
116+
}
117+
94118
// 1. If the length of M is greater than the input limitation for the
95119
// hash function (2^61 - 1 octets for SHA-1), output "inconsistent"
96120
// and stop.
97121
//
98122
// 2. Let mHash = Hash(M), an octet string of length hLen.
99-
hLen := hash.Size()
100123
if hLen != len(mHash) {
101124
return ErrVerification
102125
}
103126

104127
// 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop.
105-
emLen := (emBits + 7) / 8
106128
if emLen < hLen+sLen+2 {
107129
return ErrVerification
108130
}
109131

110132
// 4. If the rightmost octet of EM does not have hexadecimal value
111133
// 0xbc, output "inconsistent" and stop.
112-
if em[len(em)-1] != 0xBC {
134+
if em[emLen-1] != 0xbc {
113135
return ErrVerification
114136
}
115137

116138
// 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and
117139
// let H be the next hLen octets.
118140
db := em[:emLen-hLen-1]
119-
h := em[emLen-hLen-1 : len(em)-1]
141+
h := em[emLen-hLen-1 : emLen-1]
120142

121143
// 6. If the leftmost 8 * emLen - emBits bits of the leftmost octet in
122144
// maskedDB are not all equal to zero, output "inconsistent" and
123145
// stop.
124-
if em[0]&(0xFF<<uint(8-(8*emLen-emBits))) != 0 {
146+
var bitMask byte = 0xff >> (8*emLen - emBits)
147+
if em[0] & ^bitMask != 0 {
125148
return ErrVerification
126149
}
127150

@@ -132,37 +155,30 @@ func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
132155

133156
// 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB
134157
// to zero.
135-
db[0] &= (0xFF >> uint(8*emLen-emBits))
158+
db[0] &= bitMask
136159

160+
// If we don't know the salt length, look for the 0x01 delimiter.
137161
if sLen == PSSSaltLengthAuto {
138-
FindSaltLength:
139-
for sLen = emLen - (hLen + 2); sLen >= 0; sLen-- {
140-
switch db[emLen-hLen-sLen-2] {
141-
case 1:
142-
break FindSaltLength
143-
case 0:
144-
continue
145-
default:
146-
return ErrVerification
147-
}
148-
}
149-
if sLen < 0 {
162+
psLen := bytes.IndexByte(db, 0x01)
163+
if psLen < 0 {
150164
return ErrVerification
151165
}
152-
} else {
153-
// 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero
154-
// or if the octet at position emLen - hLen - sLen - 1 (the leftmost
155-
// position is "position 1") does not have hexadecimal value 0x01,
156-
// output "inconsistent" and stop.
157-
for _, e := range db[:emLen-hLen-sLen-2] {
158-
if e != 0x00 {
159-
return ErrVerification
160-
}
161-
}
162-
if db[emLen-hLen-sLen-2] != 0x01 {
166+
sLen = len(db) - psLen - 1
167+
}
168+
169+
// 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero
170+
// or if the octet at position emLen - hLen - sLen - 1 (the leftmost
171+
// position is "position 1") does not have hexadecimal value 0x01,
172+
// output "inconsistent" and stop.
173+
psLen := emLen - hLen - sLen - 2
174+
for _, e := range db[:psLen] {
175+
if e != 0x00 {
163176
return ErrVerification
164177
}
165178
}
179+
if db[psLen] != 0x01 {
180+
return ErrVerification
181+
}
166182

167183
// 11. Let salt be the last sLen octets of DB.
168184
salt := db[len(db)-sLen:]
@@ -181,19 +197,19 @@ func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
181197
h0 := hash.Sum(nil)
182198

183199
// 14. If H = H', output "consistent." Otherwise, output "inconsistent."
184-
if !bytes.Equal(h0, h) {
200+
if !bytes.Equal(h0, h) { // TODO: constant time?
185201
return ErrVerification
186202
}
187203
return nil
188204
}
189205

190-
// signPSSWithSalt calculates the signature of hashed using PSS [1] with specified salt.
206+
// signPSSWithSalt calculates the signature of hashed using PSS with specified salt.
191207
// Note that hashed must be the result of hashing the input message using the
192208
// given hash function. salt is a random sequence of bytes whose length will be
193209
// later used to verify the signature.
194210
func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) (s []byte, err error) {
195-
nBits := priv.N.BitLen()
196-
em, err := emsaPSSEncode(hashed, nBits-1, salt, hash.New())
211+
emBits := priv.N.BitLen() - 1
212+
em, err := emsaPSSEncode(hashed, emBits, salt, hash.New())
197213
if err != nil {
198214
return
199215
}
@@ -202,7 +218,7 @@ func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed,
202218
if err != nil {
203219
return
204220
}
205-
s = make([]byte, (nBits+7)/8)
221+
s = make([]byte, priv.Size())
206222
copyWithLeftPad(s, c.Bytes())
207223
return
208224
}
@@ -223,16 +239,15 @@ type PSSOptions struct {
223239
// PSSSaltLength constants.
224240
SaltLength int
225241

226-
// Hash, if not zero, overrides the hash function passed to SignPSS.
227-
// This is the only way to specify the hash function when using the
228-
// crypto.Signer interface.
242+
// Hash is the hash function used to generate the message digest. If not
243+
// zero, it overrides the hash function passed to SignPSS. It's required
244+
// when using PrivateKey.Sign.
229245
Hash crypto.Hash
230246
}
231247

232-
// HashFunc returns pssOpts.Hash so that PSSOptions implements
233-
// crypto.SignerOpts.
234-
func (pssOpts *PSSOptions) HashFunc() crypto.Hash {
235-
return pssOpts.Hash
248+
// HashFunc returns opts.Hash so that PSSOptions implements crypto.SignerOpts.
249+
func (opts *PSSOptions) HashFunc() crypto.Hash {
250+
return opts.Hash
236251
}
237252

238253
func (opts *PSSOptions) saltLength() int {
@@ -242,56 +257,50 @@ func (opts *PSSOptions) saltLength() int {
242257
return opts.SaltLength
243258
}
244259

245-
// SignPSS calculates the signature of hashed using RSASSA-PSS [1].
246-
// Note that hashed must be the result of hashing the input message using the
247-
// given hash function. The opts argument may be nil, in which case sensible
248-
// defaults are used.
249-
func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed []byte, opts *PSSOptions) ([]byte, error) {
260+
// SignPSS calculates the signature of digest using PSS.
261+
//
262+
// digest must be the result of hashing the input message using the given hash
263+
// function. The opts argument may be nil, in which case sensible defaults are
264+
// used. If opts.Hash is set, it overrides hash.
265+
func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, digest []byte, opts *PSSOptions) ([]byte, error) {
266+
if opts != nil && opts.Hash != 0 {
267+
hash = opts.Hash
268+
}
269+
250270
saltLength := opts.saltLength()
251271
switch saltLength {
252272
case PSSSaltLengthAuto:
253-
saltLength = (priv.N.BitLen()+7)/8 - 2 - hash.Size()
273+
saltLength = priv.Size() - 2 - hash.Size()
254274
case PSSSaltLengthEqualsHash:
255275
saltLength = hash.Size()
256276
}
257277

258-
if opts != nil && opts.Hash != 0 {
259-
hash = opts.Hash
260-
}
261-
262278
salt := make([]byte, saltLength)
263279
if _, err := io.ReadFull(rand, salt); err != nil {
264280
return nil, err
265281
}
266-
return signPSSWithSalt(rand, priv, hash, hashed, salt)
282+
return signPSSWithSalt(rand, priv, hash, digest, salt)
267283
}
268284

269285
// VerifyPSS verifies a PSS signature.
270-
// hashed is the result of hashing the input message using the given hash
271-
// function and sig is the signature. A valid signature is indicated by
272-
// returning a nil error. The opts argument may be nil, in which case sensible
273-
// defaults are used.
274-
func VerifyPSS(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte, opts *PSSOptions) error {
275-
return verifyPSS(pub, hash, hashed, sig, opts.saltLength())
276-
}
277-
278-
// verifyPSS verifies a PSS signature with the given salt length.
279-
func verifyPSS(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte, saltLen int) error {
280-
nBits := pub.N.BitLen()
281-
if len(sig) != (nBits+7)/8 {
286+
//
287+
// A valid signature is indicated by returning a nil error. digest must be the
288+
// result of hashing the input message using the given hash function. The opts
289+
// argument may be nil, in which case sensible defaults are used. opts.Hash is
290+
// ignored.
291+
func VerifyPSS(pub *PublicKey, hash crypto.Hash, digest []byte, sig []byte, opts *PSSOptions) error {
292+
if len(sig) != pub.Size() {
282293
return ErrVerification
283294
}
284295
s := new(big.Int).SetBytes(sig)
285296
m := encrypt(new(big.Int), pub, s)
286-
emBits := nBits - 1
297+
emBits := pub.N.BitLen() - 1
287298
emLen := (emBits + 7) / 8
288-
if emLen < len(m.Bytes()) {
299+
emBytes := m.Bytes()
300+
if emLen < len(emBytes) {
289301
return ErrVerification
290302
}
291303
em := make([]byte, emLen)
292-
copyWithLeftPad(em, m.Bytes())
293-
if saltLen == PSSSaltLengthEqualsHash {
294-
saltLen = hash.Size()
295-
}
296-
return emsaPSSVerify(hashed, em, emBits, saltLen, hash.New())
304+
copyWithLeftPad(em, emBytes)
305+
return emsaPSSVerify(digest, em, emBits, opts.saltLength(), hash.New())
297306
}

src/crypto/rsa/rsa.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,21 @@
22
// Use of this source code is governed by a BSD-style
33
// license that can be found in the LICENSE file.
44

5-
// Package rsa implements RSA encryption as specified in PKCS#1.
5+
// Package rsa implements RSA encryption as specified in PKCS#1 and RFC 8017.
66
//
77
// RSA is a single, fundamental operation that is used in this package to
88
// implement either public-key encryption or public-key signatures.
99
//
1010
// The original specification for encryption and signatures with RSA is PKCS#1
1111
// and the terms "RSA encryption" and "RSA signatures" by default refer to
1212
// PKCS#1 version 1.5. However, that specification has flaws and new designs
13-
// should use version two, usually called by just OAEP and PSS, where
13+
// should use version 2, usually called by just OAEP and PSS, where
1414
// possible.
1515
//
1616
// Two sets of interfaces are included in this package. When a more abstract
1717
// interface isn't necessary, there are functions for encrypting/decrypting
1818
// with v1.5/OAEP and signing/verifying with v1.5/PSS. If one needs to abstract
19-
// over the public-key primitive, the PrivateKey struct implements the
19+
// over the public key primitive, the PrivateKey type implements the
2020
// Decrypter and Signer interfaces from the crypto package.
2121
//
2222
// The RSA operations in this package are not implemented using constant-time algorithms.
@@ -111,7 +111,8 @@ func (priv *PrivateKey) Public() crypto.PublicKey {
111111

112112
// Sign signs digest with priv, reading randomness from rand. If opts is a
113113
// *PSSOptions then the PSS algorithm will be used, otherwise PKCS#1 v1.5 will
114-
// be used.
114+
// be used. digest must be the result of hashing the input message using
115+
// opts.HashFunc().
115116
//
116117
// This method implements crypto.Signer, which is an interface to support keys
117118
// where the private part is kept in, for example, a hardware module. Common

0 commit comments

Comments
 (0)