diff --git a/changelog.d/19556.feature b/changelog.d/19556.feature new file mode 100644 index 00000000000..bcb6c5c983c --- /dev/null +++ b/changelog.d/19556.feature @@ -0,0 +1,2 @@ +Add optional support for [MSC4429: Profile Updates for Legacy Sync](https://github.com/matrix-org/matrix-spec-proposals/pull/4429). +Currently defaults to not enabled, and is limited to local users only for the sync results. \ No newline at end of file diff --git a/docker/complement/conf/start_for_complement.sh b/docker/complement/conf/start_for_complement.sh index e0d30abed36..9d8fd93334f 100755 --- a/docker/complement/conf/start_for_complement.sh +++ b/docker/complement/conf/start_for_complement.sh @@ -60,12 +60,13 @@ if [[ -n "$SYNAPSE_COMPLEMENT_USE_WORKERS" ]]; then federation_inbound, \ federation_reader, \ federation_sender, \ + profile_updates, \ synchrotron, \ client_reader, \ appservice, \ pusher, \ device_lists:2, \ - stream_writers=account_data+presence+receipts+to_device+typing" + stream_writers=account_data+presence+profile_updates+receipts+to_device+typing" fi log "Workers requested: $SYNAPSE_WORKER_TYPES" diff --git a/docker/complement/conf/workers-shared-extra.yaml.j2 b/docker/complement/conf/workers-shared-extra.yaml.j2 index e829292aca9..64a36522fa4 100644 --- a/docker/complement/conf/workers-shared-extra.yaml.j2 +++ b/docker/complement/conf/workers-shared-extra.yaml.j2 @@ -15,6 +15,8 @@ enable_registration_without_verification: true bcrypt_rounds: 4 url_preview_enabled: true url_preview_ip_range_blacklist: [] +# MSC4429 Profile updates down legacy /sync +include_profile_updates_in_sync: true ## Registration ## diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py index 26c8556eff4..895c287ce9d 100755 --- a/docker/configure_workers_and_start.py +++ b/docker/configure_workers_and_start.py @@ -269,7 +269,10 @@ "^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$", "^/_matrix/client/(api/v1|r0|v3|unstable)/join/", "^/_matrix/client/(api/v1|r0|v3|unstable)/knock/", - "^/_matrix/client/(api/v1|r0|v3|unstable)/profile/", + # The [^/] differentiates this endpoint from + # `ProfileRestFieldsServlet`, which we want to instead go to the + # `profile_updates` worker below. + "^/_matrix/client/(api/v1|r0|v3|unstable)/profile/[^/]+", ], "shared_extra_conf": {}, "worker_extra_conf": "", @@ -308,6 +311,15 @@ "shared_extra_conf": {}, "worker_extra_conf": "", }, + "profile_updates": { + "app": "synapse.app.generic_worker", + "listener_resources": ["client", "replication"], + "endpoint_patterns": [ + "^/_matrix/client/(unstable/uk.tcpip.msc4133|api/v1|r0|v3|unstable)/profile/.+/" + ], + "shared_extra_conf": {}, + "worker_extra_conf": "", + }, "device_lists": { "app": "synapse.app.generic_worker", "listener_resources": ["client", "replication"], @@ -517,6 +529,7 @@ def add_worker_roles_to_shared_config( "typing", "push_rules", "thread_subscriptions", + "profile_updates", } # Worker-type specific sharding config. Now a single worker can fulfill multiple diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 92eca4a7ffd..4d4694c0eb3 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -4486,6 +4486,8 @@ This setting has the following sub-options: * `device_lists` (string): Name of a worker assigned to the `device_lists` stream. +* `profile_updates` (string): Name of a worker assigned to the `profile_updates` stream. + Example configuration: ```yaml stream_writers: diff --git a/docs/workers.md b/docs/workers.md index d987ff8980f..6259e3ac7f1 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -585,6 +585,7 @@ configured as stream writer for the `quarantined_media_changes` stream: ^/_synapse/admin/v1/quarantine_media/.*$ + #### Restrict outbound federation traffic to a specific set of workers The diff --git a/schema/synapse-config.schema.yaml b/schema/synapse-config.schema.yaml index 0e345b7b69d..8a889339b4e 100644 --- a/schema/synapse-config.schema.yaml +++ b/schema/synapse-config.schema.yaml @@ -5541,6 +5541,9 @@ properties: device_lists: type: string description: Name of a worker assigned to the `device_lists` stream. + profile_updates: + type: string + description: Name of a worker assigned to the `profile_updates` stream. default: {} examples: - events: worker1 diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index cca87d42a9e..d8e553f446f 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -286,6 +286,7 @@ main() { ./tests/msc4155 ./tests/msc4306 ./tests/msc4222 + ./tests/msc4429 ) # Export the list of test packages as a space-separated environment variable, so other diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 0b8a289d92d..f4f598d27c7 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -917,6 +917,10 @@ def alter_table(txn: LoggingTransaction) -> None: "quarantined_media_id_seq", [("quarantined_media_changes", "stream_id")], ) + await self._setup_sequence( + "profile_updates_sequence", + [("profile_updates", "stream_id")], + ) # Step 3. Get tables. self.progress.set_state("Fetching tables") diff --git a/synapse/api/constants.py b/synapse/api/constants.py index acac0573340..a5ee617344b 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -410,6 +410,56 @@ class ProfileFields: AVATAR_URL: Final = "avatar_url" +class ProfileUpdateAction(str, enum.Enum): + """ + Enum representing the action of a row in the profile updates stream tables. + These actions are used to determine how profiles, and what data is included in the + sync responses, depending on field updates and room membership changes. + """ + + JOINED_ROOM = "joined_room" + """ + This profile update row action represents a user joining a room. + + When gathering an incremental sync non-lazy response for profile updates, + we always include the full profile of users who have joined a room the syncing + user is a member of, where full profile means all the current profile values the + client asked for, regardless of whether they have changed recently. This ensures + that clients have profile re-populated for any users who have recently left + shared rooms. + + A scenario example would be as follows: + + * Alice leaves a room with Bob + * Bob's client clears all profile fields from Alice + * Alice joins a room with Bob + * Bob's client does an incremental non-lazy sync + + At the end of the flow Bob should receive all the profile fields the client + is interested in, not just the potential diff, which non-lazy incremental sync + normally includes. This update action currently has no meaning for sync responses + that are not incremental and non-lazy. + """ + LEFT_ROOM = "left_room" + """ + This profile update row action represents a user leaving a room. + + Clients will want to know when they no longer share rooms with a user. This + profile action row allows the sync code to deliver a `null` response for those + profiles, so clients can clear their cache containing the users profile data + they are no longer interested in. + """ + UPDATE = "update" + """ + This profile update row action represents a user updating a profile field. + + Depending on the type of sync (initial/incremental, lazy/non-lazy), either the + diff of profile field updates or all the current profile fields are included + in the sync response. In the latter case the profile update action row signifies + a change, but the client may still get fields that have not changed. + """ + + class StickyEventField(TypedDict): """ Dict content of the `sticky` part of an event. diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 9b47c20437b..cbae9133c66 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -123,6 +123,13 @@ "filter": FILTER_SCHEMA, "room_filter": ROOM_FILTER_SCHEMA, "room_event_filter": ROOM_EVENT_FILTER_SCHEMA, + "profile_fields_filter": { + "type": "object", + "properties": { + "ids": {"type": "array", "items": {"type": "string"}}, + }, + "additionalProperties": True, + }, }, "properties": { "presence": {"$ref": "#/definitions/filter"}, @@ -130,6 +137,9 @@ "room": {"$ref": "#/definitions/room_filter"}, "event_format": {"type": "string", "enum": ["client", "federation"]}, "event_fields": {"type": "array", "items": {"type": "string"}}, + "org.matrix.msc4429.profile_fields": { + "$ref": "#/definitions/profile_fields_filter" + }, }, "additionalProperties": True, # Allow new fields for forward compatibility } @@ -217,6 +227,13 @@ def __init__(self, hs: "HomeServer", filter_json: JsonMapping): self.event_fields = filter_json.get("event_fields", []) self.event_format = filter_json.get("event_format", "client") + self.profile_fields: set[str] = set() + if hs.config.server.include_profile_updates_in_sync: + profile_fields_filter = filter_json.get("org.matrix.msc4429.profile_fields") + + if isinstance(profile_fields_filter, Mapping): + self.profile_fields = set(profile_fields_filter.get("ids", [])) + def __repr__(self) -> str: return "" % (json.dumps(self._filter_json),) diff --git a/synapse/config/server.py b/synapse/config/server.py index ca94c224ea5..ffd4ef5ab89 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -545,6 +545,12 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: " 'allow_public_rooms_over_federation' is set." ) + # Whether to support MSC4299 profile updates down legacy /sync + self.include_profile_updates_in_sync = config.get( + "include_profile_updates_in_sync", + False, + ) + # Check if the legacy "restrict_public_rooms_to_local_users" flag is set. This # flag is now obsolete but we need to check it for backward-compatibility. if config.get("restrict_public_rooms_to_local_users", False): diff --git a/synapse/config/workers.py b/synapse/config/workers.py index fb7378bfc81..a05843d900e 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -142,6 +142,9 @@ class WriterLocations: push_rules: The instances that write to the push stream. Currently can only be a single instance. device_lists: The instances that write to the device list stream. + thread_subscriptions: The instances that write to the thread subscriptions + stream. + profile_updates: The instances that write to the profile updates stream. quarantined_media_changes: The instances that write to the quarantined media changes stream. """ @@ -179,7 +182,11 @@ class WriterLocations: converter=_instance_to_list_converter, ) thread_subscriptions: list[str] = attr.ib( - default=["master"], + default=[MAIN_PROCESS_INSTANCE_NAME], + converter=_instance_to_list_converter, + ) + profile_updates: list[str] = attr.ib( + default=[MAIN_PROCESS_INSTANCE_NAME], converter=_instance_to_list_converter, ) quarantined_media_changes: list[str] = attr.ib( @@ -361,8 +368,7 @@ def read_config( writers = config.get("stream_writers") or {} self.writers = WriterLocations(**writers) - # Check that the configured writers for events and typing also appears in - # `instance_map`. + # Check that the configured writers also appear in `instance_map`. for stream in ( "events", "typing", @@ -371,6 +377,9 @@ def read_config( "receipts", "presence", "push_rules", + "device_lists", + "thread_subscriptions", + "profile_updates", ): instances = _instance_to_list_converter(getattr(self.writers, stream)) for instance in instances: @@ -421,6 +430,16 @@ def read_config( "Must specify at least one instance to handle `device_lists` messages." ) + if len(self.writers.thread_subscriptions) == 0: + raise ConfigError( + "Must specify at least one instance to handle `thread_subscriptions` messages." + ) + + if len(self.writers.profile_updates) == 0: + raise ConfigError( + "Must specify at least one instance to handle `profile_updates` messages." + ) + self.events_shard_config = RoutableShardedWorkerHandlingConfig( self.writers.events ) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index c3886795b66..3b857d6b12e 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -25,7 +25,7 @@ from twisted.internet.defer import CancelledError -from synapse.api.constants import ProfileFields +from synapse.api.constants import ProfileFields, ProfileUpdateAction from synapse.api.errors import ( AuthError, Codes, @@ -34,6 +34,7 @@ StoreError, SynapseError, ) +from synapse.replication.http.profile import ReplicationProfileRecordFieldUpdates from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia from synapse.storage.roommember import ProfileInfo from synapse.types import ( @@ -42,6 +43,7 @@ JsonValue, Requester, ScheduledTask, + StreamKeyType, TaskStatus, UserID, create_requester, @@ -75,6 +77,8 @@ def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() # nb must be called this for @cached self.store = hs.get_datastores().main self.hs = hs + self._notifier = hs.get_notifier() + self._msc4429_enabled = hs.config.server.include_profile_updates_in_sync self.federation = hs.get_federation_client() hs.get_federation_registry().register_query_handler( @@ -98,6 +102,54 @@ def __init__(self, hs: "HomeServer"): self._update_join_states_task, UPDATE_JOIN_STATES_ACTION_NAME ) self._worker_locks = hs.get_worker_locks_handler() + self._is_profile_worker = ( + hs.get_instance_name() in hs.config.worker.writers.profile_updates + ) + self._record_profile_updates_client = ( + ReplicationProfileRecordFieldUpdates.make_client(self.hs) + ) + self._profile_updates_writer_instance = ( + self.hs.config.worker.writers.profile_updates[0] + ) + + async def record_profile_updates( + self, user_id: UserID, updated_fields: set[str] + ) -> None: + """ + Record user profile updates to our stream updates table. + + Args: + user_id: The user whose profile has had updates. + updated_fields: A set of the names of the fields that were updated. + + Returns: + None + """ + if not self._msc4429_enabled or not updated_fields: + return + + stream_id = await self.store.add_profile_updates( + user_id=user_id, + updated_fields=updated_fields, + action=ProfileUpdateAction.UPDATE, + ) + room_ids = await self.store.get_rooms_for_user(user_id.to_string()) + if not room_ids: + return + + users_who_share_rooms = ( + await self.store.get_local_users_who_share_room_with_user( + user_id.to_string() + ) + ) + await self.store.track_profile_updates_per_user( + stream_id=stream_id, + user_ids=users_who_share_rooms, + ) + + self._notifier.on_new_event( + StreamKeyType.PROFILE_UPDATES, stream_id, rooms=room_ids + ) async def get_profile(self, user_id: str, ignore_backoff: bool = True) -> JsonDict: """ @@ -253,6 +305,10 @@ async def set_displayname( ) await self.store.set_profile_displayname(target_user, displayname_to_set) + await self._dispatch_record_profile_updates( + target_user, + {ProfileFields.DISPLAYNAME}, + ) profile = await self.store.get_profileinfo(target_user) @@ -362,6 +418,10 @@ async def set_avatar_url( ) await self.store.set_profile_avatar_url(target_user, avatar_url_to_set) + await self._dispatch_record_profile_updates( + target_user, + {ProfileFields.AVATAR_URL}, + ) profile = await self.store.get_profileinfo(target_user) @@ -376,6 +436,87 @@ async def set_avatar_url( if propagate: await self._update_join_states(requester, target_user) + async def user_left_room(self, user_id: UserID, room_id: str) -> None: + """ + A user left a room. We now: + + * Add a row to `profile_updates` stating that the user left a room. + * Check if this user no longer shares any rooms with certain users. + * Insert a row for each of those users into `profile_updates_per_user`. + * Remove any previous profile update stream rows concerning + this user. This is done to stop leaking any updates to users who no longer + share a room. + * Now, when any of those users sync, the sync code will check + `profile_updates` and see that the user left a room. And thus a "clear + this user's profile" instruction will be sent down to the client. + + If that is the case, + """ + if not self._msc4429_enabled: + return + user_id_str = user_id.to_string() + + users_in_left_room = set(await self.store.get_local_users_in_room(room_id)) + users_in_left_room.discard(user_id_str) + if not users_in_left_room: + return + + users_still_sharing_rooms = await self.store.do_users_share_a_room( + user_id_str, users_in_left_room + ) + + users_to_update = users_in_left_room - users_still_sharing_rooms + if users_to_update: + # First clear any old profile updates for these users + await self.store.clear_profile_updates_for_user( + user_id=user_id, + users_to_remove=users_to_update, + ) + + # Record our leave + stream_id = await self.store.add_profile_updates( + user_id=user_id, + action=ProfileUpdateAction.LEFT_ROOM, + updated_fields=None, + ) + await self.store.track_profile_updates_per_user( + stream_id=stream_id, + user_ids=users_to_update, + ) + + async def user_joined_room(self, user_id: UserID, room_id: str) -> None: + """ + A user joined a room. We now: + + * Add a row to `profile_updates` stating that the user joined a room. + * Get list of users in that room. + * Insert a row for each of those users into `profile_updates_per_user`. + * Now, when any of those users sync, the sync code will check + `profile_updates` and see that the user joined a room. Thus, we can + include the users full profile in the case that we need to do so. + + If that is the case, + """ + if not self._msc4429_enabled: + return + user_id_str = user_id.to_string() + + users_in_room = set(await self.store.get_local_users_in_room(room_id)) + users_in_room.discard(user_id_str) + if not users_in_room: + return + + stream_id = await self.store.add_profile_updates( + user_id=user_id, + action=ProfileUpdateAction.JOINED_ROOM, + updated_fields=None, + ) + + await self.store.track_profile_updates_per_user( + stream_id=stream_id, + user_ids=users_in_room, + ) + async def delete_profile_upon_deactivation( self, target_user: UserID, @@ -406,8 +547,30 @@ async def delete_profile_upon_deactivation( # have it. raise AuthError(400, "Cannot remove another user's profile") + profile_updates: list[tuple[str, JsonValue | None]] = [] + current_profile: ProfileInfo | None = None + + if self._msc4429_enabled: + if current_profile is None: + current_profile = await self.store.get_profileinfo(target_user) + + if current_profile.display_name is not None: + profile_updates.append((ProfileFields.DISPLAYNAME, None)) + if current_profile.avatar_url is not None: + profile_updates.append((ProfileFields.AVATAR_URL, None)) + + custom_fields = await self.store.get_profile_fields(target_user) + for field_name in custom_fields.keys(): + profile_updates.append((field_name, None)) + await self.store.delete_profile(target_user) + # Record profile updates for the profile update stream + if len(profile_updates): + await self._dispatch_record_profile_updates( + target_user, {field_name for field_name, _value in profile_updates} + ) + await self._third_party_rules.on_profile_update( target_user.to_string(), ProfileInfo(None, None), @@ -415,6 +578,36 @@ async def delete_profile_upon_deactivation( deactivation=True, ) + async def _dispatch_record_profile_updates( + self, user_id: UserID, updated_fields: set[str] + ) -> None: + """ + Dispatch the recording of profile updates, either directly via the current + instance, if we're a profile worker, otherwise push via replication. + + Args: + user_id: The user whose profile has had updates. + updated_fields: A set of the names of the fields that were updated. + + Returns: + None + """ + if not self._msc4429_enabled: + return + + if self._is_profile_worker: + await self.record_profile_updates( + user_id, + updated_fields, + ) + else: + # Offload to the right worker via http replication + await self._record_profile_updates_client( + instance_name=self._profile_updates_writer_instance, + user_id=user_id.to_string(), + updated_fields=updated_fields, + ) + @cached() async def check_avatar_size_and_mime_type(self, mxc: str) -> bool: """Check that the size and content type of the avatar at the given MXC URI are @@ -492,7 +685,7 @@ async def check_avatar_size_and_mime_type(self, mxc: str) -> bool: async def get_profile_field( self, target_user: UserID, field_name: str - ) -> JsonValue: + ) -> JsonValue | dict[str, JsonValue]: """ Fetch a user's profile from the database for local users and over federation for remote users. @@ -530,12 +723,56 @@ async def get_profile_field( return result.get(field_name) + async def set_field( + self, + *, + target_user: UserID, + requester: Requester, + field_name: str, + new_value: JsonValue | dict[str, JsonValue], + by_admin: bool = False, + propagate: bool = False, + ) -> None: + """Wrapper function for setting any profile field for a user.""" + if field_name == ProfileFields.DISPLAYNAME: + if not isinstance(new_value, str): + raise SynapseError( + 400, "'displayname' must be a string", errcode=Codes.INVALID_PARAM + ) + await self.set_displayname( + target_user=target_user, + requester=requester, + new_displayname=new_value, + by_admin=by_admin, + propagate=propagate, + ) + elif field_name == ProfileFields.AVATAR_URL: + if not isinstance(new_value, str): + raise SynapseError( + 400, "'avatar_url' must be a string", errcode=Codes.INVALID_PARAM + ) + await self.set_avatar_url( + target_user=target_user, + requester=requester, + new_avatar_url=new_value, + by_admin=by_admin, + propagate=propagate, + ) + else: + await self.set_profile_field( + target_user=target_user, + requester=requester, + field_name=field_name, + new_value=new_value, + by_admin=by_admin, + ) + async def set_profile_field( self, target_user: UserID, requester: Requester, field_name: str, - new_value: JsonValue, + new_value: JsonValue | dict[str, JsonValue], *, by_admin: bool = False, ) -> None: @@ -560,6 +797,7 @@ async def set_profile_field( raise AuthError(403, "Cannot set another user's profile") await self.store.set_profile_field(target_user, field_name, new_value) + await self._dispatch_record_profile_updates(target_user, {field_name}) # Custom fields do not propagate into the user directory *or* rooms. profile = await self.store.get_profileinfo(target_user) @@ -595,6 +833,7 @@ async def delete_profile_field( raise AuthError(400, "Cannot set another user's profile") await self.store.delete_profile_field(target_user, field_name) + await self._dispatch_record_profile_updates(target_user, {field_name}) # Custom fields do not propagate into the user directory *or* rooms. profile = await self.store.get_profileinfo(target_user) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 5152d0b522e..e34fc79d4ac 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -53,6 +53,7 @@ from synapse.logging import opentracing from synapse.logging.opentracing import SynapseTags, set_tag, tag_args, trace from synapse.metrics import SERVER_NAME_LABEL, event_processing_positions +from synapse.replication.http.profile import ReplicationProfileUserRoomMembershipChange from synapse.replication.http.push import ReplicationCopyPusherRestServlet from synapse.storage.databases.main.state_deltas import StateDelta from synapse.storage.invite_rule import InviteRule @@ -205,6 +206,16 @@ def __init__(self, hs: "HomeServer"): ) self._push_writer = hs.config.worker.writers.push_rules[0] self._copy_push_client = ReplicationCopyPusherRestServlet.make_client(hs) + self._msc4429_enabled = hs.config.server.include_profile_updates_in_sync + self._is_profile_worker = ( + hs.get_instance_name() in hs.config.worker.writers.profile_updates + ) + self._profile_updates_writer_instance = ( + self.hs.config.worker.writers.profile_updates[0] + ) + self._profile_user_room_membership_change_client = ( + ReplicationProfileUserRoomMembershipChange.make_client(self.hs) + ) def _on_user_joined_room(self, event_id: str, room_id: str) -> None: """Notify the rate limiter that a room join has occurred. @@ -526,6 +537,36 @@ async def _local_membership_update( ) if prev_member_event.membership == Membership.JOIN: await self._user_left_room(target, room_id) + if self._msc4429_enabled: + # Notify the profile handler. We only want to do this once + # in a multi-worker setup, so we can't listen on the dispatched + # event above. + if self._is_profile_worker: + await self.profile_handler.user_left_room( + target, room_id + ) + else: + # Offload to the right worker via http replication + await self._profile_user_room_membership_change_client( + instance_name=self._profile_updates_writer_instance, + user_id=target.to_string(), + room_id=room_id, + membership=Membership.LEAVE, + ) + + elif self._msc4429_enabled and event.membership == Membership.JOIN: + # Notify the profile handler. We only want to do this once + # in a multi-worker setup, so we can't dispatch a hook to all workers. + if self._is_profile_worker: + await self.profile_handler.user_joined_room(target, room_id) + else: + # Offload to the right worker via http replication + await self._profile_user_room_membership_change_client( + instance_name=self._profile_updates_writer_instance, + user_id=target.to_string(), + room_id=room_id, + membership=Membership.JOIN, + ) break except PartialStateConflictError as e: @@ -1541,6 +1582,35 @@ async def send_membership_event( prev_member_event = await self.store.get_event(prev_member_event_id) if prev_member_event.membership == Membership.JOIN: await self._user_left_room(target_user, room_id) + if self._msc4429_enabled: + # Notify the profile handler. We only want to do this once + # in a multi-worker setup, so we can't listen on the dispatched + # event above. + if self._is_profile_worker: + await self.profile_handler.user_left_room( + target_user, room_id + ) + else: + # Offload to the right worker via http replication + await self._profile_user_room_membership_change_client( + instance_name=self._profile_updates_writer_instance, + user_id=target_user.to_string(), + room_id=room_id, + membership=Membership.LEAVE, + ) + elif self._msc4429_enabled and event.membership == Membership.JOIN: + # Notify the profile handler. We only want to do this once + # in a multi-worker setup, so we can't dispatch a hook to all workers. + if self._is_profile_worker: + await self.profile_handler.user_joined_room(target_user, room_id) + else: + # Offload to the right worker via http replication + await self._profile_user_room_membership_change_client( + instance_name=self._profile_updates_writer_instance, + user_id=target_user.to_string(), + room_id=room_id, + membership=Membership.JOIN, + ) async def _can_guest_join(self, partial_current_state_ids: StateMap[str]) -> bool: """ diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 9ecfe0da0f2..2c1597417ee 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -37,6 +37,7 @@ EventContentFields, EventTypes, Membership, + ProfileUpdateAction, StickyEvent, ) from synapse.api.filtering import FilterCollection @@ -64,6 +65,7 @@ DeviceListUpdates, JsonDict, JsonMapping, + JsonValue, MultiWriterStreamToken, MutableStateMap, Requester, @@ -104,10 +106,18 @@ # client for no more than 30 minutes. LAZY_LOADED_MEMBERS_CACHE_MAX_AGE = 30 * 60 * 1000 +# Store the cache that tracks which lazy-loaded profile fields have been sent to a given +# client for no more than 30 minutes. +LAZY_LOADED_PROFILE_FIELDS_CACHE_MAX_AGE = 30 * 60 * 1000 + # Remember the last 100 members we sent to a client for the purposes of # avoiding redundantly sending the same lazy-loaded members to the client LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE = 100 +# Remember the last 100 profile field updates we sent to a client for the purposes of +# avoiding redundantly sending the same lazy-loaded full profiles to the client +LAZY_LOADED_PROFILE_FIELDS_CACHE_MAX_SIZE = 100 + SyncRequestKey = tuple[Any, ...] @@ -224,6 +234,7 @@ class SyncResult: next_batch: Token for the next sync presence: List of presence events for the user. account_data: List of account_data events for the user. + profile_updates: Map of user_id to profile field updates for that user. joined: JoinedSyncResult for each joined room. invited: InvitedSyncResult for each invited room. knocked: KnockedSyncResult for each knocked on room. @@ -239,6 +250,8 @@ class SyncResult: next_batch: StreamToken presence: list[UserPresenceState] account_data: list[JsonDict] + # user ID -> {profile field -> value | null if unset } + profile_updates: dict[str, dict[str, JsonValue | dict[str, JsonValue]] | None] joined: list[JoinedSyncResult] invited: list[InvitedSyncResult] knocked: list[KnockedSyncResult] @@ -260,6 +273,7 @@ def __bool__(self) -> bool: or self.knocked or self.archived or self.account_data + or self.profile_updates or self.to_device or self.device_lists ) @@ -275,6 +289,7 @@ def empty( next_batch=next_batch, presence=[], account_data=[], + profile_updates={}, joined=[], invited=[], knocked=[], @@ -291,6 +306,7 @@ def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname self.hs_config = hs.config self.store = hs.get_datastores().main + self._is_mine_id = hs.is_mine_id self.notifier = hs.get_notifier() self.presence_handler = hs.get_presence_handler() self._relations_handler = hs.get_relations_handler() @@ -329,6 +345,27 @@ def __init__(self, hs: "HomeServer"): max_len=0, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, ) + # ExpiringCache((User, Device)) -> LruCache(Other User ID + Field Name -> bool) + self.lazy_loaded_profile_fields_cache: ExpiringCache[ + tuple[str, str | None], LruCache[str, bool] + ] = ExpiringCache( + cache_name="lazy_loaded_profile_fields_cache", + server_name=self.server_name, + hs=hs, + clock=self.clock, + max_len=0, + expiry_ms=LAZY_LOADED_PROFILE_FIELDS_CACHE_MAX_AGE, + ) + """This cache contains fields we have sent to clients as profile updates, + for a particular user + device combo. The cache entry is a combination of the + user + field name, with the value existing indicating the field has recently + been sent. The boolean value does not hold other significance. A missing + cache entry means "we have not sent this user + field name combo to the + syncing user". + + We don't manually remove entries from this cache, though it may be ignored + in cases where the sync must send the field down to the client. + """ self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync @@ -1036,6 +1073,24 @@ def get_lazy_loaded_members_cache( logger.debug("found LruCache for %r", cache_key) return cache + def get_lazy_loaded_profile_fields_cache( + self, cache_key: tuple[str, str | None] + ) -> LruCache[str, bool]: + cache: LruCache[str, bool] | None = self.lazy_loaded_profile_fields_cache.get( + cache_key + ) + if cache is None: + logger.debug("creating LruCache for %r", cache_key) + cache = LruCache( + max_size=LAZY_LOADED_PROFILE_FIELDS_CACHE_MAX_SIZE, + clock=self.clock, + server_name=self.server_name, + ) + self.lazy_loaded_profile_fields_cache[cache_key] = cache + else: + logger.debug("found LruCache for %r", cache_key) + return cache + async def compute_state_delta( self, room_id: str, @@ -1853,10 +1908,18 @@ async def generate_sync_result( } ) + # Note, this needs to be after we collect `joined`, `invited`, `knocked` and + # `archived` sync results since we want to utilize the work we did to collect + # events in those responses as a basis for which users to include profiles + # for when lazy loading. + if self.hs_config.server.include_profile_updates_in_sync: + await self._generate_sync_entry_for_profile_updates(sync_result_builder) + logger.debug("Sync response calculation complete") return SyncResult( presence=sync_result_builder.presence, account_data=sync_result_builder.account_data, + profile_updates=sync_result_builder.profile_updates, joined=sync_result_builder.joined, invited=sync_result_builder.invited, knocked=sync_result_builder.knocked, @@ -2121,6 +2184,243 @@ async def _generate_sync_entry_for_account_data( sync_result_builder.account_data = account_data_for_user + async def _generate_initial_sync_entry_for_profile_updates( + self, + *, + user_id: str, + sync_result_builder: "SyncResultBuilder", + profile_fields: set[str], + include_users: set[str] | None, + ) -> None: + """ + Build an initial sync entry for profile updates and attach it to the + given `sync_result_builder`. + + Note: Currently, only profile updates of local users are generated. + + Args: + user_id: The Matrix ID of the user to generate the sync entry for. + sync_result_builder: + profile_fields: The list of field IDs to filter for. + include_users: List of users profiles to include in the sync response, + for when we have calculated a list of users in our lazy loading + sync and want to only return those. + """ + # Currently, limited to only local profiles, so filter remote servers out + user_ids = await self.store.get_local_users_who_share_room_with_user(user_id) + if include_users: + # Filter down to selected included users + user_ids = {user_id for user_id in user_ids if user_id in include_users} + + if not user_ids: + return + + profile_data_by_user = await self.store.get_profile_data_for_users(user_ids) + + # Serialise the profile updates into the sync response format. + profile_updates: dict[ + str, dict[str, JsonValue | dict[str, JsonValue]] | None + ] = {} + for other_user_id in user_ids: + profile_data = profile_data_by_user.get(other_user_id) + if profile_data is None: + # Don't generate anything for users with no profile data + # in initial sync. + continue + + per_user_updates: dict[str, JsonValue | dict[str, JsonValue]] = {} + for field_name in profile_fields: + if field_name in profile_data.keys(): + per_user_updates[field_name] = profile_data[field_name] + + if per_user_updates: + profile_updates[other_user_id] = per_user_updates + + if profile_updates: + sync_result_builder.profile_updates = profile_updates + + async def _generate_sync_entry_for_profile_updates( + self, sync_result_builder: "SyncResultBuilder" + ) -> None: + """ + Build a sync entry for profile updates and attach it to the given + `sync_result_builder`. + + Currently only local profiles updates will be included in the sync response. + + Args: + sync_result_builder: + """ + sync_config = sync_result_builder.sync_config + profile_fields = sync_config.filter_collection.profile_fields + if not profile_fields: + return + + user_id = sync_config.user.to_string() + since_token = sync_result_builder.since_token + now_token = sync_result_builder.now_token + + sync_config = sync_result_builder.sync_config + lazy_load_members = sync_config.filter_collection.lazy_load_members() + include_users = None + if lazy_load_members: + # Collect members from the existing `sync_result_builder` data. + # Ensure we filter out any remove users until we support profile + # updates for federated users. + include_users = set() + # invited + for invited in sync_result_builder.invited: + if self._is_mine_id(invited.invite.sender): + include_users.add(invited.invite.sender) + # joined + for joined in sync_result_builder.joined: + for timeline_event in joined.timeline.events: + if self._is_mine_id(timeline_event.event.sender): + include_users.add(timeline_event.event.sender) + # knocked + for knocked in sync_result_builder.knocked: + if self._is_mine_id(knocked.knock.sender): + include_users.add(knocked.knock.sender) + # archived + for archived in sync_result_builder.archived: + for timeline_event in archived.timeline.events: + if self._is_mine_id(timeline_event.event.sender): + include_users.add(timeline_event.event.sender) + + if since_token is None: + await self._generate_initial_sync_entry_for_profile_updates( + user_id=user_id, + sync_result_builder=sync_result_builder, + profile_fields=profile_fields, + include_users=include_users, + ) + return + + updates = await self.store.get_profile_updates_for_user_and_fields( + from_id=since_token.profile_updates_key, + to_id=now_token.profile_updates_key, + user_id=user_id, + field_names=profile_fields, + ) + + left_room_user_ids = { + update.user_id + for update in updates + if update.action == ProfileUpdateAction.LEFT_ROOM.value + } + joined_room_user_ids = { + update.user_id + for update in updates + if update.action == ProfileUpdateAction.JOINED_ROOM.value + } + users = set() + updated_users = { + update.user_id + for update in updates + if update.action == ProfileUpdateAction.UPDATE.value + } + # Add any users in the timeline, if we collected them due to lazy loading + if include_users: + users.update(include_users) + # Add users with updates + users.update(updated_users) + # Add any newly joined users + users.update(joined_room_user_ids) + + if not users and not left_room_user_ids: + return + + # Serialise the profile updates into the sync response format. + # user ID -> {profile field -> value | null if unset } + profile_updates: dict[ + str, dict[str, JsonValue | dict[str, JsonValue]] | None + ] = {} + + # Process field updates and users who have events in the sync response + if users: + updated_user_fields: dict[str, set[str]] = {} + # Set fields from updates + for update in updates: + # Skip the update if there is no field update (a joined or left room + # action), the client didn't ask for this field, or we're not + # interested in this user. + if ( + not update.field_name + or update.field_name not in profile_fields + or update.user_id not in users + ): + continue + updated_user_fields.setdefault(update.user_id, set()).add( + update.field_name + ) + + # Note: there's a small race condition here where a profile update may + # occur between fetching `now_token` above and reaching this step. In + # that case, the profile information will be newer than `now_token`. + # This is fine, as users will generally always want the latest profile + # information. However, it does mean that on the next sync, the same + # profile update will come down a second time. + # + # Hopefully clients can just filter these out. + profile_data_by_user = await self.store.get_profile_data_for_users(users) + + for other_user_id in users: + profile_data = profile_data_by_user.get(other_user_id) + if profile_data is None: + # No profile data for this user, just return a blank dictionary + # in incremental sync, telling the clients to remove all profile + # information for this user. + profile_updates[other_user_id] = None + continue + + per_user_updates: dict[str, JsonValue | dict[str, JsonValue]] = {} + if include_users and other_user_id in include_users: + # Include all the fields the client asked for, as this user + # has events in a lazy loaded sync response, except for + # fields we've recently sent in a previous lazy loaded sync response + fields = set(profile_data.keys()).intersection(profile_fields) + for field_name in fields: + cache_key = ( + sync_config.user.to_string(), + sync_config.device_id, + ) + cache = self.get_lazy_loaded_profile_fields_cache(cache_key) + # Only send this users field if we haven't recently sent it + if cache.get(f"{other_user_id}-{field_name}") is None: + per_user_updates[field_name] = profile_data.get(field_name) + # Update our cache to indicate this user/field combo + # has been recently sent. + cache.set( + f"{other_user_id}-{field_name}", + True, + ) + else: + # Include only the diff, unless the user recently joined, + # then send all the fields the client asked for. + # We don't use a cache here as for non-lazy sync we always + # send changes and/or fields the client asked for, if relevant + # as above joined condition. + fields = ( + profile_fields + if other_user_id in joined_room_user_ids + else set(updated_user_fields.get(other_user_id, [])) + ) + fields = set(profile_data.keys()).intersection(fields) + for field_name in fields: + per_user_updates[field_name] = profile_data[field_name] + + if per_user_updates: + profile_updates[other_user_id] = per_user_updates + + # Process left rooms + if left_room_user_ids: + for other_user_id in left_room_user_ids: + # Return an empty dictionary to the client + profile_updates[other_user_id] = None + + if profile_updates: + sync_result_builder.profile_updates = profile_updates + async def _generate_sync_entry_for_presence( self, sync_result_builder: "SyncResultBuilder", @@ -3137,6 +3437,7 @@ class SyncResultBuilder: # The following mirror the fields in a sync response presence account_data + profile_updates joined invited knocked @@ -3155,6 +3456,9 @@ class SyncResultBuilder: presence: list[UserPresenceState] = attr.Factory(list) account_data: list[JsonDict] = attr.Factory(list) + profile_updates: dict[str, dict[str, JsonValue | dict[str, JsonValue]] | None] = ( + attr.Factory(dict) + ) joined: list[JoinedSyncResult] = attr.Factory(list) invited: list[InvitedSyncResult] = attr.Factory(list) knocked: list[KnockedSyncResult] = attr.Factory(list) diff --git a/synapse/notifier.py b/synapse/notifier.py index 6a057ac09fa..e24d0ef5a25 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -528,6 +528,7 @@ def on_new_event( StreamKeyType.UN_PARTIAL_STATED_ROOMS, StreamKeyType.THREAD_SUBSCRIPTIONS, StreamKeyType.STICKY_EVENTS, + StreamKeyType.PROFILE_UPDATES, ], new_token: int, users: Collection[str | UserID] | None = None, diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index 68cc6ce1fc6..d934ef80678 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -30,6 +30,7 @@ login, membership, presence, + profile, push, register, send_events, @@ -59,6 +60,7 @@ def register_servlets(self, hs: "HomeServer") -> None: push.register_servlets(hs, self) state.register_servlets(hs, self) devices.register_servlets(hs, self) + profile.register_servlets(hs, self) # The following can't currently be instantiated on workers. if hs.config.worker.worker_app is None: diff --git a/synapse/replication/http/profile.py b/synapse/replication/http/profile.py new file mode 100644 index 00000000000..2ebae2c0078 --- /dev/null +++ b/synapse/replication/http/profile.py @@ -0,0 +1,140 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2026 Element Creations, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# + + +import logging +from typing import TYPE_CHECKING + +from twisted.web.server import Request + +from synapse.api.constants import Membership +from synapse.http.server import HttpServer +from synapse.replication.http._base import ReplicationEndpoint +from synapse.types import JsonDict, UserID + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class ReplicationProfileUserRoomMembershipChange(ReplicationEndpoint): + """Store user profile update action regarding membership changes. + + The POST looks like: + + POST /_synapse/replication/profile_user_room_membership_change/ + + { + "room_id": "!1234:domain.tld", + "membership": "join | leave" + } + + 200 OK + + {} + """ + + NAME = "profile_user_room_membership_change" + PATH_ARGS = ("user_id",) + METHOD = "POST" + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self._profile_handler = hs.get_profile_handler() + + @staticmethod + async def _serialize_payload( # type: ignore[override] + user_id: str, + room_id: str, + membership: str, + ) -> JsonDict: + assert membership in (Membership.JOIN, Membership.LEAVE) + return { + "room_id": room_id, + "membership": membership, + } + + async def _handle_request( # type: ignore[override] + self, request: Request, content: JsonDict, user_id: str + ) -> tuple[int, JsonDict]: + assert content["membership"] in (Membership.JOIN, Membership.LEAVE) + if content["membership"] == Membership.JOIN: + await self._profile_handler.user_joined_room( + user_id=UserID.from_string(user_id), + room_id=content["room_id"], + ) + else: + await self._profile_handler.user_left_room( + user_id=UserID.from_string(user_id), + room_id=content["room_id"], + ) + + return (200, {}) + + +class ReplicationProfileRecordFieldUpdates(ReplicationEndpoint): + """Record user profile field updates for the profile updates stream. + + The POST looks like: + + POST /_synapse/replication/profile_record_field_updates/ + + { + "updated_fields": ["list", "of", "fields"] + } + + 200 OK + + {} + """ + + NAME = "profile_record_field_updates" + PATH_ARGS = ("user_id",) + METHOD = "POST" + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self._profile_handler = hs.get_profile_handler() + + @staticmethod + async def _serialize_payload( # type: ignore[override] + user_id: str, + updated_fields: set[str], + ) -> JsonDict: + assert len(updated_fields) > 0 + return { + "updated_fields": list(updated_fields), + } + + async def _handle_request( # type: ignore[override] + self, request: Request, content: JsonDict, user_id: str + ) -> tuple[int, JsonDict]: + assert len(content["updated_fields"]) > 0 + await self._profile_handler.record_profile_updates( + user_id=UserID.from_string(user_id), + updated_fields=set(content["updated_fields"]), + ) + + return (200, {}) + + +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + if hs.config.server.include_profile_updates_in_sync: + ReplicationProfileUserRoomMembershipChange(hs).register(http_server) + ReplicationProfileRecordFieldUpdates(hs).register(http_server) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index bc7e46d4c92..00a24db749b 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -44,6 +44,7 @@ UnPartialStatedRoomStream, ) from synapse.replication.tcp.streams._base import ( + ProfileUpdatesStream, StickyEventsStream, ThreadSubscriptionsStream, ) @@ -265,6 +266,23 @@ async def on_rdata( token, users=[row.user_id for row in rows], ) + elif stream_name == ProfileUpdatesStream.NAME: + updated_user_ids = {row.user_id for row in rows} + if updated_user_ids: + room_ids: set[str] = set() + # Get all the rooms of the updated users, dict of + # User ID -> [Room ID] + users_and_rooms = await self.store.get_rooms_for_users(updated_user_ids) + # Loop through each users room ID's and add to our set of rooms + for user_room_ids in users_and_rooms.values(): + room_ids.update(user_room_ids) + + if room_ids: + self.notifier.on_new_event( + StreamKeyType.PROFILE_UPDATES, + token, + rooms=room_ids, + ) elif stream_name == StickyEventsStream.NAME: self.notifier.on_new_event( StreamKeyType.STICKY_EVENTS, diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index ad9fed72dd8..65befb77064 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -67,6 +67,7 @@ ) from synapse.replication.tcp.streams._base import ( DeviceListsStream, + ProfileUpdatesStream, StickyEventsStream, ThreadSubscriptionsStream, ) @@ -218,6 +219,12 @@ def __init__(self, hs: "HomeServer"): continue + if isinstance(stream, ProfileUpdatesStream): + if hs.get_instance_name() in hs.config.worker.writers.profile_updates: + self._streams_to_replicate.append(stream) + + continue + if isinstance(stream, StickyEventsStream): if hs.get_instance_name() in hs.config.worker.writers.events: self._streams_to_replicate.append(stream) diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index e41573cf689..e657822da70 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -37,6 +37,7 @@ DeviceListsStream, PresenceFederationStream, PresenceStream, + ProfileUpdatesStream, PushersStream, PushRulesStream, QuarantinedMediaStream, @@ -70,6 +71,7 @@ ToDeviceStream, FederationStream, AccountDataStream, + ProfileUpdatesStream, StickyEventsStream, ThreadSubscriptionsStream, UnPartialStatedRoomStream, @@ -94,6 +96,7 @@ "ToDeviceStream", "FederationStream", "AccountDataStream", + "ProfileUpdatesStream", "StickyEventsStream", "ThreadSubscriptionsStream", "UnPartialStatedRoomStream", diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index a73f767add2..86242b23413 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -31,8 +31,9 @@ import attr -from synapse.api.constants import AccountDataTypes +from synapse.api.constants import AccountDataTypes, ProfileUpdateAction from synapse.replication.http.streams import ReplicationGetStreamUpdates +from synapse.types import UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -765,6 +766,57 @@ async def _update_function( return rows, rows[-1][0], len(updates) == limit +@attr.s(slots=True, auto_attribs=True) +class ProfileUpdatesStreamRow: + """Stream to inform workers about profile updates.""" + + user_id: UserID + """The full user ID with the profile update.""" + action: ProfileUpdateAction + """The action, either 'update' for a field update, 'left_room' if the user left + a room or `joined_room` if the user joined a room, see ProfileUpdateAction constant. + """ + field_name: str | None + """The profile field that was updated, see https://spec.matrix.org/unstable/client-server-api/#profiles. + This can be None if `action` is not 'update'. + """ + + +class ProfileUpdatesStream(_StreamFromIdGen): + """A user profile field was changed.""" + + NAME = "profile_updates" + ROW_TYPE = ProfileUpdatesStreamRow + + def __init__(self, hs: "HomeServer"): + self.store = hs.get_datastores().main + super().__init__( + hs.get_instance_name(), + self._update_function, + self.store._profile_updates_id_gen, + ) + + async def _update_function( + self, instance_name: str, from_token: int, to_token: int, limit: int + ) -> StreamUpdateResult: + updates = await self.store.get_updated_profile_updates( + from_id=from_token, to_id=to_token, limit=limit + ) + rows = [ + ( + stream_id, + # These are the args to `ProfileUpdatesStreamRow` + (user_id, action, field_name), + ) + for stream_id, user_id, action, field_name in updates + ] + + if not rows: + return [], to_token, False + + return rows, rows[-1][0], len(updates) == limit + + @attr.s(slots=True, auto_attribs=True) class StickyEventsStreamRow: """Stream to inform workers about changes to sticky events.""" diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py index c2ec5b36114..93a5b102d11 100644 --- a/synapse/rest/client/profile.py +++ b/synapse/rest/client/profile.py @@ -58,7 +58,7 @@ def _read_propagate(hs: "HomeServer", request: SynapseRequest) -> bool: class ProfileRestServlet(RestServlet): - PATTERNS = client_patterns("/profile/(?P[^/]*)", v1=True) + PATTERNS = client_patterns("/profile/(?P[^/]*)$", v1=True) CATEGORY = "Event sending requests" def __init__(self, hs: "HomeServer"): @@ -109,6 +109,9 @@ def __init__(self, hs: "HomeServer"): self.hs = hs self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() + self._is_profile_worker = ( + hs.get_instance_name() in hs.config.worker.writers.profile_updates + ) if hs.config.experimental.msc4133_enabled: self.PATTERNS.append( re.compile( @@ -146,7 +149,9 @@ async def on_GET( await self.profile_handler.check_profile_query_allowed(user, requester_user) if field_name == ProfileFields.DISPLAYNAME: - field_value: JsonValue = await self.profile_handler.get_displayname(user) + field_value: ( + JsonValue | dict[str, JsonValue] + ) = await self.profile_handler.get_displayname(user) elif field_name == ProfileFields.AVATAR_URL: field_value = await self.profile_handler.get_avatar_url(user) else: @@ -204,18 +209,14 @@ async def on_PUT( Codes.USER_ACCOUNT_SUSPENDED, ) - if field_name == ProfileFields.DISPLAYNAME: - await self.profile_handler.set_displayname( - user, requester, new_value, by_admin=is_admin, propagate=propagate - ) - elif field_name == ProfileFields.AVATAR_URL: - await self.profile_handler.set_avatar_url( - user, requester, new_value, by_admin=is_admin, propagate=propagate - ) - else: - await self.profile_handler.set_profile_field( - user, requester, field_name, new_value, by_admin=is_admin - ) + await self.profile_handler.set_field( + target_user=user, + requester=requester, + field_name=field_name, + new_value=new_value, + by_admin=is_admin, + propagate=propagate, + ) return 200, {} @@ -261,17 +262,21 @@ async def on_DELETE( Codes.USER_ACCOUNT_SUSPENDED, ) - if field_name == ProfileFields.DISPLAYNAME: - await self.profile_handler.set_displayname( - user, requester, "", by_admin=is_admin, propagate=propagate - ) - elif field_name == ProfileFields.AVATAR_URL: - await self.profile_handler.set_avatar_url( - user, requester, "", by_admin=is_admin, propagate=propagate + if field_name in (ProfileFields.DISPLAYNAME, ProfileFields.AVATAR_URL): + await self.profile_handler.set_field( + target_user=user, + requester=requester, + field_name=field_name, + new_value="", + by_admin=is_admin, + propagate=propagate, ) else: await self.profile_handler.delete_profile_field( - user, requester, field_name, by_admin=is_admin + target_user=user, + requester=requester, + field_name=field_name, + by_admin=is_admin, ) return 200, {} @@ -284,8 +289,9 @@ class UnstableProfileFieldRestServlet(ProfileFieldRestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - # The specific field endpoint *must* appear before the generic profile endpoint. ProfileFieldRestServlet(hs).register(http_server) - ProfileRestServlet(hs).register(http_server) + if hs.config.experimental.msc4133_enabled: UnstableProfileFieldRestServlet(hs).register(http_server) + + ProfileRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 702ddcd6ca1..042f3091e09 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -124,6 +124,7 @@ def __init__(self, hs: "HomeServer"): self._event_serializer = hs.get_event_client_serializer() self._msc2654_enabled = hs.config.experimental.msc2654_enabled self._msc3773_enabled = hs.config.experimental.msc3773_enabled + self._msc4429_enabled = hs.config.server.include_profile_updates_in_sync self._json_filter_cache: LruCache[str, bool] = LruCache( max_size=1000, @@ -352,6 +353,12 @@ async def encode_response( if sync_result.to_device: response["to_device"] = {"events": sync_result.to_device} + if self._msc4429_enabled and sync_result.profile_updates: + response["org.matrix.msc4429.users"] = { + user_id: {"profile_updates": updates} + for user_id, updates in sync_result.profile_updates.items() + } + if sync_result.device_lists.changed: response["device_lists"]["changed"] = list(sync_result.device_lists.changed) if sync_result.device_lists.left: diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 809e920a2bb..605d4ee7a5a 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -198,6 +198,8 @@ async def on_GET(self, request: SynapseRequest) -> tuple[int, JsonDict]: # Arbitrary key-value profile fields. "uk.tcpip.msc4133": self.config.experimental.msc4133_enabled, "uk.tcpip.msc4133.stable": True, + # MSC4429: Profile updates for legacy /sync. + "org.matrix.msc4429": self.config.server.include_profile_updates_in_sync, # MSC4155: Invite filtering "org.matrix.msc4155": self.config.experimental.msc4155_enabled, # MSC4306: Support for thread subscriptions diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 9b787e19a3d..173381e5fed 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -19,13 +19,17 @@ # # import json -from typing import TYPE_CHECKING, cast +import logging +from typing import TYPE_CHECKING, Collection, Iterable, cast +import attr from canonicaljson import encode_canonical_json -from synapse.api.constants import ProfileFields +from synapse.api.constants import ProfileFields, ProfileUpdateAction from synapse.api.errors import Codes, StoreError -from synapse.storage._base import SQLBaseStore +from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.replication.tcp.streams._base import ProfileUpdatesStream +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, @@ -33,15 +37,36 @@ ) from synapse.storage.databases.main.roommember import ProfileInfo from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import JsonDict, JsonValue, UserID +from synapse.util.duration import Duration if TYPE_CHECKING: from synapse.server import HomeServer +logger = logging.getLogger(__name__) # The number of bytes that the serialized profile can have. MAX_PROFILE_SIZE = 65536 +# Prunes entries out of the `profile_updates` and `profile_updates_per_user` tables +# that are more than this old. +PRUNE_PROFILE_UPDATES_AGE = Duration(days=30) + +# The number of rows to delete at once when pruning old entries out of the +# `profile_updates` and `profile_updates_per_user` tables. +PRUNE_PROFILE_UPDATES_BATCH_SIZE = 1000 + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ProfileUpdate: + """An update to a user's profile.""" + + stream_id: int + user_id: str + action: str + field_name: str | None + class ProfileWorkerStore(SQLBaseStore): def __init__( @@ -52,6 +77,7 @@ def __init__( ): super().__init__(database, db_conn, hs) self.server_name: str = hs.hostname + self._instance_name: str = hs.get_instance_name() self.database_engine = database.engine self.db_pool.updates.register_background_index_update( "profiles_full_user_id_key_idx", @@ -65,6 +91,28 @@ def __init__( "populate_full_user_id_profiles", self.populate_full_user_id_profiles ) + self._can_write_to_profile_updates = ( + self._instance_name in hs.config.worker.writers.profile_updates + ) + self._profile_updates_id_gen: MultiWriterIdGenerator = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + notifier=hs.get_replication_notifier(), + stream_name="profile_updates", + server_name=self.server_name, + instance_name=self._instance_name, + tables=[ + ("profile_updates", "instance_name", "stream_id"), + ], + sequence_name="profile_updates_sequence", + writers=hs.config.worker.writers.profile_updates, + ) + if hs.config.worker.run_background_tasks: + self.clock.looping_call( + self._prune_profile_updates, + Duration(hours=1), + ) + async def populate_full_user_id_profiles( self, progress: JsonDict, batch_size: int ) -> int: @@ -152,6 +200,13 @@ def _final_batch(txn: LoggingTransaction, lower_bound_id: str) -> None: return 50 + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == ProfileUpdatesStream.NAME: + self._profile_updates_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + async def get_profileinfo(self, user_id: UserID) -> ProfileInfo: """ Fetch the display name and avatar URL of a user. @@ -210,7 +265,9 @@ async def get_profile_avatar_url(self, user_id: UserID) -> str | None: desc="get_profile_avatar_url", ) - async def get_profile_field(self, user_id: UserID, field_name: str) -> JsonValue: + async def get_profile_field( + self, user_id: UserID, field_name: str + ) -> JsonValue | dict[str, JsonValue]: """ Get a custom profile field for a user. @@ -222,7 +279,9 @@ async def get_profile_field(self, user_id: UserID, field_name: str) -> JsonValue The string value if the field exists, otherwise raises 404. """ - def get_profile_field(txn: LoggingTransaction) -> JsonValue: + def get_profile_field( + txn: LoggingTransaction, + ) -> JsonValue | dict[str, JsonValue]: # This will error if field_name has double quotes in it, but that's not # possible due to the grammar. field_path = f'$."{field_name}"' @@ -240,7 +299,9 @@ def get_profile_field(txn: LoggingTransaction) -> JsonValue: # Test exists first since value being None is used for both # missing and a null JSON value. - exists, value = cast(tuple[bool, JsonValue], txn.fetchone()) + exists, value = cast( + tuple[bool, JsonValue | dict[str, JsonValue]], txn.fetchone() + ) if not exists: raise StoreError(404, "No row found") return value @@ -257,7 +318,9 @@ def get_profile_field(txn: LoggingTransaction) -> JsonValue: ) # If value_type is None, then the value did not exist. - value_type, value = cast(tuple[str | None, JsonValue], txn.fetchone()) + value_type, value = cast( + tuple[str | None, JsonValue | dict[str, JsonValue]], txn.fetchone() + ) if not value_type: raise StoreError(404, "No row found") # If value_type is object or array, then need to deserialize the JSON. @@ -291,6 +354,348 @@ async def get_profile_fields(self, user_id: UserID) -> dict[str, str]: result = json.loads(result) return result or {} + def get_max_profile_updates_stream_id(self) -> int: + """Get the current maximum stream_id for profile updates.""" + return self._profile_updates_id_gen.get_current_token() + + def get_profile_updates_stream_id_generator(self) -> MultiWriterIdGenerator: + return self._profile_updates_id_gen + + async def get_updated_profile_updates( + self, *, from_id: int, to_id: int, limit: int + ) -> list[tuple[int, str, str, str | None]]: + """Get updates to profile updates between two stream IDs. + + Bounds: from_id < ... <= to_id + + Args: + from_id: The starting stream ID (exclusive) + to_id: The ending stream ID (inclusive) + limit: The maximum number of rows to return + + Returns: + list of tuples representing stream_id, user_id, action and field_name + """ + if from_id >= to_id: + return [] + + def _get_updated_profile_updates_txn( + txn: LoggingTransaction, + ) -> list[tuple[int, str, str, str | None]]: + sql = """ + SELECT + stream_id, user_id, action, field_name + FROM profile_updates + WHERE + ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC LIMIT ? + """ + txn.execute(sql, (from_id, to_id, limit)) + return cast(list[tuple[int, str, str, str | None]], txn.fetchall()) + + return await self.db_pool.runInteraction( + "get_updated_profile_updates", _get_updated_profile_updates_txn + ) + + async def get_profile_updates_for_fields( + self, + *, + from_id: int, + to_id: int, + field_names: Iterable[str], + ) -> list[ProfileUpdate]: + """Get profile update markers for the given fields in a stream range. + + Bounds: from_id < ... <= to_id + + Args: + from_id: The starting stream ID (exclusive) + to_id: The ending stream ID (inclusive) + field_names: List of field names to filter against. + + Returns: + list of ProfileUpdates update rows + """ + if from_id >= to_id: + return [] + + field_names = list(field_names) + if not field_names: + return [] + + def _get_profile_updates_for_fields_txn( + txn: LoggingTransaction, + ) -> list[ProfileUpdate]: + clause, args = make_in_list_sql_clause( + txn.database_engine, "field_name", field_names + ) + sql = ( + "SELECT stream_id, user_id, action, field_name" + " FROM profile_updates" + f" WHERE ? < stream_id AND stream_id <= ? AND ({clause}" + " OR action != ?) " + " ORDER BY stream_id ASC" + ) + txn.execute(sql, (from_id, to_id, *args, ProfileUpdateAction.UPDATE.value)) + rows = cast(list[tuple[int, str, str, str | None]], txn.fetchall()) + + updates: list[ProfileUpdate] = [] + for stream_id, user_id, action, field_name in rows: + updates.append( + ProfileUpdate( + stream_id=stream_id, + user_id=user_id, + action=action, + field_name=field_name, + ) + ) + + return updates + + return await self.db_pool.runInteraction( + "get_profile_updates_for_fields", _get_profile_updates_for_fields_txn + ) + + async def get_profile_updates_for_user_and_fields( + self, + *, + from_id: int, + to_id: int, + user_id: str, + field_names: set[str], + include_users: set[str] | None = None, + ) -> list[ProfileUpdate]: + """Get profile update markers for a user in a stream range. + + The returned profile update rows are restricted to those with a + corresponding `profile_updates_per_user` row for the syncing user. + + Bounds: from_id < ... <= to_id + + Args: + from_id: The starting stream ID (exclusive). + to_id: The ending stream ID (inclusive). + user_id: The full user ID to filter on. + field_names: Set of field names to filter update actions against. + include_users: If given, only include updates for these user IDs. + + Returns: + A list of ProfileUpdates update rows. + """ + if from_id >= to_id: + return [] + + if len(field_names) == 0: + return [] + + if include_users is not None and len(include_users) == 0: + # All updates have been filtered out by lazy-loading. + return [] + + def _get_profile_updates_for_user_and_fields_txn( + txn: LoggingTransaction, + ) -> list[ProfileUpdate]: + field_clause, field_args = make_in_list_sql_clause( + txn.database_engine, "pu.field_name", field_names + ) + user_clause = "" + user_args: list[str] = [] + if include_users is not None: + # Filter out rows that aren't in `include_users`, if defined. + # This is only relevant when lazy-loading. + user_clause, user_args = make_in_list_sql_clause( + txn.database_engine, "pu.user_id", include_users + ) + user_clause = f"AND {user_clause}" + + # Retrieve profile updates where there's a corresponding row in + # `profile_updates_per_user` within the given `stream_id` bounds + # and the `user_id` and `field_names` match. + sql = f""" + SELECT pu.stream_id, pu.user_id, pu.action, pu.field_name + FROM profile_updates AS pu + INNER JOIN profile_updates_per_user AS puf + ON pu.stream_id = puf.stream_id + WHERE ? < pu.stream_id AND pu.stream_id <= ? + AND puf.user_id = ? + {user_clause} + AND ({field_clause} OR pu.action != ?) + ORDER BY pu.stream_id ASC + """ + + txn.execute( + sql, + ( + from_id, + to_id, + user_id, + *user_args, + *field_args, + ProfileUpdateAction.UPDATE.value, + ), + ) + rows = cast(list[tuple[int, str, str, str | None]], txn.fetchall()) + + updates: list[ProfileUpdate] = [] + for stream_id, updated_user_id, action, field_name in rows: + updates.append( + ProfileUpdate( + stream_id=stream_id, + user_id=updated_user_id, + action=action, + field_name=field_name, + ) + ) + + return updates + + return await self.db_pool.runInteraction( + "get_profile_updates_for_user_and_fields", + _get_profile_updates_for_user_and_fields_txn, + ) + + async def get_profile_data_for_users( + self, user_ids: Collection[str] + ) -> dict[str, dict[str, JsonValue | dict[str, JsonValue]]]: + """Fetch displayname/avatar_url/custom fields for a list of users. + + Currently, this returns only local users as the `profiles` table only + tracks local users. + + Args: + user_ids: List of user IDs to filter against. + + Returns: + Dictionary of displayname/avatar_url/custom fields for a list of users. + """ + if not user_ids: + return {} + + rows = await self.db_pool.simple_select_many_batch( + table="profiles", + column="full_user_id", + iterable=user_ids, + retcols=("full_user_id", "displayname", "avatar_url", "fields"), + desc="get_profile_data_for_users", + ) + + results: dict[str, dict[str, JsonValue | dict[str, JsonValue]]] = {} + for full_user_id, displayname, avatar_url, fields in rows: + user_fields = fields or {} + # The SQLite driver doesn't automatically convert JSON to + # Python objects + if isinstance(self.database_engine, Sqlite3Engine) and fields: + user_fields = json.loads(fields) + base_fields = { + ProfileFields.DISPLAYNAME: displayname, + ProfileFields.AVATAR_URL: avatar_url, + } + user_fields.update(base_fields) + + results[full_user_id] = user_fields + + return results + + async def add_profile_updates( + self, + user_id: UserID, + action: ProfileUpdateAction, + updated_fields: set[str] | None, + ) -> int: + """Persist profile update markers and return the last stream ID.""" + assert self._can_write_to_profile_updates + + if action == ProfileUpdateAction.UPDATE and not updated_fields: + return self._profile_updates_id_gen.get_current_token() + elif action == ProfileUpdateAction.LEFT_ROOM: + assert not updated_fields + + user_id_str = user_id.to_string() + + def _add_profile_updates_txn(txn: LoggingTransaction) -> int: + values = [] + inserted_ts = self.clock.time_msec() + if updated_fields: + stream_ids = self._profile_updates_id_gen.get_next_mult_txn( + txn, len(updated_fields) + ) + for stream_id, field_name in zip(stream_ids, updated_fields): + values.append( + [ + stream_id, + self._instance_name, + user_id_str, + action.value, + field_name, + inserted_ts, + ] + ) + else: + stream_ids = [self._profile_updates_id_gen.get_next_txn(txn)] + values.append( + [ + stream_ids[0], + self._instance_name, + user_id_str, + action.value, + None, + inserted_ts, + ] + ) + self.db_pool.simple_insert_many_txn( + txn, + table="profile_updates", + keys=[ + "stream_id", + "instance_name", + "user_id", + "action", + "field_name", + "inserted_ts", + ], + values=values, + ) + + return stream_ids[-1] + + return await self.db_pool.runInteraction( + "add_profile_updates", _add_profile_updates_txn + ) + + async def track_profile_updates_per_user( + self, + stream_id: int, + user_ids: set[str], + ) -> None: + """ + Create tracking rows for profile updater per target user interested in profile + updates for the user triggering one, including themselves. + + Args: + stream_id: Stream ID referencing a `profile_updates` stream ID. + user_ids: A set of the full user IDs of the target users interested in + this change. + """ + + def _track_profile_updates_per_user_txn(txn: LoggingTransaction) -> None: + inserted_ts = self.clock.time_msec() + values = [(stream_id, user_id, inserted_ts) for user_id in user_ids] + self.db_pool.simple_insert_many_txn( + txn, + table="profile_updates_per_user", + keys=[ + "stream_id", + "user_id", + "inserted_ts", + ], + values=values, + ) + + return await self.db_pool.runInteraction( + "track_profile_updates_per_user", + _track_profile_updates_per_user_txn, + ) + async def create_profile(self, user_id: UserID) -> None: """ Create a blank profile for a user. @@ -310,7 +715,7 @@ def _check_profile_size( txn: LoggingTransaction, user_id: UserID, new_field_name: str, - new_value: JsonValue, + new_value: JsonValue | dict[str, JsonValue], ) -> None: # For each entry there are 4 quotes (2 each for key and value), 1 colon, # and 1 comma. @@ -437,7 +842,10 @@ def set_profile_avatar_url(txn: LoggingTransaction) -> None: ) async def set_profile_field( - self, user_id: UserID, field_name: str, new_value: JsonValue + self, + user_id: UserID, + field_name: str, + new_value: JsonValue | dict[str, JsonValue], ) -> None: """ Set a custom profile field for a user. @@ -546,6 +954,193 @@ async def delete_profile(self, user_id: UserID) -> None: keyvalues={"full_user_id": user_id.to_string()}, ) + async def clear_profile_updates_for_user( + self, user_id: UserID, users_to_remove: set[str] + ) -> None: + """ + Clear all the ProfileUpdateAction.UPDATE rows from the + `profile_updates_per_user` table from a particular user for + a list of target users. + + This does not remove the stream ID row from `profile_updates` as it is + likely other per user rows may refer to it. Our automatic pruning of old + stream ID's will kick in later and clean up potential orphan `profile_updates` + table rows. + + Args: + user_id: The user's ID. + users_to_remove: List of users to remove per user rows for. + + Returns: + None + """ + assert self._can_write_to_profile_updates + if not users_to_remove: + return + + def _clear_profile_updates_for_user_txn( + txn: LoggingTransaction, + ) -> None: + sql = """ + SELECT stream_id FROM profile_updates + WHERE user_id = ? AND action = ? + """ + + txn.execute(sql, (user_id.to_string(), ProfileUpdateAction.UPDATE.value)) + res = txn.fetchall() + if not res: + return + + stream_ids = [row[0] for row in res] + + user_clause, user_args = make_in_list_sql_clause( + txn.database_engine, + "user_id", + users_to_remove, + ) + stream_id_clause, stream_id_args = make_in_list_sql_clause( + txn.database_engine, + "stream_id", + stream_ids, + ) + txn.execute( + f""" + DELETE FROM profile_updates_per_user + WHERE {user_clause} + AND {stream_id_clause} + """, + (*user_args, *stream_id_args), + ) + + await self.db_pool.runInteraction( + "clear_profile_updates_for_user", + _clear_profile_updates_for_user_txn, + ) + + @wrap_as_background_process("prune_profile_updates") + async def _prune_profile_updates(self) -> None: + """Delete old entries out of the `profile_updates` and + `profile_updates_per_user` tables, so that the tables don't grow indefinitely. + """ + prune_before_ts = self.clock.time_msec() - PRUNE_PROFILE_UPDATES_AGE.as_millis() + + def get_prune_before_stream_id_txn(txn: LoggingTransaction) -> int | None: + txn.execute( + """ + SELECT stream_id FROM profile_updates + WHERE inserted_ts <= ? + ORDER BY inserted_ts DESC + LIMIT 1 + """, + (prune_before_ts,), + ) + row = txn.fetchone() + return row[0] if row else None + + prune_before_stream_id = await self.db_pool.runInteraction( + "prune_profile_updates_get_stream_id", + get_prune_before_stream_id_txn, + ) + + if prune_before_stream_id is None: + return + + # Get the max stream ID in the table so we avoid deleting it. We need + # to keep the latest row so that we can calculate the maximum stream ID + # used. + max_stream_id = await self.db_pool.simple_select_one_onecol( + table="profile_updates", + keyvalues={}, + retcol="MAX(stream_id)", + desc="prune_profile_updates_get_max_stream_id", + ) + if prune_before_stream_id >= max_stream_id: + prune_before_stream_id = max_stream_id - 1 + + logger.debug( + "Pruning profile_updates before stream ID %d (timestamp %d)", + prune_before_stream_id, + prune_before_ts, + ) + # Now delete all rows with stream_id less than the + # prune_before_stream_id. + # + # We also delete in batches to avoid massive churn when initially + # clearing out all the old entries. + # + # We set a minimum stream ID so that when we delete in batches the + # database doesn't have to scan through all the (dead) tuples that were just + # deleted to find the next batch to delete. + + # The minimum stream ID to delete in the next batch, c.f. comment above. + # We default to 0 here as that is less than all possible stream IDs. + min_stream_id = 0 + + def prune_profile_updates_txn(txn: LoggingTransaction) -> int: + nonlocal min_stream_id + + assert table in ("profile_updates", "profile_updates_per_user") + txn.execute( + f""" + DELETE FROM {table} + WHERE stream_id IN ( + SELECT stream_id FROM {table} + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + ) + RETURNING stream_id + """, + ( + min_stream_id, + prune_before_stream_id, + PRUNE_PROFILE_UPDATES_BATCH_SIZE, + ), + ) + + # We can't use rowcount as that is incorrect on SQLite when using + # RETURNING. + num_deleted = 0 + for (deleted_stream_id,) in txn: + num_deleted += 1 + min_stream_id = max(min_stream_id, deleted_stream_id) + + return num_deleted + + # Do this twice, first for the per_user table, then for the main table + for table in ("profile_updates_per_user", "profile_updates"): + progress_num_rows_deleted = 0 + while True: + batch_deleted = await self.db_pool.runInteraction( + f"prune_{table}", + prune_profile_updates_txn, + ) + + finished = batch_deleted < PRUNE_PROFILE_UPDATES_BATCH_SIZE + + progress_num_rows_deleted += batch_deleted + + # Periodically report progress in the logs. We do this either when + # we've deleted a significant number of rows or when we've finished + # deleting all rows in this round. + if finished or progress_num_rows_deleted > 10000: + logger.info( + "Pruned %d rows from %s", + progress_num_rows_deleted, + table, + ) + progress_num_rows_deleted = 0 + + if finished: + break + + # Sleep for a short time to avoid hammering the database too much if + # there are a lot of rows to delete. + await self.clock.sleep(Duration(milliseconds=100)) + + # Reset the minimum stream id for our next table + min_stream_id = 0 + class ProfileStore(ProfileWorkerStore): pass diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 736f3e4c781..8c69266d510 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -982,6 +982,20 @@ async def get_users_who_share_room_with_user(self, user_id: str) -> set[str]: return user_who_share_room + async def get_local_users_who_share_room_with_user(self, user_id: str) -> set[str]: + """Returns the set of local users who share a room with `user_id`. + + This also includes the `user_id` themselves. + """ + room_ids = await self.get_rooms_for_user(user_id) + + user_who_share_room: set[str] = set() + for room_id in room_ids: + user_ids = await self.get_local_users_in_room(room_id) + user_who_share_room.update(user_ids) + + return user_who_share_room + @cached(cache_context=True, iterable=True) async def get_mutual_rooms_between_users( self, user_ids: frozenset[str], cache_context: _CacheContext diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 1afc6d0b2a6..3495dce866a 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -175,6 +175,7 @@ Changes in SCHEMA_VERSION = 94 - Add `recheck` column (boolean, default true) to the `redactions` table. - MSC4242: Add state DAG tables. + - MSC4429: Track updates to user profile fields via a new stream. """ diff --git a/synapse/storage/schema/main/delta/94/05_profile_updates.sql b/synapse/storage/schema/main/delta/94/05_profile_updates.sql new file mode 100644 index 00000000000..b612ceea861 --- /dev/null +++ b/synapse/storage/schema/main/delta/94/05_profile_updates.sql @@ -0,0 +1,55 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2026 Element Creations Ltd. +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +-- Track updates to profile fields for MSC4429 legacy /sync. +CREATE TABLE IF NOT EXISTS profile_updates ( + stream_id BIGINT NOT NULL PRIMARY KEY, + instance_name TEXT NOT NULL, + + -- The full user ID + user_id TEXT NOT NULL, + + -- Profile action that has happened, see ProfileUpdateAction enum. + action TEXT NOT NULL, + + -- Profile field name that has been updated, + -- see https://spec.matrix.org/unstable/client-server-api/#profiles + -- This is only required if "action" is "update" + field_name TEXT NULL, + + -- Unix timestamp. Used to determine when to cull rows (to prevent the table + -- from growing indefinitely). + inserted_ts BIGINT NOT NULL +); + +CREATE INDEX IF NOT EXISTS profile_updates_by_user ON profile_updates (user_id, stream_id); +CREATE INDEX IF NOT EXISTS profile_updates_by_field ON profile_updates (field_name, stream_id); +CREATE INDEX IF NOT EXISTS profile_updates_inserted_ts ON profile_updates (inserted_ts); + +-- Track which local users should receive each profile update. +CREATE TABLE IF NOT EXISTS profile_updates_per_user ( + id $%AUTO_INCREMENT_PRIMARY_KEY%$, + + -- Stream ID reference to `profile_updates` + stream_id BIGINT NOT NULL REFERENCES profile_updates (stream_id), + + -- The full user ID of the local user that should receive the profile update. + user_id TEXT NOT NULL, + + -- Unix timestamp. Used to determine when to cull rows (to prevent the table + -- from growing indefinitely). + inserted_ts BIGINT NOT NULL +); + +CREATE INDEX IF NOT EXISTS profile_updates_per_user_by_user_stream ON profile_updates_per_user (user_id, stream_id); +CREATE INDEX IF NOT EXISTS profile_updates_per_user_inserted_ts ON profile_updates_per_user (inserted_ts); diff --git a/synapse/storage/schema/main/delta/94/05_profile_updates_seq.sql.postgres b/synapse/storage/schema/main/delta/94/05_profile_updates_seq.sql.postgres new file mode 100644 index 00000000000..9abf79b68de --- /dev/null +++ b/synapse/storage/schema/main/delta/94/05_profile_updates_seq.sql.postgres @@ -0,0 +1,18 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2026 Element Creations Ltd. +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +CREATE SEQUENCE profile_updates_sequence; +-- Synapse streams start at 2, because the default position is 1 +-- so any item inserted at position 1 is ignored. +-- We have to use nextval not START WITH 2, see https://github.com/element-hq/synapse/issues/18712 +SELECT nextval('profile_updates_sequence'); diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 36490fcb355..24120eb7362 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -86,6 +86,7 @@ def get_current_token(self) -> StreamToken: thread_subscriptions_key = self.store.get_max_thread_subscriptions_stream_id() sticky_events_key = self.store.get_max_sticky_events_stream_id() quarantined_media_key = self.store.get_quarantined_media_stream_token() + profile_updates_key = self.store.get_max_profile_updates_stream_id() token = StreamToken( room_key=self.sources.room.get_current_key(), @@ -102,6 +103,7 @@ def get_current_token(self) -> StreamToken: thread_subscriptions_key=thread_subscriptions_key, sticky_events_key=sticky_events_key, quarantined_media_key=quarantined_media_key, + profile_updates_key=profile_updates_key, ) return token @@ -131,6 +133,7 @@ async def bound_future_token(self, token: StreamToken) -> StreamToken: StreamKeyType.THREAD_SUBSCRIPTIONS: self.store.get_thread_subscriptions_stream_id_generator(), StreamKeyType.STICKY_EVENTS: self.store.get_sticky_events_stream_id_generator(), StreamKeyType.QUARANTINED_MEDIA: self.store.get_quarantined_media_stream_id_generator(), + StreamKeyType.PROFILE_UPDATES: self.store.get_profile_updates_stream_id_generator(), } for _, key in StreamKeyType.__members__.items(): diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index a6fc806701d..ef99d575546 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -1095,6 +1095,7 @@ class StreamKeyType(Enum): THREAD_SUBSCRIPTIONS = "thread_subscriptions_key" STICKY_EVENTS = "sticky_events_key" QUARANTINED_MEDIA = "quarantined_media_key" + PROFILE_UPDATES = "profile_updates_key" @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -1102,7 +1103,7 @@ class StreamToken: """A collection of keys joined together by underscores in the following order and which represent the position in their respective streams. - ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1_379_4242_4141_4343` + ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1_379_4242_4141_4343_4444` 1. `room_key`: `s2633508` which is a `RoomStreamToken` - `RoomStreamToken`'s can also look like `t426-2633508` or `m56~2.58~3.59` - See the docstring for `RoomStreamToken` for more details. @@ -1118,6 +1119,7 @@ class StreamToken: 11. `thread_subscriptions_key`: 4242 12. `sticky_events_key`: 4141 13. `quarantined_media_key`: 4343 + 14. `profile_updates_key`: 4444 You can see how many of these keys correspond to the various fields in a "/sync" response: @@ -1181,6 +1183,7 @@ class StreamToken: quarantined_media_key: MultiWriterStreamToken = attr.ib( validator=attr.validators.instance_of(MultiWriterStreamToken) ) + profile_updates_key: int _SEPARATOR = "_" START: ClassVar["StreamToken"] @@ -1211,6 +1214,7 @@ async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": thread_subscriptions_key, sticky_events_key, quarantined_media_key, + profile_updates_key, ) = keys return cls( @@ -1231,6 +1235,7 @@ async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": quarantined_media_key=await MultiWriterStreamToken.parse( store, quarantined_media_key ), + profile_updates_key=int(profile_updates_key), ) except CancelledError: raise @@ -1256,6 +1261,7 @@ async def to_string(self, store: "DataStore") -> str: str(self.thread_subscriptions_key), str(self.sticky_events_key), await self.quarantined_media_key.to_string(store), + str(self.profile_updates_key), ] ) @@ -1329,6 +1335,7 @@ def get_field( StreamKeyType.UN_PARTIAL_STATED_ROOMS, StreamKeyType.THREAD_SUBSCRIPTIONS, StreamKeyType.STICKY_EVENTS, + StreamKeyType.PROFILE_UPDATES, ], ) -> int: ... @@ -1384,9 +1391,10 @@ def __str__(self) -> str: f"typing: {self.typing_key}, receipt: {self.receipt_key}, " f"account_data: {self.account_data_key}, push_rules: {self.push_rules_key}, " f"to_device: {self.to_device_key}, device_list: {self.device_list_key}, " - f"groups: {self.groups_key}, un_partial_stated_rooms: {self.un_partial_stated_rooms_key}," - f"thread_subscriptions: {self.thread_subscriptions_key}, sticky_events: {self.sticky_events_key}" - f"quarantined_media: {self.quarantined_media_key})" + f"groups: {self.groups_key}, un_partial_stated_rooms: {self.un_partial_stated_rooms_key}, " + f"thread_subscriptions: {self.thread_subscriptions_key}, sticky_events: {self.sticky_events_key}, " + f"quarantined_media: {self.quarantined_media_key}), " + f"profile_updates: {self.profile_updates_key})" ) @@ -1404,6 +1412,7 @@ def __str__(self) -> str: thread_subscriptions_key=0, sticky_events_key=0, quarantined_media_key=MultiWriterStreamToken(stream=0), + profile_updates_key=0, ) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 5152e8fc536..0bf940d24b7 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -26,12 +26,13 @@ from twisted.internet.testing import MemoryReactor import synapse.types -from synapse.api.constants import EventTypes +from synapse.api.constants import EventTypes, ProfileUpdateAction from synapse.api.errors import AuthError, SynapseError from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer -from synapse.types import JsonDict, UserID +from synapse.storage.databases.main.profile import ProfileUpdate +from synapse.types import JsonDict, StreamKeyType, UserID from synapse.types.state import StateFilter from synapse.util.clock import Clock from synapse.util.duration import Duration @@ -62,8 +63,10 @@ def register_query_handler( self.query_handlers[query_type] = handler self.mock_registry.register_query_handler = register_query_handler + self.mock_hs_notifier = Mock() hs = self.setup_test_homeserver( + notifier=self.mock_hs_notifier, federation_client=self.mock_federation, federation_server=Mock(), federation_registry=self.mock_registry, @@ -83,6 +86,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.frank_token = self.login(self.frank.localpart, "frankpassword") self.handler = hs.get_profile_handler() + self.on_new_event = self.mock_hs_notifier.on_new_event def test_get_my_name(self) -> None: self.get_success(self.store.set_profile_displayname(self.frank, "Frank")) @@ -161,6 +165,414 @@ def test_update_room_membership_on_set_displayname(self) -> None: ) self.assertEqual(membership[state_tuple].content["displayname"], "Frank Jr.") + @parameterized.expand( + [ + ["displayname", "Frank"], + ["avatar_url", "mxc://foobar"], + ["m.status", '{"text": "Holiday", "emoji": "🏖"}'], + ] + ) + def test_update_profile_does_not_update_stream_on_set_field_if_msc4429_not_enabled( + self, + field_name: str, + new_value: str, + ) -> None: + """Test that profile updates don't get recorded in the profile updates stream + if MSC4429 is not enabled.""" + self.get_success( + self.handler.set_field( + target_user=self.frank, + requester=synapse.types.create_requester(self.frank), + field_name=field_name, + new_value=new_value, + ) + ) + updates = self.get_success( + self.store.get_updated_profile_updates( + from_id=1, + to_id=2, + limit=1, + ) + ) + self.assertEqual(len(updates), 0) + + @parameterized.expand( + [ + ["displayname", "Frank"], + ["avatar_url", "mxc://foobar"], + ["m.status", '{"text": "Holiday", "emoji": "🏖"}'], + ] + ) + def test_update_profile_does_not_notify_notifier_on_set_field_if_msc4429_not_enabled( + self, + field_name: str, + new_value: str, + ) -> None: + """Test that profile updates do not cause the profile updates stream notifier + to wake up if MSC4429 is not enabled.""" + self.get_success( + self.handler.set_field( + target_user=self.frank, + requester=synapse.types.create_requester(self.frank), + field_name=field_name, + new_value=new_value, + ) + ) + + calls_found = [ + call + for call in self.on_new_event.mock_calls + if call.args[0] == StreamKeyType.PROFILE_UPDATES + ] + self.assertEqual(len(calls_found), 0) + + @parameterized.expand( + [ + ["displayname", "Frank"], + ["avatar_url", "mxc://foobar"], + ["m.status", '{"text": "Holiday", "emoji": "🏖"}'], + ] + ) + @override_config({"include_profile_updates_in_sync": True}) + def test_update_profile_does_not_notify_notifier_on_set_field_if_user_not_in_rooms( + self, field_name: str, new_value: str + ) -> None: + """Test that profile updates do not cause the profile updates stream notifier + to wake up if the user is not in any rooms, if MSC4429 is enabled.""" + self.get_success( + self.handler.set_field( + target_user=self.frank, + requester=synapse.types.create_requester(self.frank), + field_name=field_name, + new_value=new_value, + ) + ) + calls_found = [ + call + for call in self.on_new_event.mock_calls + if call.args[0] == StreamKeyType.PROFILE_UPDATES + ] + self.assertEqual(len(calls_found), 0) + + @parameterized.expand( + [ + ["displayname", "Frank"], + ["avatar_url", "mxc://foobar"], + ["m.status", '{"text": "Holiday", "emoji": "🏖"}'], + ] + ) + @override_config({"include_profile_updates_in_sync": True}) + def test_update_profile_updates_stream_on_set_field( + self, field_name: str, new_value: str + ) -> None: + """Test that profile updates get recorded in the profile updates stream if + MSC4429 is enabled.""" + self.get_success( + self.handler.set_field( + target_user=self.frank, + requester=synapse.types.create_requester(self.frank), + field_name=field_name, + new_value=new_value, + ) + ) + updates = self.get_success( + self.store.get_updated_profile_updates( + from_id=1, + to_id=2, + limit=1, + ) + ) + self.assertEqual( + updates[0], + (2, "@1234abcd:test", ProfileUpdateAction.UPDATE.value, field_name), + ) + + fields_updates = self.get_success( + self.store.get_profile_updates_for_fields( + from_id=1, + to_id=2, + field_names=[field_name], + ) + ) + self.assertEqual( + fields_updates[0], + ProfileUpdate( + stream_id=2, + user_id="@1234abcd:test", + action=ProfileUpdateAction.UPDATE.value, + field_name=field_name, + ), + ) + + self.get_success( + self.handler.set_field( + target_user=self.frank, + requester=synapse.types.create_requester(self.frank), + field_name=field_name, + new_value="", + ) + ) + delete_updates = self.get_success( + self.store.get_updated_profile_updates( + from_id=2, + to_id=3, + limit=1, + ) + ) + self.assertEqual( + delete_updates[0], + (3, "@1234abcd:test", ProfileUpdateAction.UPDATE.value, field_name), + ) + + @override_config({"include_profile_updates_in_sync": True}) + def test_update_profile_set_field_writes_to_per_user_profile_tracking_table( + self, + ) -> None: + """Test that profiles updates get recorded in the 'per user' profile updates + stream tracking table, if MSC4429 is enabled.""" + self.register_user("roger", "password") + roger_token = self.login("roger", "password") + self.register_user("millie", "password") + millie_token = self.login("millie", "password") + room_id = self.helper.create_room_as( + room_creator=self.frank.to_string(), + tok=self.frank_token, + ) + self.helper.join(room_id, "@roger:test", tok=roger_token) + self.helper.join(room_id, "@millie:test", tok=millie_token) + self.get_success( + self.handler.set_field( + target_user=self.frank, + requester=synapse.types.create_requester(self.frank), + field_name="m.status", + new_value='{"text": "Holiday"}', + ) + ) + per_user_updates = self.get_success( + self.store.get_profile_updates_for_user_and_fields( + from_id=0, + to_id=10, + user_id="@roger:test", + field_names={"m.status"}, + ) + ) + self.assertEqual( + per_user_updates, + [ + ProfileUpdate( + stream_id=3, + user_id="@millie:test", + action="joined_room", + field_name=None, + ), + ProfileUpdate( + stream_id=4, + user_id=self.frank.to_string(), + action="update", + field_name="m.status", + ), + ], + ) + per_user_updates = self.get_success( + self.store.get_profile_updates_for_user_and_fields( + from_id=0, + to_id=10, + user_id="@millie:test", + field_names={"m.status"}, + ) + ) + self.assertEqual( + per_user_updates, + [ + ProfileUpdate( + stream_id=4, + user_id=self.frank.to_string(), + action="update", + field_name="m.status", + ), + ], + ) + per_user_updates = self.get_success( + self.store.get_profile_updates_for_user_and_fields( + from_id=0, + to_id=10, + user_id=self.frank.to_string(), + field_names={"m.status"}, + ) + ) + self.assertEqual( + per_user_updates, + [ + ProfileUpdate( + stream_id=2, + user_id="@roger:test", + action="joined_room", + field_name=None, + ), + ProfileUpdate( + stream_id=3, + user_id="@millie:test", + action="joined_room", + field_name=None, + ), + ProfileUpdate( + stream_id=4, + user_id=self.frank.to_string(), + action="update", + field_name="m.status", + ), + ], + ) + + @override_config({"include_profile_updates_in_sync": True}) + def test_previous_profile_updates_stream_rows_cleared_if_no_longer_sharing_a_room( + self, + ) -> None: + """Test that previous profile update stream rows are removed for a user if + the user no longer shares rooms with another user, if MSC4429 is enabled. + + This test ensures that when a user leaves a room, we clear all old profile + update rows of users who the user no longer shares rooms with, to avoid + leaking any further profile field updates from those users. + """ + self.register_user("roger", "password") + roger_token = self.login("roger", "password") + self.register_user("millie", "password") + millie_token = self.login("millie", "password") + room_id = self.helper.create_room_as( + room_creator=self.frank.to_string(), + tok=self.frank_token, + ) + room_with_millie_id = self.helper.create_room_as( + room_creator=self.frank.to_string(), + tok=self.frank_token, + ) + self.helper.join(room_id, "@roger:test", tok=roger_token) + self.helper.join(room_with_millie_id, "@millie:test", tok=millie_token) + self.get_success( + self.handler.set_field( + target_user=self.frank, + requester=synapse.types.create_requester(self.frank), + field_name="m.status", + new_value='{"text": "Holiday"}', + ) + ) + per_user_updates = self.get_success( + self.store.get_profile_updates_for_user_and_fields( + from_id=0, + to_id=10, + user_id="@roger:test", + field_names={"m.status"}, + ) + ) + self.assertEqual( + per_user_updates, + [ + ProfileUpdate( + stream_id=4, + user_id=self.frank.to_string(), + action="update", + field_name="m.status", + ), + ], + ) + per_user_updates = self.get_success( + self.store.get_profile_updates_for_user_and_fields( + from_id=0, + to_id=10, + user_id="@millie:test", + field_names={"m.status"}, + ) + ) + self.assertEqual( + per_user_updates, + [ + ProfileUpdate( + stream_id=4, + user_id=self.frank.to_string(), + action="update", + field_name="m.status", + ), + ], + ) + + # Leave room and verify only the "left room" exists for roger + self.helper.leave(room_id, self.frank.to_string(), tok=self.frank_token) + per_user_updates = self.get_success( + self.store.get_profile_updates_for_user_and_fields( + from_id=0, + to_id=10, + user_id="@roger:test", + field_names={"m.status"}, + ) + ) + self.assertEqual( + per_user_updates, + [ + ProfileUpdate( + stream_id=5, + user_id=self.frank.to_string(), + action="left_room", + field_name=None, + ), + ], + ) + + # Sanity check we didn't clear any rows for millie + per_user_updates = self.get_success( + self.store.get_profile_updates_for_user_and_fields( + from_id=0, + to_id=10, + user_id="@millie:test", + field_names={"m.status"}, + ) + ) + self.assertEqual( + per_user_updates, + [ + ProfileUpdate( + stream_id=4, + user_id=self.frank.to_string(), + action="update", + field_name="m.status", + ), + ], + ) + + @parameterized.expand( + [ + ["displayname", "Frank"], + ["avatar_url", "mxc://foobar"], + ["m.status", '{"text": "Holiday", "emoji": "🏖"}'], + ] + ) + @override_config({"include_profile_updates_in_sync": True}) + def test_update_profile_notifies_notifier_on_set_field( + self, + field_name: str, + new_value: str, + ) -> None: + """Test that profile updates wake up the profile updates stream on profile + field updates, if MSC4429 is enabled.""" + self.helper.create_room_as( + room_creator=self.frank.to_string(), + tok=self.frank_token, + ) + self.get_success( + self.handler.set_field( + target_user=self.frank, + requester=synapse.types.create_requester(self.frank), + field_name=field_name, + new_value=new_value, + ) + ) + calls_found = [ + call + for call in self.on_new_event.mock_calls + if call.args[0] == StreamKeyType.PROFILE_UPDATES + ] + self.assertEqual(len(calls_found), 1) + def test_background_update_room_membership_on_set_displayname(self) -> None: """Test that `set_displayname` returns immediately and that room membership updates are still done in background.""" diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index d2b2523321b..96c47191a93 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -18,7 +18,7 @@ # # from http import HTTPStatus -from typing import Collection, ContextManager +from typing import Collection, ContextManager, cast from unittest.mock import AsyncMock, Mock, patch from parameterized import parameterized, parameterized_class @@ -43,6 +43,7 @@ from synapse.server import HomeServer from synapse.types import ( JsonDict, + JsonValue, MultiWriterStreamToken, RoomStreamToken, StreamKeyType, @@ -55,6 +56,8 @@ import tests.unittest import tests.utils from tests.test_utils.event_builders import make_test_pdu_event +from tests.test_utils.event_injection import inject_member_event +from tests.unittest import override_config _request_key = 0 @@ -1152,6 +1155,1031 @@ def generate_sync_config( ) +class SyncProfileUpdatesTestCase(tests.unittest.HomeserverTestCase): + """Tests Sync Handler for profile updates.""" + + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + super().prepare(reactor, clock, hs) + self.sync_handler = self.hs.get_sync_handler() + self.profile_handler = self.hs.get_profile_handler() + self.store = self.hs.get_datastores().main + self.user = self.register_user("user", "password") + self.tok = self.login("user", "password") + self.other_user = self.register_user("other_user", "password") + self.other_tok = self.login("other_user", "password") + self.joined_room = self.helper.create_room_as(self.user, tok=self.tok) + self.get_success( + self.store.set_profile_field( + user_id=UserID.from_string(self.user), + field_name="m.status", + new_value={"text": "Swimming in the Great Lakes!", "emoji": "🏊"}, + ) + ) + self.helper.join( + room=self.joined_room, user=self.other_user, tok=self.other_tok + ) + + def test_initial_sync_no_profile_updates_if_not_enabled(self) -> None: + """Test that without MSC4429 enabled the initial sync response does not + contain any profile updates.""" + self.get_success( + self.profile_handler.set_field( + target_user=UserID.from_string(self.other_user), + requester=create_requester(self.other_user), + field_name="m.status", + new_value={"text": "On holiday", "emoji": "🏖"}, + ) + ) + + requester = create_requester(self.user) + initial_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + sync_config=generate_sync_config( + self.user, + ), + request_key=generate_request_key(), + ) + ) + self.assertEqual(initial_result.profile_updates, {}) + + @override_config({"include_profile_updates_in_sync": True}) + def test_initial_sync_no_profile_updates_if_not_filtered_for(self) -> None: + """Test that with MSC4429 enabled the initial sync response does not + contain any profile updates, if fields are not filtered for.""" + self.get_success( + self.profile_handler.set_field( + target_user=UserID.from_string(self.other_user), + requester=create_requester(self.other_user), + field_name="m.status", + new_value={"text": "On holiday", "emoji": "🏖"}, + ) + ) + + requester = create_requester(self.user) + initial_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + sync_config=generate_sync_config( + user_id=self.user, + ), + request_key=generate_request_key(), + ) + ) + self.assertEqual( + initial_result.profile_updates, + {}, + ) + + @override_config({"include_profile_updates_in_sync": True}) + def test_initial_sync_responds_with_tracked_profile_updates(self) -> None: + """Test that with MSC4429 enabled the initial sync response does + contain profile updates for users who share rooms, for the fields the + client requests. This response should include our syncing user.""" + self.get_success( + self.profile_handler.set_field( + target_user=UserID.from_string(self.other_user), + requester=create_requester(self.other_user), + field_name="m.status", + new_value={"text": "On holiday", "emoji": "🏖"}, + ) + ) + # Also set a field the client doesn't want + self.get_success( + self.profile_handler.set_field( + target_user=UserID.from_string(self.other_user), + requester=create_requester(self.other_user), + field_name="displayname", + new_value="New displayname", + ) + ) + + requester = create_requester(self.user) + initial_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + sync_config=generate_sync_config( + user_id=self.user, + filter_collection=FilterCollection( + hs=self.hs, + filter_json={ + "org.matrix.msc4429.profile_fields": {"ids": ["m.status"]} + }, + ), + ), + request_key=generate_request_key(), + ) + ) + assert initial_result.profile_updates[self.user] is not None + assert initial_result.profile_updates["@other_user:test"] is not None + self.assertEqual( + initial_result.profile_updates["@other_user:test"]["m.status"], + {"text": "On holiday", "emoji": "🏖"}, + ) + self.assertFalse( + "displayname" in initial_result.profile_updates["@other_user:test"].keys(), + ) + self.assertCountEqual( + initial_result.profile_updates.keys(), + [ + self.user, + "@other_user:test", + ], + ) + + @parameterized.expand( + [ + True, + False, + ] + ) + @override_config({"include_profile_updates_in_sync": True}) + def test_initial_sync_does_not_include_untracked_users_profile_updates( + self, is_lazy: bool + ) -> None: + """Test that with MSC4429 enabled the initial sync response does not + contain profile updates for users who do not share rooms.""" + third_user = self.register_user("third_user", "password") + self.get_success( + self.profile_handler.set_field( + target_user=UserID.from_string(third_user), + requester=create_requester(third_user), + field_name="m.status", + new_value={"text": "On holiday", "emoji": "🏖"}, + ) + ) + + requester = create_requester(self.user) + filter_json: dict[str, dict] = { + "org.matrix.msc4429.profile_fields": { + "ids": ["m.status", "displayname", "avatar_url"] + } + } + if is_lazy: + filter_json["room"] = { + "state": { + "lazy_load_members": True, + }, + } + initial_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + sync_config=generate_sync_config( + user_id=self.user, + filter_collection=FilterCollection( + hs=self.hs, + filter_json=filter_json, + ), + ), + request_key=generate_request_key(), + ) + ) + self.assertIsNone(initial_result.profile_updates.get(third_user)) + + @override_config({"include_profile_updates_in_sync": True}) + def test_initial_sync_lazy_loading_responds_with_only_profiles_with_events( + self, + ) -> None: + """Test that with MSC4429 enabled the initial sync lazy loading response does + contain profile updates for events in the timeline. + + This test ensures lazy loading sync only returns profiles that we also have + events for in the sync response. The second room in this test has the most + recent events from "third_user" and thus we don't get the profile of + "other_user" down the line, who is in the the same rooms as the syncer, + but not in the second room. + """ + third_user = self.register_user("third_user", "password") + third_tok = self.login("third_user", "password") + self.helper.join( + room=self.joined_room, + user=third_user, + tok=third_tok, + ) + + requester = create_requester(self.user) + + self.get_success( + self.profile_handler.set_field( + target_user=UserID.from_string(self.other_user), + requester=create_requester(self.other_user), + field_name="m.status", + new_value={"text": "On holiday", "emoji": "🏖"}, + ) + ) + # Check that lazy-loading filters out profile updates as well on initial sync. + self.get_success( + self.profile_handler.set_field( + target_user=UserID.from_string(third_user), + requester=create_requester(third_user), + field_name="m.status", + new_value={"text": "On fire", "emoji": "🔥"}, + ) + ) + self.helper.send_messages( + room_id=self.joined_room, num_events=1, tok=self.other_tok + ) + self.helper.send_messages( + room_id=self.joined_room, num_events=10, tok=third_tok + ) + initial_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + sync_config=generate_sync_config( + user_id=self.user, + filter_collection=FilterCollection( + hs=self.hs, + filter_json={ + "org.matrix.msc4429.profile_fields": { + "ids": ["m.status", "displayname", "avatar_url"] + }, + "room": { + "state": { + "lazy_load_members": True, + }, + }, + }, + ), + ), + request_key=generate_request_key(), + ) + ) + # Only third_user is returned, as lazy loading filters out the events from + # the other users + self.assertCountEqual( + initial_result.profile_updates.keys(), + [ + "@third_user:test", + ], + ) + + @override_config({"include_profile_updates_in_sync": True}) + def test_incremental_sync_sends_down_profile_update_diffs( + self, + ) -> None: + """Test that with MSC4429 enabled the incremental sync response does + contain profile update diffs.""" + requester = create_requester(self.user) + initial_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + sync_config=generate_sync_config( + user_id=self.user, + filter_collection=FilterCollection( + hs=self.hs, + filter_json={ + "org.matrix.msc4429.profile_fields": { + "ids": ["m.status", "displayname", "avatar_url"] + } + }, + ), + ), + request_key=generate_request_key(), + ) + ) + self.get_success( + self.profile_handler.set_field( + target_user=UserID.from_string(self.other_user), + requester=create_requester(self.other_user), + field_name="m.status", + new_value={"text": "On holiday", "emoji": "🏖"}, + ) + ) + # Set a field the client didn't ask for + self.get_success( + self.profile_handler.set_field( + target_user=UserID.from_string(self.other_user), + requester=create_requester(self.other_user), + field_name="uninterestingfield", + new_value="Content", + ) + ) + incremental_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + since_token=initial_result.next_batch, + sync_config=generate_sync_config( + user_id=self.user, + filter_collection=FilterCollection( + hs=self.hs, + filter_json={ + "org.matrix.msc4429.profile_fields": { + "ids": ["m.status", "displayname", "avatar_url"] + } + }, + ), + ), + request_key=generate_request_key(), + ) + ) + assert incremental_result.profile_updates["@other_user:test"] is not None + self.assertEqual( + incremental_result.profile_updates["@other_user:test"]["m.status"], + {"text": "On holiday", "emoji": "🏖"}, + ) + # We only send diffs in incremental sync for profile field updates + self.assertFalse( + "displayname" + in incremental_result.profile_updates["@other_user:test"].keys(), + ) + # The client didn't ask for this field + self.assertFalse( + "uninterestingfield" + in incremental_result.profile_updates["@other_user:test"].keys(), + ) + + @override_config({"include_profile_updates_in_sync": True}) + def test_incremental_sync_does_not_filter_profile_updates_when_lazy_loading( + self, + ) -> None: + """Test that with MSC4429 enabled the incremental sync lazy loading response + does contain profile updates even if the user would be filtered out by lazy + loading. + """ + third_user = self.register_user("third_user", "password") + third_tok = self.login("third_user", "password") + self.helper.join( + room=self.joined_room, + user=third_user, + tok=third_tok, + ) + + requester = create_requester(self.user) + initial_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + sync_config=generate_sync_config( + user_id=self.user, + filter_collection=FilterCollection( + hs=self.hs, + filter_json={ + "org.matrix.msc4429.profile_fields": { + "ids": ["m.status", "displayname"] + } + }, + ), + ), + request_key=generate_request_key(), + ) + ) + self.get_success( + self.profile_handler.set_field( + target_user=UserID.from_string(self.other_user), + requester=create_requester(self.other_user), + field_name="m.status", + new_value={"text": "On holiday", "emoji": "🏖"}, + ) + ) + self.get_success( + self.profile_handler.set_field( + target_user=UserID.from_string(third_user), + requester=create_requester(third_user), + field_name="m.status", + new_value={"text": "On fire", "emoji": "🔥"}, + ) + ) + self.helper.send_messages( + room_id=self.joined_room, num_events=1, tok=self.other_tok + ) + self.helper.send_messages( + room_id=self.joined_room, num_events=10, tok=third_tok + ) + self.get_success( + self.profile_handler.set_field( + target_user=UserID.from_string(third_user), + requester=create_requester(third_user), + field_name="uninterestingfield", + new_value="Content", + ) + ) + # Join a federated user to the room + self.get_success( + inject_member_event( + self.hs, + self.joined_room, + "@federateduser:federatedhs", + "join", + ) + ) + incremental_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + since_token=initial_result.next_batch, + sync_config=generate_sync_config( + user_id=self.user, + filter_collection=FilterCollection( + hs=self.hs, + filter_json={ + "org.matrix.msc4429.profile_fields": { + "ids": ["m.status", "displayname"] + }, + "room": { + "state": { + "lazy_load_members": True, + }, + }, + }, + ), + ), + request_key=generate_request_key(), + ) + ) + + # Ensure our federated user is filtered out, even though they have an + # event in the joined room timeline + self.assertFalse( + "@federateduser:federatedhs" in incremental_result.profile_updates.keys() + ) + + # Lazy loading only filters initial sync profile updates. Incremental syncs + # should include all tracked profile updates for the syncing user. + self.assertCountEqual( + incremental_result.profile_updates.keys(), + [ + "@other_user:test", + "@third_user:test", + ], + ) + assert incremental_result.profile_updates["@other_user:test"] is not None + + # This is a field update, so should be here + self.assertEqual( + incremental_result.profile_updates["@other_user:test"]["m.status"], + {"text": "On holiday", "emoji": "🏖"}, + ) + + # We don't have events for this user in this response, so their full profile + # is not included + self.assertFalse( + "displayname" + in incremental_result.profile_updates["@other_user:test"].keys(), + ) + assert incremental_result.profile_updates["@third_user:test"] is not None + + # This user has events in the timeline, thus the fields the client asked for + # are included + self.assertEqual( + incremental_result.profile_updates["@third_user:test"]["m.status"], + {"text": "On fire", "emoji": "🔥"}, + ) + self.assertFalse( + "uninterestingfield" + in incremental_result.profile_updates["@third_user:test"].keys(), + ) + self.assertEqual( + incremental_result.profile_updates["@third_user:test"]["displayname"], + "third_user", + ) + + @parameterized.expand( + [ + [True, True], + [False, False], + [True, False], + [False, True], + ] + ) + @override_config({"include_profile_updates_in_sync": True}) + def test_sync_filters_out_profile_updates_from_federated_users( + self, + is_initial: bool, + is_lazy: bool, + ) -> None: + """Test that with MSC4429 enabled any sync response + doesn't contain federated users even if there are timeline events from them. + """ + # Join a federated user to the room, causing a membership event into + # the joined rooms sync response + self.get_success( + inject_member_event( + self.hs, + self.joined_room, + "@federateduser1:federatedhs", + "join", + ) + ) + requester = create_requester(self.user) + filter_json: dict[str, dict] = { + "org.matrix.msc4429.profile_fields": { + "ids": ["m.status", "displayname", "avatar_url"] + }, + } + if is_lazy: + filter_json["room"] = { + "state": { + "lazy_load_members": True, + }, + } + initial_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + sync_config=generate_sync_config( + user_id=self.user, + filter_collection=FilterCollection( + hs=self.hs, + filter_json=filter_json, + ), + ), + request_key=generate_request_key(), + ) + ) + # Ensure our federated user is filtered out, even though they have an + # event in the joined room timeline + self.assertFalse( + "@federateduser1:federatedhs" in initial_result.profile_updates.keys() + ) + if not is_initial: + # Join another federated user to the room, causing a membership event into + # the joined rooms sync response + self.get_success( + inject_member_event( + self.hs, + self.joined_room, + "@federateduser2:federatedhs", + "join", + ) + ) + incremental_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + since_token=initial_result.next_batch, + sync_config=generate_sync_config( + user_id=self.user, + filter_collection=FilterCollection( + hs=self.hs, + filter_json=filter_json, + ), + ), + request_key=generate_request_key(), + ) + ) + + # Ensure our federated user is filtered out, even though they have an + # event in the joined room timeline + self.assertFalse( + "@federateduser2:federatedhs" + in incremental_result.profile_updates.keys() + ) + + @parameterized.expand( + [ + [True, True], + [False, False], + [True, False], + [False, True], + ] + ) + @override_config({"include_profile_updates_in_sync": True}) + def test_sync_profile_updates_works_correctly_with_falsey_values( + self, + is_initial: bool, + is_lazy: bool, + ) -> None: + """Test that with MSC4429 enabled a sync response correctly includes falsey + profile field values. + """ + requester = create_requester(self.user) + filter_json: dict[str, dict] = { + "org.matrix.msc4429.profile_fields": {"ids": ["falseyvaluefield"]}, + } + if is_lazy: + filter_json["room"] = { + "state": { + "lazy_load_members": True, + }, + } + for value in [False, 0, "", [], {}, None]: + if is_initial: + self.get_success( + self.profile_handler.set_field( + target_user=UserID.from_string(self.other_user), + requester=create_requester(self.other_user), + field_name="falseyvaluefield", + new_value=cast(JsonValue | dict[str, JsonValue], value), + ) + ) + initial_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + sync_config=generate_sync_config( + user_id=self.user, + filter_collection=FilterCollection( + hs=self.hs, + filter_json=filter_json, + ), + ), + request_key=generate_request_key(), + ) + ) + if is_initial: + assert initial_result.profile_updates["@other_user:test"] is not None + self.assertEqual( + initial_result.profile_updates["@other_user:test"][ + "falseyvaluefield" + ], + value, + ) + else: + self.get_success( + self.profile_handler.set_field( + target_user=UserID.from_string(self.other_user), + requester=create_requester(self.other_user), + field_name="falseyvaluefield", + new_value=cast(JsonValue | dict[str, JsonValue], value), + ) + ) + incremental_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + since_token=initial_result.next_batch, + sync_config=generate_sync_config( + user_id=self.user, + filter_collection=FilterCollection( + hs=self.hs, + filter_json=filter_json, + ), + ), + request_key=generate_request_key(), + ) + ) + assert ( + incremental_result.profile_updates["@other_user:test"] is not None + ) + self.assertEqual( + incremental_result.profile_updates["@other_user:test"][ + "falseyvaluefield" + ], + value, + ) + + @override_config({"include_profile_updates_in_sync": True}) + def test_incremental_sync_lazy_loading_cache_filters_recently_sent_profiles_and_fields( + self, + ) -> None: + """Test that with MSC4429 enabled the incremental sync lazy loading response + filters out unchanged profiles or fields we have recently sent to the client. + """ + requester = create_requester(self.user) + self.get_success( + self.profile_handler.set_field( + target_user=UserID.from_string(self.other_user), + requester=create_requester(self.other_user), + field_name="sooninterestingfield", + new_value="Content", + ) + ) + initial_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + sync_config=generate_sync_config( + user_id=self.user, + filter_collection=FilterCollection( + hs=self.hs, + filter_json={ + "org.matrix.msc4429.profile_fields": { + "ids": ["m.status", "displayname", "avatar_url"] + }, + }, + ), + ), + request_key=generate_request_key(), + ) + ) + self.helper.send_messages( + room_id=self.joined_room, + num_events=1, + tok=self.other_tok, + ) + incremental_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + since_token=initial_result.next_batch, + sync_config=generate_sync_config( + user_id=self.user, + filter_collection=FilterCollection( + hs=self.hs, + filter_json={ + "org.matrix.msc4429.profile_fields": { + "ids": ["m.status", "displayname", "avatar_url"] + }, + "room": { + "state": { + "lazy_load_members": True, + }, + }, + }, + ), + ), + request_key=generate_request_key(), + ) + ) + # Lazy loading incremental sync should include profiles from events + self.assertCountEqual( + incremental_result.profile_updates.keys(), + [ + "@other_user:test", + ], + ) + assert incremental_result.profile_updates["@other_user:test"] is not None + self.assertEqual( + set(incremental_result.profile_updates["@other_user:test"].keys()), + {"avatar_url", "displayname"}, + ) + + # If we have more events from the other_user, and do another lazy sync, + # we don't expect the full profile to be sent again due to our cache. + self.helper.send_messages( + room_id=self.joined_room, num_events=1, tok=self.other_tok + ) + incremental_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + since_token=incremental_result.next_batch, + sync_config=generate_sync_config( + user_id=self.user, + filter_collection=FilterCollection( + hs=self.hs, + filter_json={ + "org.matrix.msc4429.profile_fields": { + "ids": ["m.status", "displayname", "avatar_url"] + }, + "room": { + "state": { + "lazy_load_members": True, + }, + }, + }, + ), + ), + request_key=generate_request_key(), + ) + ) + self.assertCountEqual( + incremental_result.profile_updates.keys(), + [], + ) + # However, if we again add an event, we do expect any fields the client didn't + # previously ask for to be there. + self.helper.send_messages( + room_id=self.joined_room, num_events=1, tok=self.other_tok + ) + incremental_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + since_token=incremental_result.next_batch, + sync_config=generate_sync_config( + user_id=self.user, + filter_collection=FilterCollection( + hs=self.hs, + filter_json={ + "org.matrix.msc4429.profile_fields": { + "ids": [ + "m.status", + "displayname", + "avatar_url", + "sooninterestingfield", + ] + }, + "room": { + "state": { + "lazy_load_members": True, + }, + }, + }, + ), + ), + request_key=generate_request_key(), + ) + ) + self.assertCountEqual( + incremental_result.profile_updates.keys(), + [ + "@other_user:test", + ], + ) + assert incremental_result.profile_updates["@other_user:test"] is not None + self.assertEqual( + set(incremental_result.profile_updates["@other_user:test"].keys()), + {"sooninterestingfield"}, + ) + + @override_config({"include_profile_updates_in_sync": True}) + def test_incremental_sync_sends_down_null_profile_if_user_no_longer_sharing_rooms( + self, + ) -> None: + """Test that with MSC4429 enabled the incremental sync response + includes a 'null' for users who are no longer sharing rooms. + """ + requester = create_requester(self.user) + initial_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + sync_config=generate_sync_config( + user_id=self.user, + filter_collection=FilterCollection( + hs=self.hs, + filter_json={ + "org.matrix.msc4429.profile_fields": { + "ids": ["m.status", "displayname", "avatar_url"] + } + }, + ), + ), + request_key=generate_request_key(), + ) + ) + self.helper.leave( + room=self.joined_room, user=self.other_user, tok=self.other_tok + ) + incremental_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + since_token=initial_result.next_batch, + sync_config=generate_sync_config( + user_id=self.user, + filter_collection=FilterCollection( + hs=self.hs, + filter_json={ + "org.matrix.msc4429.profile_fields": { + "ids": ["m.status", "displayname", "avatar_url"] + } + }, + ), + ), + request_key=generate_request_key(), + ) + ) + self.assertIsNone( + incremental_result.profile_updates["@other_user:test"], + ) + + @override_config({"include_profile_updates_in_sync": True}) + def test_incremental_sync_sends_down_all_requested_fields_for_users_who_have_joined( + self, + ) -> None: + """Test that with MSC4429 enabled the incremental sync response + includes all the requested fields of a user who has joined a room with the + syncing user. + """ + requester = create_requester(self.user) + initial_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + sync_config=generate_sync_config( + user_id=self.user, + filter_collection=FilterCollection( + hs=self.hs, + filter_json={ + "org.matrix.msc4429.profile_fields": { + "ids": ["displayname", "avatar_url"] + }, + }, + ), + ), + request_key=generate_request_key(), + ) + ) + incremental_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + since_token=initial_result.next_batch, + sync_config=generate_sync_config( + user_id=self.user, + filter_collection=FilterCollection( + hs=self.hs, + filter_json={ + "org.matrix.msc4429.profile_fields": { + "ids": ["displayname", "avatar_url"] + }, + }, + ), + ), + request_key=generate_request_key(), + ) + ) + + third_user = self.register_user("third_user", "password") + third_tok = self.login("third_user", "password") + self.helper.join( + room=self.joined_room, + user=third_user, + tok=third_tok, + ) + # Set a status field we don't except to see in sync + self.get_success( + self.profile_handler.set_field( + target_user=UserID.from_string(third_user), + requester=create_requester(third_user), + field_name="m.status", + new_value={"text": "On fire", "emoji": "🔥"}, + ) + ) + incremental_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + since_token=incremental_result.next_batch, + sync_config=generate_sync_config( + user_id=self.user, + filter_collection=FilterCollection( + hs=self.hs, + filter_json={ + "org.matrix.msc4429.profile_fields": { + "ids": ["displayname", "avatar_url"] + }, + }, + ), + ), + request_key=generate_request_key(), + ) + ) + assert incremental_result.profile_updates["@third_user:test"] is not None + self.assertCountEqual( + incremental_result.profile_updates.keys(), + [third_user], + ) + self.assertEqual( + incremental_result.profile_updates["@third_user:test"]["displayname"], + "third_user", + ) + self.assertIsNone( + incremental_result.profile_updates["@third_user:test"]["avatar_url"], + ) + self.assertFalse( + "m.status" in incremental_result.profile_updates["@third_user:test"].keys(), + ) + + @parameterized.expand( + [ + True, + False, + ] + ) + @override_config({"include_profile_updates_in_sync": True}) + def test_incremental_sync_includes_own_profile_updates(self, is_lazy: bool) -> None: + """Test that with MSC4429 enabled the incremental sync response includes + ones own profile updates.""" + requester = create_requester(self.user) + filter_json: dict[str, dict] = { + "org.matrix.msc4429.profile_fields": {"ids": ["m.status", "avatar_url"]} + } + if is_lazy: + filter_json["room"] = { + "state": { + "lazy_load_members": True, + }, + } + initial_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + sync_config=generate_sync_config( + user_id=self.user, + filter_collection=FilterCollection( + hs=self.hs, + filter_json=filter_json, + ), + ), + request_key=generate_request_key(), + ) + ) + self.get_success( + self.profile_handler.set_field( + target_user=UserID.from_string(self.user), + requester=requester, + field_name="m.status", + new_value={"text": "On holiday", "emoji": "🏖"}, + ) + ) + incremental_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + since_token=initial_result.next_batch, + sync_config=generate_sync_config( + user_id=self.user, + filter_collection=FilterCollection( + hs=self.hs, + filter_json=filter_json, + ), + ), + request_key=generate_request_key(), + ) + ) + assert incremental_result.profile_updates["@user:test"] is not None + self.assertEqual( + incremental_result.profile_updates["@user:test"]["m.status"], + {"text": "On holiday", "emoji": "🏖"}, + ) + # We didn't ask for displayname + self.assertFalse( + "displayname" in incremental_result.profile_updates["@user:test"].keys(), + ) + + class SyncStateAfterTestCase(tests.unittest.HomeserverTestCase): """Tests Sync Handler state behavior when using `use_state_after.""" diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index c4e4170c6f9..4deb3c29f41 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -2549,7 +2549,7 @@ def test_timestamp_to_event(self) -> None: def test_topo_token_is_accepted(self) -> None: """Test Topo Token is accepted.""" - token = "t1-0_0_0_0_0_0_0_0_0_0_0_0_0" + token = "t1-0_0_0_0_0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token), @@ -2563,7 +2563,7 @@ def test_topo_token_is_accepted(self) -> None: def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None: """Test that stream token is accepted for forward pagination.""" - token = "s0_0_0_0_0_0_0_0_0_0_0_0_0" + token = "s0_0_0_0_0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token), diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 8d4892ae91f..793a66f3434 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -2248,7 +2248,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_id = self.helper.create_room_as(self.user_id) def test_topo_token_is_accepted(self) -> None: - token = "t1-0_0_0_0_0_0_0_0_0_0_0_0_0" + token = "t1-0_0_0_0_0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) @@ -2259,7 +2259,7 @@ def test_topo_token_is_accepted(self) -> None: self.assertTrue("end" in channel.json_body) def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None: - token = "s0_0_0_0_0_0_0_0_0_0_0_0_0" + token = "s0_0_0_0_0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index dbaf2986975..6f96c1c4b05 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -18,14 +18,19 @@ # [This file includes modifications made by New Vector Limited] # # +import itertools +from unittest.mock import patch from twisted.internet.testing import MemoryReactor +from synapse.api.constants import ProfileUpdateAction from synapse.server import HomeServer from synapse.storage.database import LoggingTransaction +from synapse.storage.databases.main.profile import PRUNE_PROFILE_UPDATES_AGE from synapse.storage.engines import PostgresEngine from synapse.types import UserID from synapse.util.clock import Clock +from synapse.util.duration import Duration from tests import unittest @@ -132,3 +137,116 @@ def f(txn: LoggingTransaction) -> None: ) self.assertEqual(len(res), len(expected_values)) self.assertEqual(res, expected_values) + + @patch("synapse.storage.databases.main.profile.PRUNE_PROFILE_UPDATES_BATCH_SIZE", 5) + def test_prune_profile_updates(self) -> None: + """Test that old entries in the `profile_updates` and `profile_updates_per_user` + tables are pruned properly.""" + + # Create a generator for field names so we can easily create many unique + # field names without having to keep track of the count ourselves. + field_name_gen = (f"field{i}" for i in itertools.count()) + + def get_profile_updates_status() -> tuple[int, str]: + """Helper function to get the count of entries in the + `profile_updates` table.""" + return self.get_success( + self.store.db_pool.simple_select_one( + table="profile_updates", + keyvalues={}, + retcols=("COUNT(*)", "MIN(field_name)"), + ) + ) + + def get_profile_updates_per_user_status() -> tuple[int]: + """Helper function to get the count of entries in the + `profile_updates_per_user` table.""" + return self.get_success( + self.store.db_pool.simple_select_one( + table="profile_updates_per_user", + keyvalues={}, + retcols=("COUNT(*)",), + ) + ) + + # First add some entries + for _ in range(10): + stream_id = self.get_success( + self.store.add_profile_updates( + user_id=UserID.from_string("@user:test"), + updated_fields={next(field_name_gen)}, + action=ProfileUpdateAction.UPDATE, + ) + ) + self.get_success( + self.store.track_profile_updates_per_user( + stream_id=stream_id, + user_ids={"@alice:test", "@bob:test"}, + ) + ) + + # Advance the reactor a while, but not long enough to trigger pruning. + self.reactor.advance(Duration(hours=1).as_secs()) + + # The `profile_updates_per_user` table should now have 10 * 2 entries. + per_user_count = get_profile_updates_per_user_status() + self.assertEqual(per_user_count[0], 20) + # The `profile_updates` table should have 10 entries. + # and the minimum field name should be `field0`. + updates_count, min_field_name = get_profile_updates_status() + self.assertEqual(updates_count, 10) + self.assertEqual(min_field_name, "field0") + + # Now we add some more entries + for _ in range(10): + stream_id = self.get_success( + self.store.add_profile_updates( + user_id=UserID.from_string("@user:test"), + updated_fields={next(field_name_gen)}, + action=ProfileUpdateAction.UPDATE, + ) + ) + self.get_success( + self.store.track_profile_updates_per_user( + stream_id=stream_id, + user_ids={"@alice:test", "@bob:test"}, + ) + ) + + # Advance the reactor a while more, so that the first batch of entries is + # now old enough to be pruned. + self.reactor.advance( + (PRUNE_PROFILE_UPDATES_AGE - Duration(minutes=30)).as_secs() + ) + + # Advance repeatedly a bit so that the pruning process can run to completion. + for _ in range(10): + self.reactor.advance(Duration(milliseconds=110).as_secs()) + + # Check that the old entries have been pruned, and the new entries are still there. + # The `profile_updates_per_user` table should now have 10 * 2 entries. + per_user_count = get_profile_updates_per_user_status() + self.assertEqual(per_user_count[0], 20) + # The `profile_updates` table should have 10 entries. + # and the minimum field name should be `field10`. + updates_count, min_field_name = get_profile_updates_status() + self.assertEqual(updates_count, 10) + self.assertEqual(min_field_name, "field10") + + # We should always keep the most recent entries, even if they are old enough to be pruned. + self.reactor.advance( + (PRUNE_PROFILE_UPDATES_AGE + Duration(minutes=30)).as_secs() + ) + + # Advance repeatedly a bit so that the pruning process can run to completion. + for _ in range(10): + self.reactor.advance(Duration(milliseconds=110).as_secs()) + + # The `profile_updates_per_user` table should now have 2 entries. + per_user_count = get_profile_updates_per_user_status() + self.assertEqual(per_user_count[0], 2) + # The `profile_updates` table should have 1 entry. + # and the minimum field name should be `field19`. + updates_count, min_field_name = get_profile_updates_status() + self.assertEqual(updates_count, 1) + self.assertEqual(min_field_name, "field19")