diff --git a/keyring.go b/keyring.go index 605c176c..8484ea99 100644 --- a/keyring.go +++ b/keyring.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "runtime" "strings" "sync" "time" @@ -325,39 +326,69 @@ func (k *KeyRing) checkUsingKeys( requests []VerifyJSONRequest, results []VerifyJSONResult, keyIDs [][]KeyID, keys map[PublicKeyLookupRequest]PublicKeyLookupResult, ) { + procs := runtime.NumCPU() - 1 + if procs < 1 { + procs = 1 + } + type job struct { + index int // the original index in the requests/results array + request VerifyJSONRequest // the request itself + } + jobs := make(map[int][]job) for i := range requests { - if results[i].Error == nil { - // We've already checked this message and it passed the signature checks. - // So we can skip to the next message. - continue - } - for _, keyID := range keyIDs[i] { - serverKey, ok := keys[PublicKeyLookupRequest{requests[i].ServerName, keyID}] - if !ok { - // No key for this key ID so we continue onto the next key ID. - continue - } - if !serverKey.WasValidAt(requests[i].AtTS, requests[i].StrictValidityChecking) { - // The key wasn't valid at the timestamp we needed it to be valid at. - // So skip onto the next key. - results[i].Error = fmt.Errorf( - "gomatrixserverlib: key with ID %q for %q not valid at %d", - keyID, requests[i].ServerName, requests[i].AtTS, - ) - continue - } - if err := VerifyJSON( - string(requests[i].ServerName), keyID, ed25519.PublicKey(serverKey.Key), requests[i].Message, - ); err != nil { - // The signature wasn't valid, record the error and try the next key ID. - results[i].Error = err - continue + jobs[i%procs] = append(jobs[i%procs], job{i, requests[i]}) + } + var wg sync.WaitGroup // tracks the workers + var mu sync.RWMutex // protects results array + wg.Add(len(jobs)) + for _, j := range jobs { + go func(jobs []job) { + for _, j := range jobs { + mu.RLock() + if results[j.index].Error == nil { + // We've already checked this message and it passed the signature checks. + // So we can skip to the next message. + mu.RUnlock() + continue + } + mu.RUnlock() + for _, keyID := range keyIDs[j.index] { + serverKey, ok := keys[PublicKeyLookupRequest{j.request.ServerName, keyID}] + if !ok { + // No key for this key ID so we continue onto the next key ID. + continue + } + if !serverKey.WasValidAt(j.request.AtTS, j.request.StrictValidityChecking) { + // The key wasn't valid at the timestamp we needed it to be valid at. + // So skip onto the next key. + mu.Lock() + results[j.index].Error = fmt.Errorf( + "gomatrixserverlib: key with ID %q for %q not valid at %d", + keyID, j.request.ServerName, j.request.AtTS, + ) + mu.Unlock() + continue + } + if err := VerifyJSON( + string(j.request.ServerName), keyID, ed25519.PublicKey(serverKey.Key), j.request.Message, + ); err != nil { + // The signature wasn't valid, record the error and try the next key ID. + mu.Lock() + results[j.index].Error = err + mu.Unlock() + continue + } + // The signature is valid, set the result to nil. + mu.Lock() + results[j.index].Error = nil + mu.Unlock() + break + } } - // The signature is valid, set the result to nil. - results[i].Error = nil - break - } + wg.Done() + }(j) } + wg.Wait() } type KeyClient interface {