diff --git a/eventcontent_test.go b/eventcontent_test.go index ee13de4c..53920851 100644 --- a/eventcontent_test.go +++ b/eventcontent_test.go @@ -217,11 +217,13 @@ func TestMXIDMapping_SignValidate(t *testing.T) { assert.NoError(t, err) // this should pass - err = validateMXIDMappingSignature(context.Background(), ev, &StubVerifier{}, verImpl) + evMapping, err := getMXIDMapping(ev) + assert.NoError(t, err) + err = validateMXIDMappingSignatures(context.Background(), ev, *evMapping, &StubVerifier{}, verImpl) assert.NoError(t, err) // this fails, for some random reason - err = validateMXIDMappingSignature(context.Background(), ev, &StubVerifier{ + err = validateMXIDMappingSignatures(context.Background(), ev, *evMapping, &StubVerifier{ results: []VerifyJSONResult{{Error: fmt.Errorf("err")}}, }, verImpl) assert.Error(t, err) @@ -231,7 +233,6 @@ func TestMXIDMapping_SignValidate(t *testing.T) { ev, err = eb.Build(time.Now(), serverName, keyID, priv) assert.NoError(t, err) - err = validateMXIDMappingSignature(context.Background(), ev, &StubVerifier{}, verImpl) + _, err = getMXIDMapping(ev) assert.Error(t, err) - } diff --git a/eventcrypto.go b/eventcrypto.go index cc30f2c1..6d4a7add 100644 --- a/eventcrypto.go +++ b/eventcrypto.go @@ -85,7 +85,11 @@ func VerifyEventSignatures(ctx context.Context, e PDU, verifier JSONVerifier, us // Validate the MXIDMapping is signed correctly if verImpl.Version() == RoomVersionPseudoIDs && membership == spec.Join { - err = validateMXIDMappingSignature(ctx, e, verifier, verImpl) + mapping, err := getMXIDMapping(e) + if err != nil { + return err + } + err = validateMXIDMappingSignatures(ctx, e, *mapping, verifier, verImpl) if err != nil { return err } @@ -154,28 +158,32 @@ func VerifyEventSignatures(ctx context.Context, e PDU, verifier JSONVerifier, us return nil } -// validateMXIDMappingSignature validates that the MXIDMapping is correctly signed -func validateMXIDMappingSignature(ctx context.Context, e PDU, verifier JSONVerifier, verImpl IRoomVersion) error { +func getMXIDMapping(e PDU) (*MXIDMapping, error) { var content MemberContent err := json.Unmarshal(e.Content(), &content) if err != nil { - return err + return nil, err } // if there is no mapping, we can't check the signature if content.MXIDMapping == nil { - return fmt.Errorf("missing mxid_mapping, unable to validate event") + return nil, fmt.Errorf("missing mxid_mapping") } - var toVerify []VerifyJSONRequest + return content.MXIDMapping, nil +} - mapping, err := json.Marshal(content.MXIDMapping) +// validateMXIDMappingSignatures validates that the MXIDMapping is correctly signed +func validateMXIDMappingSignatures(ctx context.Context, e PDU, mapping MXIDMapping, verifier JSONVerifier, verImpl IRoomVersion) error { + mappingBytes, err := json.Marshal(mapping) if err != nil { return err } - for s := range content.MXIDMapping.Signatures { + + var toVerify []VerifyJSONRequest + for s := range mapping.Signatures { v := VerifyJSONRequest{ - Message: mapping, + Message: mappingBytes, AtTS: e.OriginServerTS(), ServerName: s, ValidityCheckingFunc: verImpl.SignatureValidityCheck, diff --git a/handlejoin.go b/handlejoin.go index 468d209d..56234995 100644 --- a/handlejoin.go +++ b/handlejoin.go @@ -22,7 +22,6 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" - "github.com/tidwall/gjson" ) type HandleMakeJoinInput struct { @@ -351,15 +350,14 @@ func HandleSendJoin(input HandleSendJoinInput) (*HandleSendJoinResponse, error) // validate the mxid_mapping of the event if input.RoomVersion == RoomVersionPseudoIDs { // validate the signature first - if err = validateMXIDMappingSignature(input.Context, event, input.Verifier, verImpl); err != nil { + mapping, err := getMXIDMapping(event) + if err != nil { + return nil, spec.BadJSON(err.Error()) + } + if err = validateMXIDMappingSignatures(input.Context, event, *mapping, input.Verifier, verImpl); err != nil { return nil, spec.Forbidden(err.Error()) } - mapping := MXIDMapping{} - err = json.Unmarshal([]byte(gjson.GetBytes(input.JoinEvent, "content.mxid_mapping").Raw), &mapping) - if err != nil { - return nil, err - } // store the user room public key -> userID mapping if err = input.StoreSenderIDFromPublicID(input.Context, mapping.UserRoomKey, mapping.UserID, input.RoomID); err != nil { return nil, err diff --git a/performjoin.go b/performjoin.go index ce76bf59..db480d5a 100644 --- a/performjoin.go +++ b/performjoin.go @@ -304,20 +304,17 @@ func storeMXIDMappings( if ev.Type() != spec.MRoomMember { continue } - mapping := MemberContent{} - if err := json.Unmarshal(ev.Content(), &mapping); err != nil { + mapping, err := getMXIDMapping(ev) + if err != nil { return err } - if mapping.MXIDMapping == nil { - continue - } // we already validated it is a valid roomversion, so this should be safe to use. verImpl := MustGetRoomVersion(ev.Version()) - if err := validateMXIDMappingSignature(ctx, ev, keyRing, verImpl); err != nil { + if err := validateMXIDMappingSignatures(ctx, ev, *mapping, keyRing, verImpl); err != nil { logrus.WithError(err).Error("invalid signature for mxid_mapping") continue } - if err := storeSenderID(ctx, ev.SenderID(), mapping.MXIDMapping.UserID, roomID); err != nil { + if err := storeSenderID(ctx, ev.SenderID(), mapping.UserID, roomID); err != nil { return err } }