4
4
5
5
package rsa
6
6
7
- // This file implements the PSS signature scheme [1].
8
- //
9
- // [1] https://www.emc.com/collateral/white-papers/h11300-pkcs-1v2-2-rsa-cryptography-standard-wp.pdf
7
+ // This file implements the RSASSA-PSS signature scheme according to RFC 8017.
10
8
11
9
import (
12
10
"bytes"
@@ -17,8 +15,22 @@ import (
17
15
"math/big"
18
16
)
19
17
18
+ // Per RFC 8017, Section 9.1
19
+ //
20
+ // EM = MGF1 xor DB || H( 8*0x00 || mHash || salt ) || 0xbc
21
+ //
22
+ // where
23
+ //
24
+ // DB = PS || 0x01 || salt
25
+ //
26
+ // and PS can be empty so
27
+ //
28
+ // emLen = dbLen + hLen + 1 = psLen + sLen + hLen + 2
29
+ //
30
+
20
31
func emsaPSSEncode (mHash []byte , emBits int , salt []byte , hash hash.Hash ) ([]byte , error ) {
21
- // See [1], section 9.1.1
32
+ // See RFC 8017, Section 9.1.1.
33
+
22
34
hLen := hash .Size ()
23
35
sLen := len (salt )
24
36
emLen := (emBits + 7 ) / 8
@@ -30,7 +42,7 @@ func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byt
30
42
// 2. Let mHash = Hash(M), an octet string of length hLen.
31
43
32
44
if len (mHash ) != hLen {
33
- return nil , errors .New ("crypto/rsa: input must be hashed message " )
45
+ return nil , errors .New ("crypto/rsa: input must be hashed with given hash " )
34
46
}
35
47
36
48
// 3. If emLen < hLen + sLen + 2, output "encoding error" and stop.
@@ -40,8 +52,9 @@ func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byt
40
52
}
41
53
42
54
em := make ([]byte , emLen )
43
- db := em [:emLen - sLen - hLen - 2 + 1 + sLen ]
44
- h := em [emLen - sLen - hLen - 2 + 1 + sLen : emLen - 1 ]
55
+ psLen := emLen - sLen - hLen - 2
56
+ db := em [:psLen + 1 + sLen ]
57
+ h := em [psLen + 1 + sLen : emLen - 1 ]
45
58
46
59
// 4. Generate a random octet string salt of length sLen; if sLen = 0,
47
60
// then salt is the empty string.
@@ -69,8 +82,8 @@ func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byt
69
82
// 8. Let DB = PS || 0x01 || salt; DB is an octet string of length
70
83
// emLen - hLen - 1.
71
84
72
- db [emLen - sLen - hLen - 2 ] = 0x01
73
- copy (db [emLen - sLen - hLen - 1 :], salt )
85
+ db [psLen ] = 0x01
86
+ copy (db [psLen + 1 :], salt )
74
87
75
88
// 9. Let dbMask = MGF(H, emLen - hLen - 1).
76
89
//
@@ -81,47 +94,57 @@ func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byt
81
94
// 11. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in
82
95
// maskedDB to zero.
83
96
84
- db [0 ] &= ( 0xFF >> uint (8 * emLen - emBits ) )
97
+ db [0 ] &= 0xff >> (8 * emLen - emBits )
85
98
86
99
// 12. Let EM = maskedDB || H || 0xbc.
87
- em [emLen - 1 ] = 0xBC
100
+ em [emLen - 1 ] = 0xbc
88
101
89
102
// 13. Output EM.
90
103
return em , nil
91
104
}
92
105
93
106
func emsaPSSVerify (mHash , em []byte , emBits , sLen int , hash hash.Hash ) error {
107
+ // See RFC 8017, Section 9.1.2.
108
+
109
+ hLen := hash .Size ()
110
+ if sLen == PSSSaltLengthEqualsHash {
111
+ sLen = hLen
112
+ }
113
+ emLen := (emBits + 7 ) / 8
114
+ if emLen != len (em ) {
115
+ return errors .New ("rsa: internal error: inconsistent length" )
116
+ }
117
+
94
118
// 1. If the length of M is greater than the input limitation for the
95
119
// hash function (2^61 - 1 octets for SHA-1), output "inconsistent"
96
120
// and stop.
97
121
//
98
122
// 2. Let mHash = Hash(M), an octet string of length hLen.
99
- hLen := hash .Size ()
100
123
if hLen != len (mHash ) {
101
124
return ErrVerification
102
125
}
103
126
104
127
// 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop.
105
- emLen := (emBits + 7 ) / 8
106
128
if emLen < hLen + sLen + 2 {
107
129
return ErrVerification
108
130
}
109
131
110
132
// 4. If the rightmost octet of EM does not have hexadecimal value
111
133
// 0xbc, output "inconsistent" and stop.
112
- if em [len ( em ) - 1 ] != 0xBC {
134
+ if em [emLen - 1 ] != 0xbc {
113
135
return ErrVerification
114
136
}
115
137
116
138
// 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and
117
139
// let H be the next hLen octets.
118
140
db := em [:emLen - hLen - 1 ]
119
- h := em [emLen - hLen - 1 : len ( em ) - 1 ]
141
+ h := em [emLen - hLen - 1 : emLen - 1 ]
120
142
121
143
// 6. If the leftmost 8 * emLen - emBits bits of the leftmost octet in
122
144
// maskedDB are not all equal to zero, output "inconsistent" and
123
145
// stop.
124
- if em [0 ]& (0xFF << uint (8 - (8 * emLen - emBits ))) != 0 {
146
+ var bitMask byte = 0xff >> (8 * emLen - emBits )
147
+ if em [0 ] & ^ bitMask != 0 {
125
148
return ErrVerification
126
149
}
127
150
@@ -132,37 +155,30 @@ func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
132
155
133
156
// 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB
134
157
// to zero.
135
- db [0 ] &= ( 0xFF >> uint ( 8 * emLen - emBits ))
158
+ db [0 ] &= bitMask
136
159
160
+ // If we don't know the salt length, look for the 0x01 delimiter.
137
161
if sLen == PSSSaltLengthAuto {
138
- FindSaltLength:
139
- for sLen = emLen - (hLen + 2 ); sLen >= 0 ; sLen -- {
140
- switch db [emLen - hLen - sLen - 2 ] {
141
- case 1 :
142
- break FindSaltLength
143
- case 0 :
144
- continue
145
- default :
146
- return ErrVerification
147
- }
148
- }
149
- if sLen < 0 {
162
+ psLen := bytes .IndexByte (db , 0x01 )
163
+ if psLen < 0 {
150
164
return ErrVerification
151
165
}
152
- } else {
153
- // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero
154
- // or if the octet at position emLen - hLen - sLen - 1 (the leftmost
155
- // position is "position 1") does not have hexadecimal value 0x01,
156
- // output "inconsistent" and stop.
157
- for _ , e := range db [:emLen - hLen - sLen - 2 ] {
158
- if e != 0x00 {
159
- return ErrVerification
160
- }
161
- }
162
- if db [emLen - hLen - sLen - 2 ] != 0x01 {
166
+ sLen = len (db ) - psLen - 1
167
+ }
168
+
169
+ // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero
170
+ // or if the octet at position emLen - hLen - sLen - 1 (the leftmost
171
+ // position is "position 1") does not have hexadecimal value 0x01,
172
+ // output "inconsistent" and stop.
173
+ psLen := emLen - hLen - sLen - 2
174
+ for _ , e := range db [:psLen ] {
175
+ if e != 0x00 {
163
176
return ErrVerification
164
177
}
165
178
}
179
+ if db [psLen ] != 0x01 {
180
+ return ErrVerification
181
+ }
166
182
167
183
// 11. Let salt be the last sLen octets of DB.
168
184
salt := db [len (db )- sLen :]
@@ -181,19 +197,19 @@ func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
181
197
h0 := hash .Sum (nil )
182
198
183
199
// 14. If H = H', output "consistent." Otherwise, output "inconsistent."
184
- if ! bytes .Equal (h0 , h ) {
200
+ if ! bytes .Equal (h0 , h ) { // TODO: constant time?
185
201
return ErrVerification
186
202
}
187
203
return nil
188
204
}
189
205
190
- // signPSSWithSalt calculates the signature of hashed using PSS [1] with specified salt.
206
+ // signPSSWithSalt calculates the signature of hashed using PSS with specified salt.
191
207
// Note that hashed must be the result of hashing the input message using the
192
208
// given hash function. salt is a random sequence of bytes whose length will be
193
209
// later used to verify the signature.
194
210
func signPSSWithSalt (rand io.Reader , priv * PrivateKey , hash crypto.Hash , hashed , salt []byte ) (s []byte , err error ) {
195
- nBits := priv .N .BitLen ()
196
- em , err := emsaPSSEncode (hashed , nBits - 1 , salt , hash .New ())
211
+ emBits := priv .N .BitLen () - 1
212
+ em , err := emsaPSSEncode (hashed , emBits , salt , hash .New ())
197
213
if err != nil {
198
214
return
199
215
}
@@ -202,7 +218,7 @@ func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed,
202
218
if err != nil {
203
219
return
204
220
}
205
- s = make ([]byte , ( nBits + 7 ) / 8 )
221
+ s = make ([]byte , priv . Size () )
206
222
copyWithLeftPad (s , c .Bytes ())
207
223
return
208
224
}
@@ -223,16 +239,15 @@ type PSSOptions struct {
223
239
// PSSSaltLength constants.
224
240
SaltLength int
225
241
226
- // Hash, if not zero, overrides the hash function passed to SignPSS.
227
- // This is the only way to specify the hash function when using the
228
- // crypto.Signer interface .
242
+ // Hash is the hash function used to generate the message digest. If not
243
+ // zero, it overrides the hash function passed to SignPSS. It's required
244
+ // when using PrivateKey.Sign .
229
245
Hash crypto.Hash
230
246
}
231
247
232
- // HashFunc returns pssOpts.Hash so that PSSOptions implements
233
- // crypto.SignerOpts.
234
- func (pssOpts * PSSOptions ) HashFunc () crypto.Hash {
235
- return pssOpts .Hash
248
+ // HashFunc returns opts.Hash so that PSSOptions implements crypto.SignerOpts.
249
+ func (opts * PSSOptions ) HashFunc () crypto.Hash {
250
+ return opts .Hash
236
251
}
237
252
238
253
func (opts * PSSOptions ) saltLength () int {
@@ -242,56 +257,50 @@ func (opts *PSSOptions) saltLength() int {
242
257
return opts .SaltLength
243
258
}
244
259
245
- // SignPSS calculates the signature of hashed using RSASSA-PSS [1].
246
- // Note that hashed must be the result of hashing the input message using the
247
- // given hash function. The opts argument may be nil, in which case sensible
248
- // defaults are used.
249
- func SignPSS (rand io.Reader , priv * PrivateKey , hash crypto.Hash , hashed []byte , opts * PSSOptions ) ([]byte , error ) {
260
+ // SignPSS calculates the signature of digest using PSS.
261
+ //
262
+ // digest must be the result of hashing the input message using the given hash
263
+ // function. The opts argument may be nil, in which case sensible defaults are
264
+ // used. If opts.Hash is set, it overrides hash.
265
+ func SignPSS (rand io.Reader , priv * PrivateKey , hash crypto.Hash , digest []byte , opts * PSSOptions ) ([]byte , error ) {
266
+ if opts != nil && opts .Hash != 0 {
267
+ hash = opts .Hash
268
+ }
269
+
250
270
saltLength := opts .saltLength ()
251
271
switch saltLength {
252
272
case PSSSaltLengthAuto :
253
- saltLength = ( priv .N . BitLen () + 7 ) / 8 - 2 - hash .Size ()
273
+ saltLength = priv .Size () - 2 - hash .Size ()
254
274
case PSSSaltLengthEqualsHash :
255
275
saltLength = hash .Size ()
256
276
}
257
277
258
- if opts != nil && opts .Hash != 0 {
259
- hash = opts .Hash
260
- }
261
-
262
278
salt := make ([]byte , saltLength )
263
279
if _ , err := io .ReadFull (rand , salt ); err != nil {
264
280
return nil , err
265
281
}
266
- return signPSSWithSalt (rand , priv , hash , hashed , salt )
282
+ return signPSSWithSalt (rand , priv , hash , digest , salt )
267
283
}
268
284
269
285
// VerifyPSS verifies a PSS signature.
270
- // hashed is the result of hashing the input message using the given hash
271
- // function and sig is the signature. A valid signature is indicated by
272
- // returning a nil error. The opts argument may be nil, in which case sensible
273
- // defaults are used.
274
- func VerifyPSS (pub * PublicKey , hash crypto.Hash , hashed []byte , sig []byte , opts * PSSOptions ) error {
275
- return verifyPSS (pub , hash , hashed , sig , opts .saltLength ())
276
- }
277
-
278
- // verifyPSS verifies a PSS signature with the given salt length.
279
- func verifyPSS (pub * PublicKey , hash crypto.Hash , hashed []byte , sig []byte , saltLen int ) error {
280
- nBits := pub .N .BitLen ()
281
- if len (sig ) != (nBits + 7 )/ 8 {
286
+ //
287
+ // A valid signature is indicated by returning a nil error. digest must be the
288
+ // result of hashing the input message using the given hash function. The opts
289
+ // argument may be nil, in which case sensible defaults are used. opts.Hash is
290
+ // ignored.
291
+ func VerifyPSS (pub * PublicKey , hash crypto.Hash , digest []byte , sig []byte , opts * PSSOptions ) error {
292
+ if len (sig ) != pub .Size () {
282
293
return ErrVerification
283
294
}
284
295
s := new (big.Int ).SetBytes (sig )
285
296
m := encrypt (new (big.Int ), pub , s )
286
- emBits := nBits - 1
297
+ emBits := pub . N . BitLen () - 1
287
298
emLen := (emBits + 7 ) / 8
288
- if emLen < len (m .Bytes ()) {
299
+ emBytes := m .Bytes ()
300
+ if emLen < len (emBytes ) {
289
301
return ErrVerification
290
302
}
291
303
em := make ([]byte , emLen )
292
- copyWithLeftPad (em , m .Bytes ())
293
- if saltLen == PSSSaltLengthEqualsHash {
294
- saltLen = hash .Size ()
295
- }
296
- return emsaPSSVerify (hashed , em , emBits , saltLen , hash .New ())
304
+ copyWithLeftPad (em , emBytes )
305
+ return emsaPSSVerify (digest , em , emBits , opts .saltLength (), hash .New ())
297
306
}
0 commit comments