diff --git a/changelog.d/19871.misc b/changelog.d/19871.misc new file mode 100644 index 00000000000..be10ee05403 --- /dev/null +++ b/changelog.d/19871.misc @@ -0,0 +1 @@ +Update `HomeserverTestCase.get_success(...)` and friends to drive async Rust (Tokio runtime/thread pool). diff --git a/synapse/http/client.py b/synapse/http/client.py index 05c5f13a874..78f03ae58a9 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -87,8 +87,7 @@ from synapse.metrics import SERVER_NAME_LABEL from synapse.types import ISynapseReactor, StrSequence from synapse.util.async_helpers import timeout_deferred -from synapse.util.clock import Clock -from synapse.util.duration import Duration +from synapse.util.clock import CLOCK_SCHEDULE_EPSILON, Clock from synapse.util.json import json_decoder if TYPE_CHECKING: @@ -163,11 +162,6 @@ def _is_ip_blocked( return False -# The delay used by the scheduler to schedule tasks "as soon as possible", while -# still allowing other tasks to run between runs. -_EPSILON = Duration(microseconds=1) - - def _make_scheduler(clock: Clock) -> Callable[[Callable[[], object]], IDelayedCall]: """Makes a schedular suitable for a Cooperator using the given reactor. @@ -176,7 +170,7 @@ def _make_scheduler(clock: Clock) -> Callable[[Callable[[], object]], IDelayedCa def _scheduler(x: Callable[[], object]) -> IDelayedCall: return clock.call_later( - _EPSILON, + CLOCK_SCHEDULE_EPSILON, x, ) diff --git a/synapse/util/clock.py b/synapse/util/clock.py index 7232a1331c8..8c056757323 100644 --- a/synapse/util/clock.py +++ b/synapse/util/clock.py @@ -62,6 +62,19 @@ logging.setLoggerClass(original_logger_class) +CLOCK_SCHEDULE_EPSILON = Duration(microseconds=1) +""" +The smallest value we can use that will schedule tasks "as soon as possible", while +still allowing other tasks to run between runs. + +This should be a non-zero value as the Twisted Reactor API does not specify how calls +get scheduled. If we used `0`, a weird reactor implementation could run it immediately +or run it any order with the other calls that are scheduled now. + +We want the semantics of run this in the "next reactor iteration". +""" + + def _try_wakeup_deferred(d: Deferred) -> None: """Try to wake up a deferred, but ignore any exceptions raised by the callback. This is useful when we want to wake up a deferred that may have diff --git a/tests/app/test_homeserver_shutdown.py b/tests/app/test_homeserver_shutdown.py index 0f5d1c73387..20d314cb682 100644 --- a/tests/app/test_homeserver_shutdown.py +++ b/tests/app/test_homeserver_shutdown.py @@ -76,6 +76,13 @@ async def shutdown() -> None: self.get_success(shutdown()) + # XXX: There can be a few already dispatched database queries (from normal + # background tasks in Synapse) and the threadless `ThreadPool` that we use in + # tests uses *untracked* clock calls to pass database results back so `shutdown` + # doesn't cancel those calls. This is a quirk of our test infrastructure + # (threadless `ThreadPool`) so this kind of "hack" is fine. + self.reactor.advance(0) + # Cleanup the internal reference in our test case del self.hs @@ -106,7 +113,7 @@ def test_clean_homeserver_shutdown_mid_background_updates(self) -> None: # Pump the background updates by a single iteration, just to ensure any extra # resources it uses have been started. store = weakref.proxy(self.hs.get_datastores().main) - self.get_success(store.db_pool.updates.do_next_background_update(False), by=0.1) + self.get_success(store.db_pool.updates.do_next_background_update(False)) hs_ref = weakref.ref(self.hs) @@ -127,6 +134,13 @@ async def shutdown() -> None: self.get_success(shutdown()) + # XXX: There can be a few already dispatched database queries (from normal + # background tasks in Synapse) and the threadless `ThreadPool` that we use in + # tests uses *untracked* clock calls to pass database results back so `shutdown` + # doesn't cancel those calls. This is a quirk of our test infrastructure + # (threadless `ThreadPool`) so this kind of "hack" is fine. + self.reactor.advance(0) + # Cleanup the internal reference in our test case del self.hs diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 6bc935f2720..01561b0d413 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -499,7 +499,7 @@ async def get_json(destination: str, path: str, **kwargs: Any) -> JsonDict: res = key_json[testverifykey_id] self.assertIsNotNone(res) assert res is not None - self.assertEqual(res.added_ts, self.reactor.seconds() * 1000) + self.assertEqual(res.added_ts, self.clock.time_msec()) self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS) # we expect it to be encoded as canonical json *before* it hits the db @@ -614,7 +614,7 @@ def test_get_keys_from_perspectives(self) -> None: res = key_json[testverifykey_id] self.assertIsNotNone(res) assert res is not None - self.assertEqual(res.added_ts, self.reactor.seconds() * 1000) + self.assertEqual(res.added_ts, self.clock.time_msec()) self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS) self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response)) @@ -732,7 +732,7 @@ def test_get_perspectives_own_key(self) -> None: res = key_json[testverifykey_id] self.assertIsNotNone(res) assert res is not None - self.assertEqual(res.added_ts, self.reactor.seconds() * 1000) + self.assertEqual(res.added_ts, self.clock.time_msec()) self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS) self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response)) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 794c0a3185f..0c7edbaa2da 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -357,7 +357,6 @@ def create_invite() -> EventBase: event.room_version, ), exc=LimitExceededError, - by=0.5, ) def _build_and_send_join_event( diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py index 6af2c068a45..995a1134b2b 100644 --- a/tests/handlers/test_oauth_delegation.py +++ b/tests/handlers/test_oauth_delegation.py @@ -21,15 +21,13 @@ import json import threading -import time from http.server import BaseHTTPRequestHandler, HTTPServer -from typing import Any, ClassVar, Coroutine, Generator, TypeVar, Union +from typing import Any, ClassVar, TypeVar from unittest.mock import AsyncMock, Mock from urllib.parse import parse_qs from parameterized.parameterized import parameterized_class -from twisted.internet.defer import Deferred, ensureDeferred from twisted.internet.testing import MemoryReactor from synapse.api.auth.mas import MasDelegatedAuth @@ -204,31 +202,6 @@ class MasAuthDelegation(HomeserverTestCase): def device_scope(self) -> str: return self.device_scope_prefix + DEVICE - def till_deferred_has_result( - self, - awaitable: Union[ - "Coroutine[Deferred[Any], Any, T]", - "Generator[Deferred[Any], Any, T]", - "Deferred[T]", - ], - ) -> "Deferred[T]": - """Wait until a deferred has a result. - - This is useful because the Rust HTTP client will resolve the deferred - using reactor.callFromThread, which are only run when we call - reactor.advance. - """ - deferred = ensureDeferred(awaitable) - tries = 0 - while not deferred.called: - time.sleep(0.1) - self.reactor.advance(0) - tries += 1 - if tries > 100: - raise Exception("Timed out waiting for deferred to resolve") - - return deferred - def default_config(self) -> dict[str, Any]: config = super().default_config() config["public_baseurl"] = BASE_URL @@ -278,11 +251,7 @@ def test_simple_introspection(self) -> None: "expires_in": 60, } - requester = self.get_success( - self.till_deferred_has_result( - self._auth.get_user_by_access_token("some_token") - ) - ) + requester = self.get_success(self._auth.get_user_by_access_token("some_token")) self.assertEqual(requester.user.to_string(), USER_ID) self.assertEqual(requester.device_id, DEVICE) @@ -301,11 +270,7 @@ def test_unexpiring_token(self) -> None: "username": USERNAME, } - requester = self.get_success( - self.till_deferred_has_result( - self._auth.get_user_by_access_token("some_token") - ) - ) + requester = self.get_success(self._auth.get_user_by_access_token("some_token")) self.assertEqual(requester.user.to_string(), USER_ID) self.assertEqual(requester.device_id, DEVICE) @@ -326,9 +291,7 @@ def test_inexistent_device(self) -> None: } failure = self.get_failure( - self.till_deferred_has_result( - self._auth.get_user_by_access_token("some_token") - ), + self._auth.get_user_by_access_token("some_token"), InvalidClientTokenError, ) self.assertEqual(failure.value.code, 401) @@ -343,9 +306,7 @@ def test_inexistent_user(self) -> None: } failure = self.get_failure( - self.till_deferred_has_result( - self._auth.get_user_by_access_token("some_token") - ), + self._auth.get_user_by_access_token("some_token"), AuthError, ) # This is a 500, it should never happen really @@ -361,9 +322,7 @@ def test_missing_scope(self) -> None: } failure = self.get_failure( - self.till_deferred_has_result( - self._auth.get_user_by_access_token("some_token") - ), + self._auth.get_user_by_access_token("some_token"), InvalidClientTokenError, ) self.assertEqual(failure.value.code, 401) @@ -372,9 +331,7 @@ def test_invalid_response(self) -> None: self.server.introspection_response = {} failure = self.get_failure( - self.till_deferred_has_result( - self._auth.get_user_by_access_token("some_token") - ), + self._auth.get_user_by_access_token("some_token"), SynapseError, ) self.assertEqual(failure.value.code, 503) @@ -389,11 +346,7 @@ def test_device_id_in_body(self) -> None: "device_id": DEVICE, } - requester = self.get_success( - self.till_deferred_has_result( - self._auth.get_user_by_access_token("some_token") - ) - ) + requester = self.get_success(self._auth.get_user_by_access_token("some_token")) self.assertEqual(requester.device_id, DEVICE) @@ -406,11 +359,7 @@ def test_admin_scope(self) -> None: "expires_in": 60, } - requester = self.get_success( - self.till_deferred_has_result( - self._auth.get_user_by_access_token("some_token") - ) - ) + requester = self.get_success(self._auth.get_user_by_access_token("some_token")) self.assertEqual(requester.user.to_string(), USER_ID) self.assertTrue(self.get_success(self._auth.is_server_admin(requester))) @@ -435,17 +384,15 @@ def test_cached_expired_introspection(self) -> None: request.requestHeaders.getRawHeaders = mock_getRawHeaders() # The first CS-API request causes a successful introspection - self.get_success( - self.till_deferred_has_result(self._auth.get_user_by_req(request)) - ) + self.get_success(self._auth.get_user_by_req(request)) self.assertEqual(self.server.calls, 1) # Sleep for 60 seconds so the token expires. self.reactor.advance(60.0) # Now the CS-API request fails because the token expired - self.assertFailure( - self.till_deferred_has_result(self._auth.get_user_by_req(request)), + self.get_failure( + self._auth.get_user_by_req(request), InvalidClientTokenError, ) # Ensure another introspection request was not sent diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 62b84c77a4d..b81c2954dc2 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -960,7 +960,7 @@ def test_exchange_code_jwt_key(self) -> None: # advance the clock a bit before we start, so we aren't working with zero # timestamps. self.reactor.advance(1000) - start_time = self.reactor.seconds() + start_time_s = int(self.reactor.seconds()) ret = self.get_success(self.provider._exchange_code(code, code_verifier="")) self.assertEqual(ret, token) @@ -981,8 +981,8 @@ def test_exchange_code_jwt_key(self) -> None: self.assertEqual(claims["aud"], ISSUER) self.assertEqual(claims["iss"], "DEFGHI") self.assertEqual(claims["sub"], CLIENT_ID) - self.assertEqual(claims["iat"], start_time) - self.assertGreater(claims["exp"], start_time) + self.assertEqual(claims["iat"], start_time_s) + self.assertGreater(claims["exp"], start_time_s) # check the rest of the POSTed data self.assertEqual(args["grant_type"], ["authorization_code"]) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 44f1e6432d6..5d02c701614 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -951,8 +951,7 @@ def test_external_process_timeout(self) -> None: self.get_success( worker_presence_handler.user_syncing( self.user_id, self.device_id, True, PresenceState.ONLINE - ), - by=0.1, + ) ) # Check that if we wait a while without telling the handler the user has @@ -1270,8 +1269,7 @@ def test_set_presence_from_syncing_multi_device( "dev-1", affect_presence=dev_1_state != PresenceState.OFFLINE, presence_state=dev_1_state, - ), - by=0.01, + ) ) # 2. Wait half the idle timer. @@ -1285,8 +1283,7 @@ def test_set_presence_from_syncing_multi_device( "dev-2", affect_presence=dev_2_state != PresenceState.OFFLINE, presence_state=dev_2_state, - ), - by=0.01, + ) ) # 4. Assert the expected presence state. @@ -1311,8 +1308,7 @@ def test_set_presence_from_syncing_multi_device( "dev-3", affect_presence=True, presence_state=PresenceState.ONLINE, - ), - by=0.01, + ) ): pass @@ -1507,8 +1503,7 @@ def test_set_presence_from_non_syncing_multi_device( "dev-1", affect_presence=dev_1_state != PresenceState.OFFLINE, presence_state=dev_1_state, - ), - by=0.1, + ) ) # 2. Sync with the second device. @@ -1518,8 +1513,7 @@ def test_set_presence_from_non_syncing_multi_device( "dev-2", affect_presence=dev_2_state != PresenceState.OFFLINE, presence_state=dev_2_state, - ), - by=0.1, + ) ) # 3. Assert the expected presence state. @@ -1625,8 +1619,7 @@ def test_set_presence_from_syncing_keeps_busy( self.get_success( worker_to_sync_against.get_presence_handler().user_syncing( self.user_id, self.device_id, True, PresenceState.ONLINE - ), - by=0.1, + ) ) # Check against the main process that the user's presence did not change. diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 5152e8fc536..561b45827fd 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -200,7 +200,7 @@ async def slow_update_membership(*args: Any, **kwargs: Any) -> tuple[str, int]: self.assertEqual(membership[state_tuple].content["displayname"], "Frank") # Let's be sure we are over the delay introduced by slow_update_membership - self.get_success(self.clock.sleep(Duration(milliseconds=20)), by=1) + self.reactor.advance(Duration(milliseconds=20).as_secs()) membership = self.get_success( self.storage_controllers.state.get_current_state( @@ -278,7 +278,7 @@ async def potentially_slow_update_membership( # Let's be sure we are over the delay introduced by slow_update_membership # and that the task was not executed as expected - self.get_success(self.clock.sleep(Duration(milliseconds=20)), by=1) + self.reactor.advance(Duration(milliseconds=20).as_secs()) membership = self.get_success( self.storage_controllers.state.get_current_state( @@ -299,8 +299,10 @@ async def potentially_slow_update_membership( ) ) + # Wait for the `TaskScheduler.SCHEDULE_INTERVAL` + self.reactor.advance(Duration(minutes=1).as_secs()) # Let's be sure we are over the delay introduced by slow_update_membership - self.get_success(self.clock.sleep(Duration(milliseconds=20)), by=1) + self.reactor.advance(Duration(milliseconds=20).as_secs()) # Updates should have been resumed from room 2 after the restart # so room 1 should not have been updated this time diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py index d5b95e4ef6b..0a7475856a8 100644 --- a/tests/handlers/test_room_member.py +++ b/tests/handlers/test_room_member.py @@ -71,7 +71,6 @@ def test_local_user_local_joins_contribute_to_limit_and_are_limited(self) -> Non action=Membership.JOIN, ), LimitExceededError, - by=0.5, ) @override_config({"rc_joins_per_room": {"per_second": 0.1, "burst_count": 2}}) @@ -213,7 +212,6 @@ def test_remote_joins_contribute_to_rate_limit(self) -> None: remote_room_hosts=[self.OTHER_SERVER_NAME], ), LimitExceededError, - by=0.5, ) # TODO: test that remote joins to a room are rate limited. @@ -281,7 +279,6 @@ def test_local_users_joining_on_another_worker_contribute_to_rate_limit( action=Membership.JOIN, ), LimitExceededError, - by=0.5, ) # Try to join as Chris on the original worker. Should get denied because Alice @@ -294,7 +291,6 @@ def test_local_users_joining_on_another_worker_contribute_to_rate_limit( action=Membership.JOIN, ), LimitExceededError, - by=0.5, ) diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py index eea88cd136b..acb88343f2e 100644 --- a/tests/handlers/test_send_email.py +++ b/tests/handlers/test_send_email.py @@ -146,7 +146,7 @@ def test_send_email(self) -> None: ) # the message should now get delivered - self.get_success(d, by=0.1) + self.get_success(d) # check it arrived self.assertEqual(len(message_delivery.messages), 1) @@ -213,7 +213,7 @@ def test_send_email_force_tls(self) -> None: ) # the message should now get delivered - self.get_success(d, by=0.1) + self.get_success(d) # check it arrived self.assertEqual(len(message_delivery.messages), 1) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 623eef0ecb6..0bbe0845470 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -248,6 +248,14 @@ def test_started_typing_remote_send(self) -> None: ) ) + # Wait for the EDU to get pushed out over federation + # + # `started_typing` is fire-and-forget and handles the remote federation part as + # part of a background process which isn't waited on. + # + # We're specifically waiting for the database queries in the background process + self.reactor.advance(0) + self.mock_federation_client.put_json.assert_called_once_with( "farm", path="/_matrix/federation/v1/send/1000000", @@ -367,6 +375,14 @@ def test_stopped_typing(self) -> None: [call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])] ) + # Wait for the EDU to get pushed out over federation + # + # `stopped_typing` is fire-and-forget and handles the remote federation part as + # part of a background process which isn't waited on. + # + # We're specifically waiting for the database queries in the background process + self.reactor.advance(0) + self.mock_federation_client.put_json.assert_called_once_with( "farm", path="/_matrix/federation/v1/send/1000000", diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index f50fa1f4a02..dc6738ca286 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -555,7 +555,15 @@ def test_process_join_after_server_leaves_room(self) -> None: # Process the leave and join in one go. dir_handler.update_user_directory = True dir_handler.notify_new_event() - self.wait_for_background_updates() + + # Wait for the user directory to update + # + # `notify_new_event` is fire-and-forget and the actual changes happen as part of + # a background process loop which isn't waited on. + # + # We're specifically waiting for the database queries in the `notify_new_event` + # background process. + self.reactor.advance(0) # The user sharing tables should have been updated. public3 = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) @@ -1124,7 +1132,6 @@ def test_local_user_leaving_room_remains_in_user_directory(self) -> None: # Alice leaves the other. She should still be in the directory. self.helper.leave(room2, alice, tok=alice_token) - self.wait_for_background_updates() users, in_public, in_private = self.get_success( self.user_dir_helper.get_tables() ) diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py index f25b507aac5..855a623ec09 100644 --- a/tests/media/test_media_storage.py +++ b/tests/media/test_media_storage.py @@ -132,12 +132,7 @@ async def test_ensure_media() -> None: # This uses a real blocking threadpool so we have to wait for it to be # actually done :/ - x = defer.ensureDeferred(test_ensure_media()) - - # Hotloop until the threadpool does its job... - self.wait_on_thread(x) - - self.get_success(x) + self.get_success(test_ensure_media()) @attr.s(auto_attribs=True, slots=True, frozen=True) diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py index a8eb7fc523c..d35191e654c 100644 --- a/tests/replication/tcp/test_handler.py +++ b/tests/replication/tcp/test_handler.py @@ -147,6 +147,12 @@ def test_wait_for_stream_position(self) -> None: # ... but worker1 finishing (and so sending an update) should. self.get_success(ctx_worker1.__aexit__(None, None, None)) + # Wait for the stream position to be replicated to the master process + # + # Replication travels over `FakeTransport` and we're specifically flushing the + # write + self.reactor.advance(0) + self.assertTrue(d.called) def test_wait_for_stream_position_rdata(self) -> None: @@ -206,6 +212,12 @@ def test_wait_for_stream_position_rdata(self) -> None: # Finish the context manager, triggering the data to be sent to master. self.get_success(ctx_worker1.__aexit__(None, None, None)) + # Wait for the stream position to be replicated to the master process + # + # Replication travels over `FakeTransport` and we're specifically flushing the + # write + self.reactor.advance(0) + # Master should get told about `next_token2`, so the deferred should # resolve. self.assertTrue(d.called) diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py index e6b9ea53832..c8de7b1fad6 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py @@ -81,6 +81,14 @@ def test_federation_ack_sent(self) -> None: ) ) + # Wait for the FEDERATION_ACK to be sent + # + # `on_rdata` handles this as part of a fire-and-forget background process (see + # `FederationSenderHandler.update_token`) + # + # We're specifically waiting for the database queries in the background process + self.reactor.advance(0) + # now check that the FEDERATION_ACK was sent mock_connection.send_command.assert_called_once() cmd = mock_connection.send_command.call_args[0][0] diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index add00453b6d..bce199c564f 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -59,7 +59,7 @@ from synapse.server import HomeServer from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY from synapse.types import JsonDict, UserID, create_requester -from synapse.util.clock import Clock +from synapse.util.clock import CLOCK_SCHEDULE_EPSILON, Clock from tests import unittest from tests.replication._base import BaseMultiWorkerStreamTestCase @@ -5850,10 +5850,16 @@ def test_redact_messages_all_rooms(self) -> None: self.assertEqual(channel.code, 200) id = channel.json_body.get("redact_id") - # Need 1 tick as we send 1 replication request per original event - # and each wait must be >= `_EPSILON` from `http/client.py` + # `/redact` just schedules a background task that runs in the background + # (fire-and-forget) so we need to do the waiting here. + # + # Need 1 tick as we send 1 replication request for the redaction of each + # original event. The replication request body is streamed by a `Cooperator` + # that uses the clock to schedule each chunk at a tiny *non-zero* delay + # (`CLOCK_SCHEDULE_EPSILON`), so we need to actually advance the clock for it to + # fire. for _ in range(len(original_event_ids)): - self.reactor.advance(0.001) + self.reactor.advance(CLOCK_SCHEDULE_EPSILON.as_secs()) # Verify the HTTP `redact_status` endpoint reports completion. channel2 = self.make_request( diff --git a/tests/server.py b/tests/server.py index ce5eaad63da..eea11b301f9 100644 --- a/tests/server.py +++ b/tests/server.py @@ -255,7 +255,7 @@ def registerProducer(self, producer: IProducer, streaming: bool) -> None: def _produce() -> None: if self._producer: self._producer.resumeProducing() - self._reactor.callLater(0.1, _produce) + self._reactor.callLater(0.0, _produce) if not streaming: self._reactor.callLater(0.0, _produce) @@ -940,7 +940,7 @@ def _produce() -> None: # mypy ignored here because: # - this is part of the test infrastructure (outside of Synapse) so tracking # these calls for for homeserver shutdown doesn't make sense. - d.addCallback(lambda x: self._reactor.callLater(0.1, _produce)) # type: ignore[call-later-not-tracked,call-overload] + d.addCallback(lambda x: self._reactor.callLater(0.0, _produce)) # type: ignore[call-later-not-tracked,call-overload] if not streaming: # mypy ignored here because: diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index e3f79d76707..139906e97ca 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -59,8 +59,8 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = self.hs.get_datastores().main async def update(self, progress: JsonDict, count: int) -> int: - duration_ms = 10 - await self.clock.sleep(Duration(milliseconds=count * duration_ms)) + fake_work_duration = Duration(seconds=1) + await self.clock.sleep(fake_work_duration) progress = {"my_key": progress["my_key"] + 1} await self.store.db_pool.runInteraction( "update_progress", @@ -86,10 +86,15 @@ def test_do_background_update(self) -> None: self.update_handler.side_effect = self.update self.update_handler.reset_mock() - res = self.get_success( - self.updates.do_next_background_update(False), - by=0.02, - ) + background_update_d = ensureDeferred( + self.updates.do_next_background_update(False) + ) + # Wait for database queries to run in `do_next_background_update(...)` so the + # background update actually gets scheduled + self.reactor.advance(0) + # Wait for the actual background update `fake_work_duration` + self.reactor.advance(Duration(seconds=1).as_secs()) + res = self.get_success(background_update_d) self.assertFalse(res) # on the first call, we should get run with the default background update size @@ -143,10 +148,15 @@ def test_background_update_default_batch_set_by_config(self) -> None: self.update_handler.side_effect = self.update self.update_handler.reset_mock() - res = self.get_success( - self.updates.do_next_background_update(False), - by=0.01, - ) + background_update_d = ensureDeferred( + self.updates.do_next_background_update(False) + ) + # Wait for database queries to run in `do_next_background_update(...)` so the + # background update actually gets scheduled + self.reactor.advance(0) + # Wait for the actual background update `fake_work_duration` + self.reactor.advance(Duration(seconds=1).as_secs()) + res = self.get_success(background_update_d) self.assertFalse(res) # on the first call, we should get run with the default background update size specified in the config @@ -265,10 +275,15 @@ def test_background_update_duration_set_in_config(self) -> None: self.update_handler.side_effect = self.update self.update_handler.reset_mock() - res = self.get_success( - self.updates.do_next_background_update(False), - by=0.02, - ) + background_update_d = ensureDeferred( + self.updates.do_next_background_update(False) + ) + # Wait for database queries to run in `do_next_background_update(...)` so the + # background update actually gets scheduled + self.reactor.advance(0) + # Wait for the actual background update `fake_work_duration` + self.reactor.advance(Duration(seconds=1).as_secs()) + res = self.get_success(background_update_d) self.assertFalse(res) # the first update was run with the default batch size, this should be run with 500ms as the @@ -298,9 +313,6 @@ def test_background_update_min_batch_set_in_config(self) -> None: """ Test that the minimum batch size set in the config is used """ - # a very long-running individual update - duration_ms = 50 - self.get_success( self.store.db_pool.simple_insert( "background_updates", @@ -310,7 +322,8 @@ def test_background_update_min_batch_set_in_config(self) -> None: # Run the update with the long-running update item async def update_long(progress: JsonDict, count: int) -> int: - await self.clock.sleep(Duration(milliseconds=count * duration_ms)) + very_long_fake_work_duration = Duration(seconds=5) + await self.clock.sleep(very_long_fake_work_duration) progress = {"my_key": progress["my_key"] + 1} await self.store.db_pool.runInteraction( "update_progress", @@ -322,10 +335,15 @@ async def update_long(progress: JsonDict, count: int) -> int: self.update_handler.side_effect = update_long self.update_handler.reset_mock() - res = self.get_success( - self.updates.do_next_background_update(False), - by=1, - ) + background_update_d = ensureDeferred( + self.updates.do_next_background_update(False) + ) + # Wait for database queries to run in `do_next_background_update(...)` so the + # background update actually gets scheduled + self.reactor.advance(0) + # Wait for the actual background update `very_long_fake_work_duration` + self.reactor.advance(Duration(seconds=5).as_secs()) + res = self.get_success(background_update_d) self.assertFalse(res) # the first update was run with the default batch size, this should be run with minimum batch size diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index 175a5ffc788..d09437c080b 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -755,7 +755,7 @@ def test_background_update_single_large_room(self) -> None: ): iterations += 1 self.get_success( - self.store.db_pool.updates.do_next_background_update(False), by=0.1 + self.store.db_pool.updates.do_next_background_update(False) ) # Ensure that we did actually take multiple iterations to process the @@ -814,7 +814,7 @@ def test_background_update_multiple_large_room(self) -> None: ): iterations += 1 self.get_success( - self.store.db_pool.updates.do_next_background_update(False), by=0.1 + self.store.db_pool.updates.do_next_background_update(False) ) # Ensure that we did actually take multiple iterations to process the diff --git a/tests/synapse_rust/test_http_client.py b/tests/synapse_rust/test_http_client.py index 56fab3a0e1d..845fe2b5033 100644 --- a/tests/synapse_rust/test_http_client.py +++ b/tests/synapse_rust/test_http_client.py @@ -15,9 +15,8 @@ import threading import time from http.server import BaseHTTPRequestHandler, HTTPServer -from typing import Any, Coroutine, Generator, TypeVar, Union +from typing import Any, TypeVar -from twisted.internet.defer import Deferred, ensureDeferred from twisted.internet.testing import MemoryReactor from synapse.logging.context import ( @@ -118,31 +117,6 @@ def tearDown(self) -> None: for callbable, args, kwargs in triggers: callbable(*args, **kwargs) - def till_deferred_has_result( - self, - awaitable: Union[ - "Coroutine[Deferred[Any], Any, T]", - "Generator[Deferred[Any], Any, T]", - "Deferred[T]", - ], - ) -> "Deferred[T]": - """Wait until a deferred has a result. - - This is useful because the Rust HTTP client will resolve the deferred - using reactor.callFromThread, which are only run when we call - reactor.advance. - """ - deferred = ensureDeferred(awaitable) - tries = 0 - while not deferred.called: - time.sleep(0.1) - self.reactor.advance(0) - tries += 1 - if tries > 100: - raise Exception("Timed out waiting for deferred to resolve") - - return deferred - def _check_current_logcontext(self, expected_logcontext_string: str) -> None: context = current_context() assert isinstance(context, LoggingContext) or isinstance(context, _Sentinel), ( @@ -168,7 +142,7 @@ async def do_request() -> None: raw_response = json_decoder.decode(resp_body.decode("utf-8")) self.assertEqual(raw_response, {"ok": True}) - self.get_success(self.till_deferred_has_result(do_request())) + self.get_success(do_request()) self.assertEqual(self.server.calls, 1) def test_request_response_limit_exceeded(self) -> None: @@ -183,8 +157,8 @@ async def do_request() -> None: response_limit=1, ) - self.assertFailure( - self.till_deferred_has_result(do_request()), + self.get_failure( + do_request(), RuntimeError, ) self.assertEqual(self.server.calls, 1) @@ -227,8 +201,15 @@ async def do_request() -> None: # Now wait for the function under test to have run with PreserveLoggingContext(): while not callback_finished: - # await self.hs.get_clock().sleep(0) - time.sleep(0.1) + # Allow the async Rust to run + # + # Suspend execution of this thread to allow other the Tokio thread + # pool to do work. + time.sleep(0) + # Advance the Twisted reactor and run any scheduled callbacks + # + # In terms of other threads, they may have scheduled something on the + # reactor to run (like `reactor.callFromThread(...)`) self.reactor.advance(0) # check that the logcontext is left in a sane state. diff --git a/tests/unittest.py b/tests/unittest.py index 93131521d03..5f7d0b3abf2 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -48,6 +48,7 @@ import unpaddedbase64 from typing_extensions import Concatenate, ParamSpec +from twisted.internet import defer from twisted.internet.defer import Deferred, ensureDeferred from twisted.internet.testing import MemoryReactor, MemoryReactorClock from twisted.python.failure import Failure @@ -76,7 +77,7 @@ from synapse.server import HomeServer from synapse.storage.keys import FetchKeyResult from synapse.types import ISynapseReactor, JsonDict, Requester, UserID, create_requester -from synapse.util.clock import Clock +from synapse.util.clock import CLOCK_SCHEDULE_EPSILON, Clock from synapse.util.httpresourcetree import create_resource_tree from tests.server import ( @@ -474,27 +475,13 @@ def tearDown(self) -> None: # Reset to not use frozen dicts. events.USE_FROZEN_DICTS = False - def wait_on_thread(self, deferred: Deferred, timeout: int = 10) -> None: - """ - Wait until a Deferred is done, where it's waiting on a real thread. - """ - start_time = time.time() - - while not deferred.called: - if start_time + timeout < time.time(): - raise ValueError("Timed out waiting for threadpool") - self.reactor.advance(0.01) - time.sleep(0.01) - def wait_for_background_updates(self) -> None: """Block until all background database updates have completed.""" store = self.hs.get_datastores().main while not self.get_success( store.db_pool.updates.has_completed_background_updates() ): - self.get_success( - store.db_pool.updates.do_next_background_update(False), by=0.1 - ) + self.get_success(store.db_pool.updates.do_next_background_update(False)) def make_homeserver( self, reactor: ThreadedMemoryReactorClock, clock: Clock @@ -736,21 +723,165 @@ def pump(self, by: float = 0.0) -> None: # whole chain to completion. self.reactor.pump([by] * 100) - def get_success(self, d: Awaitable[TV], by: float = 0.0) -> TV: + def _wait_for_deferred( + self, + d: "Deferred[Any]", + ) -> None: + """ + Wait for the deferred to finish or raise. + + Does not advance time in the Twisted reactor clock but will loop 100 times + waiting for a result. The loop 1) allows `clock.call_later` scheduled callbacks + to run if they are scheduled to run now and 2) will also allow other threads to + make progress. This could be things spawned on the Twisted reactor threadpool or + Tokio runtime (async Rust code). + + Args: + d: Twisted Deferred + + Raises: + defer.TimeoutError: If the timeout expires before the deferred completes. + """ + # Wait until the deferred has a result + # + # Checking `d.called` by itself is not sufficient by itself as this is possible: + # + # If you have a first `Deferred` `D1`, you can add a callback which returns + # another `Deferred` `D2`, and `D2` must then complete before any further + # callbacks on `D1` will execute (and later callbacks on `D1` get the *result* + # of `D2` rather than `D2` itself). + # + # So, `D1` might have `called=True` (as in, it has started running its + # callbacks), but any new callbacks added to `D1` won't get run until `D2` + # completes. Fortunately, we can detect this by checking `d.paused`. + loop_count = 0 + while not d.called or d.paused: + # 100 loops is arbitrary but based on previous code which used to "pump" and + # advance the reactor 100 times. This also makes the assumption that any + # work on other threads will finish before we give up after sleeping ~0.1s + # of real-time (100 * 0.001). + if loop_count > 100: + raise defer.TimeoutError("Timed out waiting for deferred to finish") + + # Suspend execution of this thread to allow other threads to do work. This + # could be things spawned on the Twisted reactor threadpool or Tokio thread + # pool (async Rust code). + # + # Note: Python has a default thread switch interval (5ms for cpython) (see + # `sys.setswitchinterval(interval)`) but we still want this here as we're + # able to preempt and cause the thread context swtich to happen faster. + # Also, without any real-time sleeping, this function would complete before + # the 5ms switch ever happened. + # + # After a few cycles, we use `time.sleep(0.001)` instead of `time.sleep(0)` + # to avoid tightlooping on the main thread (CPU 100%) because it's wasteful + # and may starve out other threads. 10 is arbitrary but many cases will have + # none or only a few round-trips so we can just try to go as fast as + # posssible. + if loop_count < 10: + time.sleep(0) + else: + time.sleep(0.001) + + # Advance the Twisted reactor and run any scheduled callbacks + # + # In terms of other threads, they may have scheduled something on the + # reactor to run (like `reactor.callFromThread(...)`) + # + # Ideally, we'd advance by `0` but the `Cooperator` used in our HTTP clients + # use `CLOCK_SCHEDULE_EPSILON` and we want to make usage in downstream tests + # as simple as possible. A common use case this helps with is anything that + # needs to make a HTTP request (like a replication requests) + self.reactor.advance(CLOCK_SCHEDULE_EPSILON.as_secs()) + + loop_count += 1 + + def get_success( + self, + d: Awaitable[TV], + ) -> TV: + """ + Get the success result of an awaitable. + + Does not advance time in the Twisted reactor clock but will loop 100 times + waiting for a result. The loop 1) allows `clock.call_later` scheduled callbacks + to run if they are scheduled to run now and 2) will also allow other threads to + make progress. This could be things spawned on the Twisted reactor threadpool or + Tokio runtime (async Rust code). + + If you need to advance the Twisted reactor by an actual time increment, you can + use the following pattern: + ```python + # We use `ensureDeferred(...)` as a `Deferred` can run in the background on its own (unlike a Python coroutine) + task_d = ensureDeferred(my_async_task()) + # Please explain why/what scheduled call you're trying to trigger + self.reactor.advance(Duration(seconds=1).as_secs()) + result = self.get_success(sync_d) + ``` + + Args: + d: awaitable + + Raises: + defer.TimeoutError: If the timeout expires before the awaitable completes. + SynchronousTestCase.failureException: If the awaitable has a failure result or has no result + (although you would probably run into `defer.TimeoutError` in that case). + """ deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type] - self.pump(by=by) + self._wait_for_deferred(deferred) + return self.successResultOf(deferred) def get_failure( - self, d: Awaitable[Any], exc: type[_ExcType], by: float = 0.0 + self, + d: Awaitable[Any], + exc: type[_ExcType], ) -> _TypedFailure[_ExcType]: """ - Run a Deferred and get a Failure from it. The failure must be of the type `exc`. + Get the failure result of an awaitable. The failure must be of the type `exc`. + + Does not advance time in the Twisted reactor clock but will loop 100 times + waiting for a result. The loop 1) allows `clock.call_later` scheduled callbacks + to run if they are scheduled to run now and 2) will also allow other threads to + make progress. This could be things spawned on the Twisted reactor threadpool or + Tokio runtime (async Rust code). + + If you need to advance the Twisted reactor by an actual time increment, you can + use the following pattern: + ```python + # We use `ensureDeferred(...)` as a `Deferred` can run in the background on its own (unlike a Python coroutine) + task_d = ensureDeferred(my_async_task()) + # Please explain why/what scheduled call you're trying to trigger + self.reactor.advance(Duration(seconds=1).as_secs()) + result = self.get_success(sync_d) + ``` + + Args: + d: awaitable + exc: Exception type to expect + + Raises: + defer.TimeoutError: If the timeout expires before the awaitable completes. + SynchronousTestCase.failureException: If the awaitable has a success result, + or has an unexpected failure result, or has no result (although you would + probably run into `defer.TimeoutError` in that case). """ deferred: Deferred[Any] = ensureDeferred(d) # type: ignore[arg-type] - self.pump(by) + self._wait_for_deferred(deferred) + return self.failureResultOf(deferred, exc) + # FIXME: Remove as this has the exact same semantics as `get_success()`. In + # https://github.com/matrix-org/synapse/pull/8402#discussion_r495992506 where it was + # introduced, it was claimed that "get_success fails the test if the deferred fails + # rather than raising, which I find a bit unintuitive." but `get_success()` actually + # does raise "@raise SynchronousTestCase.failureException : If the + # L{Deferred} has no result or has a failure + # result." at-least in today's world. + # + # As another alternative, we could also just update `get_success(...)` to have this + # behavior as the default, see + # https://github.com/element-hq/synapse/pull/19871#discussion_r3483616710 def get_success_or_raise(self, d: Awaitable[TV], by: float = 0.0) -> TV: """Drive deferred to completion and return result or raise exception on failure. diff --git a/tests/util/test_task_scheduler.py b/tests/util/test_task_scheduler.py index 94c1d778e63..cab9695d33b 100644 --- a/tests/util/test_task_scheduler.py +++ b/tests/util/test_task_scheduler.py @@ -260,7 +260,7 @@ async def _incrementing_running_task( await self.task_scheduler.update_task( task.id, result={"counter": current_counter} ) - await self.hs.get_clock().sleep(Duration(microseconds=1)) + await self.hs.get_clock().sleep(Duration(seconds=1)) return TaskStatus.COMPLETE, None, None # type: ignore[unreachable]