Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 54 additions & 21 deletions userapi/consumers/roomserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,60 +209,89 @@ func (s *OutputRoomEventConsumer) storeMessageStats(ctx context.Context, eventTy
func (s *OutputRoomEventConsumer) handleRoomUpgrade(ctx context.Context, oldRoomID, newRoomID string, localMembers []*localMembership, roomSize int) error {
for _, membership := range localMembers {
// Copy any existing push rules from old -> new room
if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain); err != nil {
changed, err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain)
if err != nil {
return err
}
// Inform the SyncAPI about the updated push_rules
if changed {
if err = s.syncProducer.SendAccountData(membership.Localpart, eventutil.AccountData{
Type: "m.push_rules",
}); err != nil {
return err
}
}

// preserve m.direct room state
if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain, roomSize); err != nil {
changed, err = s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain, roomSize)
if err != nil {
return err
}
// Inform the SyncAPI about the updated m.direct
if changed {
if err = s.syncProducer.SendAccountData(membership.Localpart, eventutil.AccountData{
Type: "m.direct",
}); err != nil {
return err
}
}

// copy existing m.tag entries, if any
if err := s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain); err != nil {
changed, err = s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain)
if err != nil {
return err
}
// Inform the SyncAPI about the updated m.tag
if changed {
if err = s.syncProducer.SendAccountData(membership.Localpart, eventutil.AccountData{
Type: "m.tag",
}); err != nil {
return err
}
}
}
return nil
}

func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string, serverName spec.ServerName) error {
func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName spec.ServerName) (hasChanges bool, err error) {
pushRules, err := s.db.QueryPushRules(ctx, localpart, serverName)
if err != nil {
return fmt.Errorf("failed to query pushrules for user: %w", err)
return false, err
}
if pushRules == nil {
return nil
return false, err
}

var rulesBytes []byte
for _, roomRule := range pushRules.Global.Room {
if roomRule.RuleID != oldRoomID {
continue
}
cpRool := *roomRule
cpRool.RuleID = newRoomID
pushRules.Global.Room = append(pushRules.Global.Room, &cpRool)
rules, err := json.Marshal(pushRules)
rulesBytes, err = json.Marshal(pushRules)
if err != nil {
return err
return false, err
}
if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.push_rules", rules); err != nil {
return fmt.Errorf("failed to update pushrules: %w", err)
if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.push_rules", rulesBytes); err != nil {
return false, err
}
hasChanges = true
}
return nil
return hasChanges, err
}

// updateMDirect copies the "is_direct" flag from oldRoomID to newROomID
func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName spec.ServerName, roomSize int) error {
func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName spec.ServerName, roomSize int) (hasChanges bool, err error) {
// this is most likely not a DM, so skip updating m.direct state
if roomSize > 2 {
return nil
return false, nil
}
// Get direct message state
directChatsRaw, err := s.db.GetAccountDataByType(ctx, localpart, serverName, "", "m.direct")
if err != nil {
return fmt.Errorf("failed to get m.direct from database: %w", err)
return false, fmt.Errorf("failed to get m.direct from database: %w", err)
}
directChats := gjson.ParseBytes(directChatsRaw)
newDirectChats := make(map[string][]string)
Expand All @@ -285,25 +314,29 @@ func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID,
var data []byte
data, err = json.Marshal(newDirectChats)
if err != nil {
return err
return false, err
}
if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.direct", data); err != nil {
return fmt.Errorf("failed to update m.direct state: %w", err)
return false, fmt.Errorf("failed to update m.direct state: %w", err)
}
return true, nil
}

return nil
return false, nil
}

func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName spec.ServerName) error {
func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName spec.ServerName) (hasChanges bool, err error) {
tag, err := s.db.GetAccountDataByType(ctx, localpart, serverName, oldRoomID, "m.tag")
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
return false, err
}
if tag == nil {
return nil
return false, nil
}
if err := s.db.SaveAccountData(ctx, localpart, serverName, newRoomID, "m.tag", tag); err != nil {
return false, err
}
return s.db.SaveAccountData(ctx, localpart, serverName, newRoomID, "m.tag", tag)
return true, nil
}

func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rstypes.HeaderedEvent, streamPos uint64) error {
Expand Down
Loading