diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index b33799dc9..63e25f7e1 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -209,32 +209,60 @@ func (s *OutputRoomEventConsumer) storeMessageStats(ctx context.Context, eventTy func (s *OutputRoomEventConsumer) handleRoomUpgrade(ctx context.Context, oldRoomID, newRoomID string, localMembers []*localMembership, roomSize int) error { for _, membership := range localMembers { // Copy any existing push rules from old -> new room - if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain); err != nil { + changed, err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain) + if err != nil { return err } + // Inform the SyncAPI about the updated push_rules + if changed { + if err = s.syncProducer.SendAccountData(membership.Localpart, eventutil.AccountData{ + Type: "m.push_rules", + }); err != nil { + return err + } + } // preserve m.direct room state - if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain, roomSize); err != nil { + changed, err = s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain, roomSize) + if err != nil { return err } + // Inform the SyncAPI about the updated m.direct + if changed { + if err = s.syncProducer.SendAccountData(membership.Localpart, eventutil.AccountData{ + Type: "m.direct", + }); err != nil { + return err + } + } // copy existing m.tag entries, if any - if err := s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain); err != nil { + changed, err = s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain) + if err != nil { return err } + // Inform the SyncAPI about the updated m.tag + if changed { + if err = s.syncProducer.SendAccountData(membership.Localpart, eventutil.AccountData{ + Type: "m.tag", + }); err != nil { + return err + } + } } return nil } -func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string, serverName spec.ServerName) error { +func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName spec.ServerName) (hasChanges bool, err error) { pushRules, err := s.db.QueryPushRules(ctx, localpart, serverName) if err != nil { - return fmt.Errorf("failed to query pushrules for user: %w", err) + return false, err } if pushRules == nil { - return nil + return false, err } + var rulesBytes []byte for _, roomRule := range pushRules.Global.Room { if roomRule.RuleID != oldRoomID { continue @@ -242,27 +270,28 @@ func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, cpRool := *roomRule cpRool.RuleID = newRoomID pushRules.Global.Room = append(pushRules.Global.Room, &cpRool) - rules, err := json.Marshal(pushRules) + rulesBytes, err = json.Marshal(pushRules) if err != nil { - return err + return false, err } - if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.push_rules", rules); err != nil { - return fmt.Errorf("failed to update pushrules: %w", err) + if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.push_rules", rulesBytes); err != nil { + return false, err } + hasChanges = true } - return nil + return hasChanges, err } // updateMDirect copies the "is_direct" flag from oldRoomID to newROomID -func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName spec.ServerName, roomSize int) error { +func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName spec.ServerName, roomSize int) (hasChanges bool, err error) { // this is most likely not a DM, so skip updating m.direct state if roomSize > 2 { - return nil + return false, nil } // Get direct message state directChatsRaw, err := s.db.GetAccountDataByType(ctx, localpart, serverName, "", "m.direct") if err != nil { - return fmt.Errorf("failed to get m.direct from database: %w", err) + return false, fmt.Errorf("failed to get m.direct from database: %w", err) } directChats := gjson.ParseBytes(directChatsRaw) newDirectChats := make(map[string][]string) @@ -285,25 +314,29 @@ func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, var data []byte data, err = json.Marshal(newDirectChats) if err != nil { - return err + return false, err } if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.direct", data); err != nil { - return fmt.Errorf("failed to update m.direct state: %w", err) + return false, fmt.Errorf("failed to update m.direct state: %w", err) } + return true, nil } - return nil + return false, nil } -func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName spec.ServerName) error { +func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName spec.ServerName) (hasChanges bool, err error) { tag, err := s.db.GetAccountDataByType(ctx, localpart, serverName, oldRoomID, "m.tag") if err != nil && !errors.Is(err, sql.ErrNoRows) { - return err + return false, err } if tag == nil { - return nil + return false, nil + } + if err := s.db.SaveAccountData(ctx, localpart, serverName, newRoomID, "m.tag", tag); err != nil { + return false, err } - return s.db.SaveAccountData(ctx, localpart, serverName, newRoomID, "m.tag", tag) + return true, nil } func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rstypes.HeaderedEvent, streamPos uint64) error {