Skip to content

Commit 22b5c14

Browse files
FiloSottilegopherbot
authored andcommitted
crypto/internal/fips140/rsa: add Miller-Rabin test
A following CL will move key generation to crypto/internal/fips140/rsa. Updates #69799 For #69536 Change-Id: Icdf9b8424da20453939c6587af7dc922aad9e0ca Reviewed-on: https://go-review.googlesource.com/c/go/+/632215 Auto-Submit: Filippo Valsorda <[email protected]> Reviewed-by: Roland Shoemaker <[email protected]> LUCI-TryBot-Result: Go LUCI <[email protected]> Reviewed-by: Russ Cox <[email protected]> Reviewed-by: Daniel McCarney <[email protected]>
1 parent caee788 commit 22b5c14

File tree

6 files changed

+799
-19
lines changed

6 files changed

+799
-19
lines changed

src/crypto/internal/fips140/bigmod/nat.go

Lines changed: 98 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,56 @@ func (x *Nat) IsZero() choice {
244244
return zero
245245
}
246246

247+
// IsOne returns 1 if x == 1, and 0 otherwise.
248+
func (x *Nat) IsOne() choice {
249+
// Eliminate bounds checks in the loop.
250+
size := len(x.limbs)
251+
xLimbs := x.limbs[:size]
252+
253+
if len(xLimbs) == 0 {
254+
return no
255+
}
256+
257+
one := ctEq(xLimbs[0], 1)
258+
for i := 1; i < size; i++ {
259+
one &= ctEq(xLimbs[i], 0)
260+
}
261+
return one
262+
}
263+
264+
// IsMinusOne returns 1 if x == -1 mod m, and 0 otherwise.
265+
//
266+
// The length of x must be the same as the modulus. x must already be reduced
267+
// modulo m.
268+
func (x *Nat) IsMinusOne(m *Modulus) choice {
269+
minusOne := m.Nat()
270+
minusOne.SubOne(m)
271+
return x.Equal(minusOne)
272+
}
273+
274+
// IsOdd returns 1 if x is odd, and 0 otherwise.
275+
func (x *Nat) IsOdd() choice {
276+
if len(x.limbs) == 0 {
277+
return no
278+
}
279+
return choice(x.limbs[0] & 1)
280+
}
281+
282+
// TrailingZeroBitsVarTime returns the number of trailing zero bits in x.
283+
func (x *Nat) TrailingZeroBitsVarTime() uint {
284+
var t uint
285+
limbs := x.limbs
286+
for _, l := range limbs {
287+
if l == 0 {
288+
t += _W
289+
continue
290+
}
291+
t += uint(bits.TrailingZeros(l))
292+
break
293+
}
294+
return t
295+
}
296+
247297
// cmpGeq returns 1 if x >= y, and 0 otherwise.
248298
//
249299
// Both operands must have the same announced length.
@@ -308,6 +358,37 @@ func (x *Nat) sub(y *Nat) (c uint) {
308358
return
309359
}
310360

361+
// ShiftRightVarTime sets x = x >> n.
362+
//
363+
// The announced length of x is unchanged.
364+
func (x *Nat) ShiftRightVarTime(n uint) *Nat {
365+
// Eliminate bounds checks in the loop.
366+
size := len(x.limbs)
367+
xLimbs := x.limbs[:size]
368+
369+
shift := int(n % _W)
370+
shiftLimbs := int(n / _W)
371+
372+
var shiftedLimbs []uint
373+
if shiftLimbs < size {
374+
shiftedLimbs = xLimbs[shiftLimbs:]
375+
}
376+
377+
for i := range xLimbs {
378+
if i >= len(shiftedLimbs) {
379+
xLimbs[i] = 0
380+
continue
381+
}
382+
383+
xLimbs[i] = shiftedLimbs[i] >> shift
384+
if i+1 < len(shiftedLimbs) {
385+
xLimbs[i] |= shiftedLimbs[i+1] << (_W - shift)
386+
}
387+
}
388+
389+
return x
390+
}
391+
311392
// Modulus is used for modular arithmetic, precomputing relevant constants.
312393
//
313394
// A Modulus can leak the exact number of bits needed to store its value
@@ -403,7 +484,7 @@ func NewModulus(b []byte) (*Modulus, error) {
403484
return nil, errors.New("modulus must be > 0")
404485
}
405486
m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
406-
if m.nat.limbs[0]&1 == 1 {
487+
if m.nat.IsOdd() == 1 {
407488
m.odd = true
408489
m.m0inv = minusInverseModW(m.nat.limbs[0])
409490
m.rr = rr(m)
@@ -435,9 +516,13 @@ func (m *Modulus) BitLen() int {
435516
return len(m.nat.limbs)*_W - int(m.leading)
436517
}
437518

438-
// Nat returns m as a Nat. The return value must not be written to.
519+
// Nat returns m as a Nat.
439520
func (m *Modulus) Nat() *Nat {
440-
return m.nat
521+
// Make a copy so that the caller can't modify m.nat or alias it with
522+
// another Nat in a modulus operation.
523+
n := NewNat()
524+
n.set(m.nat)
525+
return n
441526
}
442527

443528
// shiftIn calculates x = x << _W + y mod m.
@@ -553,6 +638,16 @@ func (x *Nat) Sub(y *Nat, m *Modulus) *Nat {
553638
return x
554639
}
555640

641+
// SubOne computes x = x - 1 mod m.
642+
//
643+
// The length of x must be the same as the modulus. x must already be reduced
644+
// modulo m.
645+
func (x *Nat) SubOne(m *Modulus) *Nat {
646+
one := NewNat().ExpandFor(m)
647+
one.limbs[0] = 1
648+
return x.Sub(one, m)
649+
}
650+
556651
// Add computes x = x + y mod m.
557652
//
558653
// The length of both operands must be the same as the modulus. Both operands

src/crypto/internal/fips140/bigmod/nat_test.go

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ func (x *Nat) setBig(n *big.Int) *Nat {
3131
return x
3232
}
3333

34+
func (n *Nat) asBig() *big.Int {
35+
bits := make([]big.Word, len(n.limbs))
36+
for i := range n.limbs {
37+
bits[i] = big.Word(n.limbs[i])
38+
}
39+
return new(big.Int).SetBits(bits)
40+
}
41+
3442
func (n *Nat) String() string {
3543
var limbs []string
3644
for i := range n.limbs {
@@ -404,6 +412,98 @@ func testMul(t *testing.T, n int) {
404412
}
405413
}
406414

415+
func TestIs(t *testing.T) {
416+
checkYes := func(c choice, err string) {
417+
t.Helper()
418+
if c != yes {
419+
t.Error(err)
420+
}
421+
}
422+
checkNot := func(c choice, err string) {
423+
t.Helper()
424+
if c != no {
425+
t.Error(err)
426+
}
427+
}
428+
429+
mFour := modulusFromBytes([]byte{4})
430+
n, err := NewNat().SetBytes([]byte{3}, mFour)
431+
if err != nil {
432+
t.Fatal(err)
433+
}
434+
checkYes(n.IsMinusOne(mFour), "3 is not -1 mod 4")
435+
checkNot(n.IsZero(), "3 is zero")
436+
checkNot(n.IsOne(), "3 is one")
437+
checkYes(n.IsOdd(), "3 is not odd")
438+
n.SubOne(mFour)
439+
checkNot(n.IsMinusOne(mFour), "2 is -1 mod 4")
440+
checkNot(n.IsZero(), "2 is zero")
441+
checkNot(n.IsOne(), "2 is one")
442+
checkNot(n.IsOdd(), "2 is odd")
443+
n.SubOne(mFour)
444+
checkNot(n.IsMinusOne(mFour), "1 is -1 mod 4")
445+
checkNot(n.IsZero(), "1 is zero")
446+
checkYes(n.IsOne(), "1 is not one")
447+
checkYes(n.IsOdd(), "1 is not odd")
448+
n.SubOne(mFour)
449+
checkNot(n.IsMinusOne(mFour), "0 is -1 mod 4")
450+
checkYes(n.IsZero(), "0 is not zero")
451+
checkNot(n.IsOne(), "0 is one")
452+
checkNot(n.IsOdd(), "0 is odd")
453+
n.SubOne(mFour)
454+
checkYes(n.IsMinusOne(mFour), "-1 is not -1 mod 4")
455+
checkNot(n.IsZero(), "-1 is zero")
456+
checkNot(n.IsOne(), "-1 is one")
457+
checkYes(n.IsOdd(), "-1 mod 4 is not odd")
458+
459+
mTwoLimbs := maxModulus(2)
460+
n, err = NewNat().SetBytes([]byte{0x01}, mTwoLimbs)
461+
if err != nil {
462+
t.Fatal(err)
463+
}
464+
if n.IsOne() != 1 {
465+
t.Errorf("1 is not one")
466+
}
467+
}
468+
469+
func TestTrailingZeroBits(t *testing.T) {
470+
nb := new(big.Int).SetBytes([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7e})
471+
nb.Lsh(nb, 128)
472+
expected := 129
473+
for expected >= 0 {
474+
n := NewNat().setBig(nb)
475+
if n.TrailingZeroBitsVarTime() != uint(expected) {
476+
t.Errorf("%d != %d", n.TrailingZeroBitsVarTime(), expected)
477+
}
478+
nb.Rsh(nb, 1)
479+
expected--
480+
}
481+
}
482+
483+
func TestRightShift(t *testing.T) {
484+
nb, err := cryptorand.Int(cryptorand.Reader, new(big.Int).Lsh(big.NewInt(1), 1024))
485+
if err != nil {
486+
t.Fatal(err)
487+
}
488+
for _, shift := range []uint{1, 32, 64, 128, 1024 - 128, 1024 - 64, 1024 - 32, 1024 - 1} {
489+
testShift := func(t *testing.T, shift uint) {
490+
n := NewNat().setBig(nb)
491+
oldLen := len(n.limbs)
492+
n.ShiftRightVarTime(shift)
493+
if len(n.limbs) != oldLen {
494+
t.Errorf("len(n.limbs) = %d, want %d", len(n.limbs), oldLen)
495+
}
496+
exp := new(big.Int).Rsh(nb, shift)
497+
if n.asBig().Cmp(exp) != 0 {
498+
t.Errorf("%v != %v", n.asBig(), exp)
499+
}
500+
}
501+
t.Run(fmt.Sprint(shift-1), func(t *testing.T) { testShift(t, shift-1) })
502+
t.Run(fmt.Sprint(shift), func(t *testing.T) { testShift(t, shift) })
503+
t.Run(fmt.Sprint(shift+1), func(t *testing.T) { testShift(t, shift+1) })
504+
}
505+
}
506+
407507
func natBytes(n *Nat) []byte {
408508
return n.Bytes(maxModulus(uint(len(n.limbs))))
409509
}

0 commit comments

Comments
 (0)