diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 0000000000..5cbc2d3f5f --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,4 @@ +# do not notify until at least 100 builds have been uploaded from the CI pipeline +# you can also set after_n_builds on comments independently +comment: + after_n_builds: 100 diff --git a/.evergreen/generated_configs/functions.yml b/.evergreen/generated_configs/functions.yml index 58bffbf922..8906e82792 100644 --- a/.evergreen/generated_configs/functions.yml +++ b/.evergreen/generated_configs/functions.yml @@ -111,6 +111,8 @@ functions: - LOAD_BALANCER - LOCAL_ATLAS - NO_EXT + - PYMONGO_BUILD_RUST + - PYMONGO_USE_RUST type: test - command: expansions.update params: @@ -152,6 +154,8 @@ functions: - IS_WIN32 - REQUIRE_FIPS - TEST_MIN_DEPS + - PYMONGO_BUILD_RUST + - PYMONGO_USE_RUST type: test - command: subprocess.exec params: @@ -250,6 +254,7 @@ functions: working_dir: src include_expansions_in_env: - TOOLCHAIN_VERSION + - COVERAGE type: test # Upload coverage codecov @@ -268,7 +273,7 @@ functions: - github_pr_number - github_pr_head_branch - github_author - - is_patch + - requester - branch_name type: test diff --git a/.evergreen/generated_configs/tasks.yml b/.evergreen/generated_configs/tasks.yml index 60ee6ed135..87d5c2261c 100644 --- a/.evergreen/generated_configs/tasks.yml +++ b/.evergreen/generated_configs/tasks.yml @@ -75,7 +75,7 @@ tasks: SUB_TEST_NAME: session-creds TOOLCHAIN_VERSION: 3.14t tags: [auth-aws, auth-aws-session-creds, free-threaded] - - name: test-auth-aws-rapid-web-identity-python3.14 + - name: test-auth-aws-rapid-web-identity-python3.14-cov commands: - func: run server vars: @@ -87,7 +87,8 @@ tasks: TEST_NAME: auth_aws SUB_TEST_NAME: web-identity TOOLCHAIN_VERSION: "3.14" - tags: [auth-aws, auth-aws-web-identity] + COVERAGE: "1" + tags: [auth-aws, auth-aws-web-identity, pr] - name: test-auth-aws-rapid-web-identity-session-name-python3.14 commands: - func: run server @@ -904,7 +905,7 @@ tasks: - ocsp-ecdsa - rapid - ocsp-staple - - name: test-ocsp-ecdsa-valid-cert-server-staples-latest-python3.14 + - name: test-ocsp-ecdsa-valid-cert-server-staples-latest-python3.14-cov commands: - func: run tests vars: @@ -913,11 +914,13 @@ tasks: TEST_NAME: ocsp TOOLCHAIN_VERSION: "3.14" VERSION: latest + COVERAGE: "1" tags: - ocsp - ocsp-ecdsa - latest - ocsp-staple + - pr - name: test-ocsp-ecdsa-invalid-cert-server-staples-v4.4-python3.10-min-deps commands: - func: run tests @@ -1928,7 +1931,7 @@ tasks: - ocsp-rsa - rapid - ocsp-staple - - name: test-ocsp-rsa-valid-cert-server-staples-latest-python3.14 + - name: test-ocsp-rsa-valid-cert-server-staples-latest-python3.14-cov commands: - func: run tests vars: @@ -1937,11 +1940,13 @@ tasks: TEST_NAME: ocsp TOOLCHAIN_VERSION: "3.14" VERSION: latest + COVERAGE: "1" tags: - ocsp - ocsp-rsa - latest - ocsp-staple + - pr - name: test-ocsp-rsa-invalid-cert-server-staples-v4.4-python3.10-min-deps commands: - func: run tests @@ -2554,6 +2559,21 @@ tasks: - func: attach benchmark test results - func: send dashboard data tags: [perf] + - name: perf-8.0-standalone-ssl-rust + commands: + - func: run server + vars: + VERSION: v8.0-perf + SSL: ssl + - func: run tests + vars: + TEST_NAME: perf + SUB_TEST_NAME: rust + PYMONGO_BUILD_RUST: "1" + PYMONGO_USE_RUST: "1" + - func: attach benchmark test results + - func: send dashboard data + tags: [perf] - name: perf-8.0-standalone commands: - func: run server @@ -2580,6 +2600,21 @@ tasks: - func: attach benchmark test results - func: send dashboard data tags: [perf] + - name: perf-8.0-standalone-rust + commands: + - func: run server + vars: + VERSION: v8.0-perf + SSL: nossl + - func: run tests + vars: + TEST_NAME: perf + SUB_TEST_NAME: rust + PYMONGO_BUILD_RUST: "1" + PYMONGO_USE_RUST: "1" + - func: attach benchmark test results + - func: send dashboard data + tags: [perf] # Search index tests - name: test-search-index-helpers @@ -2615,20 +2650,18 @@ tasks: - replica_set-auth-nossl - async - free-threaded - - name: test-server-version-python3.13-sync-auth-nossl-replica-set-cov + - name: test-server-version-python3.13-sync-auth-nossl-replica-set commands: - func: run server vars: AUTH: auth SSL: nossl TOPOLOGY: replica_set - COVERAGE: "1" - func: run tests vars: AUTH: auth SSL: nossl TOPOLOGY: replica_set - COVERAGE: "1" TOOLCHAIN_VERSION: "3.13" TEST_NAME: default_sync tags: @@ -2636,20 +2669,18 @@ tasks: - python-3.13 - replica_set-auth-nossl - sync - - name: test-server-version-python3.12-async-auth-ssl-replica-set-cov + - name: test-server-version-python3.12-async-auth-ssl-replica-set commands: - func: run server vars: AUTH: auth SSL: ssl TOPOLOGY: replica_set - COVERAGE: "1" - func: run tests vars: AUTH: auth SSL: ssl TOPOLOGY: replica_set - COVERAGE: "1" TOOLCHAIN_VERSION: "3.12" TEST_NAME: default_async tags: @@ -2657,20 +2688,18 @@ tasks: - python-3.12 - replica_set-auth-ssl - async - - name: test-server-version-python3.11-sync-auth-ssl-replica-set-cov + - name: test-server-version-python3.11-sync-auth-ssl-replica-set commands: - func: run server vars: AUTH: auth SSL: ssl TOPOLOGY: replica_set - COVERAGE: "1" - func: run tests vars: AUTH: auth SSL: ssl TOPOLOGY: replica_set - COVERAGE: "1" TOOLCHAIN_VERSION: "3.11" TEST_NAME: default_sync tags: @@ -2743,20 +2772,18 @@ tasks: - python-pypy3.11 - replica_set-noauth-ssl - async - - name: test-server-version-python3.14-sync-noauth-ssl-replica-set-cov + - name: test-server-version-python3.14-sync-noauth-ssl-replica-set commands: - func: run server vars: AUTH: noauth SSL: ssl TOPOLOGY: replica_set - COVERAGE: "1" - func: run tests vars: AUTH: noauth SSL: ssl TOPOLOGY: replica_set - COVERAGE: "1" TOOLCHAIN_VERSION: "3.14" TEST_NAME: default_sync tags: @@ -2764,20 +2791,18 @@ tasks: - python-3.14 - replica_set-noauth-ssl - sync - - name: test-server-version-python3.14-async-auth-nossl-sharded-cluster-cov + - name: test-server-version-python3.14-async-auth-nossl-sharded-cluster commands: - func: run server vars: AUTH: auth SSL: nossl TOPOLOGY: sharded_cluster - COVERAGE: "1" - func: run tests vars: AUTH: auth SSL: nossl TOPOLOGY: sharded_cluster - COVERAGE: "1" TOOLCHAIN_VERSION: "3.14" TEST_NAME: default_async tags: @@ -2829,20 +2854,18 @@ tasks: - sharded_cluster-auth-ssl - async - pr - - name: test-server-version-python3.11-async-auth-ssl-sharded-cluster-cov + - name: test-server-version-python3.11-async-auth-ssl-sharded-cluster commands: - func: run server vars: AUTH: auth SSL: ssl TOPOLOGY: sharded_cluster - COVERAGE: "1" - func: run tests vars: AUTH: auth SSL: ssl TOPOLOGY: sharded_cluster - COVERAGE: "1" TOOLCHAIN_VERSION: "3.11" TEST_NAME: default_async tags: @@ -2850,20 +2873,18 @@ tasks: - python-3.11 - sharded_cluster-auth-ssl - async - - name: test-server-version-python3.12-async-auth-ssl-sharded-cluster-cov + - name: test-server-version-python3.12-async-auth-ssl-sharded-cluster commands: - func: run server vars: AUTH: auth SSL: ssl TOPOLOGY: sharded_cluster - COVERAGE: "1" - func: run tests vars: AUTH: auth SSL: ssl TOPOLOGY: sharded_cluster - COVERAGE: "1" TOOLCHAIN_VERSION: "3.12" TEST_NAME: default_async tags: @@ -2871,20 +2892,18 @@ tasks: - python-3.12 - sharded_cluster-auth-ssl - async - - name: test-server-version-python3.13-async-auth-ssl-sharded-cluster-cov + - name: test-server-version-python3.13-async-auth-ssl-sharded-cluster commands: - func: run server vars: AUTH: auth SSL: ssl TOPOLOGY: sharded_cluster - COVERAGE: "1" - func: run tests vars: AUTH: auth SSL: ssl TOPOLOGY: sharded_cluster - COVERAGE: "1" TOOLCHAIN_VERSION: "3.13" TEST_NAME: default_async tags: @@ -2892,20 +2911,18 @@ tasks: - python-3.13 - sharded_cluster-auth-ssl - async - - name: test-server-version-python3.14-async-auth-ssl-sharded-cluster-cov + - name: test-server-version-python3.14-async-auth-ssl-sharded-cluster commands: - func: run server vars: AUTH: auth SSL: ssl TOPOLOGY: sharded_cluster - COVERAGE: "1" - func: run tests vars: AUTH: auth SSL: ssl TOPOLOGY: sharded_cluster - COVERAGE: "1" TOOLCHAIN_VERSION: "3.14" TEST_NAME: default_async tags: @@ -2976,20 +2993,18 @@ tasks: - sharded_cluster-auth-ssl - sync - pr - - name: test-server-version-python3.11-sync-auth-ssl-sharded-cluster-cov + - name: test-server-version-python3.11-sync-auth-ssl-sharded-cluster commands: - func: run server vars: AUTH: auth SSL: ssl TOPOLOGY: sharded_cluster - COVERAGE: "1" - func: run tests vars: AUTH: auth SSL: ssl TOPOLOGY: sharded_cluster - COVERAGE: "1" TOOLCHAIN_VERSION: "3.11" TEST_NAME: default_sync tags: @@ -2997,20 +3012,18 @@ tasks: - python-3.11 - sharded_cluster-auth-ssl - sync - - name: test-server-version-python3.12-sync-auth-ssl-sharded-cluster-cov + - name: test-server-version-python3.12-sync-auth-ssl-sharded-cluster commands: - func: run server vars: AUTH: auth SSL: ssl TOPOLOGY: sharded_cluster - COVERAGE: "1" - func: run tests vars: AUTH: auth SSL: ssl TOPOLOGY: sharded_cluster - COVERAGE: "1" TOOLCHAIN_VERSION: "3.12" TEST_NAME: default_sync tags: @@ -3018,20 +3031,18 @@ tasks: - python-3.12 - sharded_cluster-auth-ssl - sync - - name: test-server-version-python3.13-sync-auth-ssl-sharded-cluster-cov + - name: test-server-version-python3.13-sync-auth-ssl-sharded-cluster commands: - func: run server vars: AUTH: auth SSL: ssl TOPOLOGY: sharded_cluster - COVERAGE: "1" - func: run tests vars: AUTH: auth SSL: ssl TOPOLOGY: sharded_cluster - COVERAGE: "1" TOOLCHAIN_VERSION: "3.13" TEST_NAME: default_sync tags: @@ -3039,20 +3050,18 @@ tasks: - python-3.13 - sharded_cluster-auth-ssl - sync - - name: test-server-version-python3.14-sync-auth-ssl-sharded-cluster-cov + - name: test-server-version-python3.14-sync-auth-ssl-sharded-cluster commands: - func: run server vars: AUTH: auth SSL: ssl TOPOLOGY: sharded_cluster - COVERAGE: "1" - func: run tests vars: AUTH: auth SSL: ssl TOPOLOGY: sharded_cluster - COVERAGE: "1" TOOLCHAIN_VERSION: "3.14" TEST_NAME: default_sync tags: @@ -3099,20 +3108,18 @@ tasks: - python-pypy3.11 - sharded_cluster-auth-ssl - sync - - name: test-server-version-python3.12-async-noauth-nossl-sharded-cluster-cov + - name: test-server-version-python3.12-async-noauth-nossl-sharded-cluster commands: - func: run server vars: AUTH: noauth SSL: nossl TOPOLOGY: sharded_cluster - COVERAGE: "1" - func: run tests vars: AUTH: noauth SSL: nossl TOPOLOGY: sharded_cluster - COVERAGE: "1" TOOLCHAIN_VERSION: "3.12" TEST_NAME: default_async tags: @@ -3120,20 +3127,18 @@ tasks: - python-3.12 - sharded_cluster-noauth-nossl - async - - name: test-server-version-python3.11-sync-noauth-nossl-sharded-cluster-cov + - name: test-server-version-python3.11-sync-noauth-nossl-sharded-cluster commands: - func: run server vars: AUTH: noauth SSL: nossl TOPOLOGY: sharded_cluster - COVERAGE: "1" - func: run tests vars: AUTH: noauth SSL: nossl TOPOLOGY: sharded_cluster - COVERAGE: "1" TOOLCHAIN_VERSION: "3.11" TEST_NAME: default_sync tags: @@ -3141,7 +3146,7 @@ tasks: - python-3.11 - sharded_cluster-noauth-nossl - sync - - name: test-server-version-python3.10-async-noauth-ssl-sharded-cluster-min-deps-cov + - name: test-server-version-python3.10-async-noauth-ssl-sharded-cluster-min-deps commands: - func: run server vars: @@ -3149,14 +3154,12 @@ tasks: SSL: ssl TOPOLOGY: sharded_cluster TEST_MIN_DEPS: "1" - COVERAGE: "1" - func: run tests vars: AUTH: noauth SSL: ssl TOPOLOGY: sharded_cluster TEST_MIN_DEPS: "1" - COVERAGE: "1" TOOLCHAIN_VERSION: "3.10" TEST_NAME: default_async tags: @@ -3183,20 +3186,18 @@ tasks: - python-pypy3.11 - sharded_cluster-noauth-ssl - sync - - name: test-server-version-python3.13-async-auth-nossl-standalone-cov + - name: test-server-version-python3.13-async-auth-nossl-standalone commands: - func: run server vars: AUTH: auth SSL: nossl TOPOLOGY: standalone - COVERAGE: "1" - func: run tests vars: AUTH: auth SSL: nossl TOPOLOGY: standalone - COVERAGE: "1" TOOLCHAIN_VERSION: "3.13" TEST_NAME: default_async tags: @@ -3204,20 +3205,18 @@ tasks: - python-3.13 - standalone-auth-nossl - async - - name: test-server-version-python3.12-sync-auth-nossl-standalone-cov + - name: test-server-version-python3.12-sync-auth-nossl-standalone commands: - func: run server vars: AUTH: auth SSL: nossl TOPOLOGY: standalone - COVERAGE: "1" - func: run tests vars: AUTH: auth SSL: nossl TOPOLOGY: standalone - COVERAGE: "1" TOOLCHAIN_VERSION: "3.12" TEST_NAME: default_sync tags: @@ -3225,20 +3224,18 @@ tasks: - python-3.12 - standalone-auth-nossl - sync - - name: test-server-version-python3.11-async-auth-ssl-standalone-cov + - name: test-server-version-python3.11-async-auth-ssl-standalone commands: - func: run server vars: AUTH: auth SSL: ssl TOPOLOGY: standalone - COVERAGE: "1" - func: run tests vars: AUTH: auth SSL: ssl TOPOLOGY: standalone - COVERAGE: "1" TOOLCHAIN_VERSION: "3.11" TEST_NAME: default_async tags: @@ -3246,7 +3243,7 @@ tasks: - python-3.11 - standalone-auth-ssl - async - - name: test-server-version-python3.10-sync-auth-ssl-standalone-min-deps-cov + - name: test-server-version-python3.10-sync-auth-ssl-standalone-min-deps commands: - func: run server vars: @@ -3254,14 +3251,12 @@ tasks: SSL: ssl TOPOLOGY: standalone TEST_MIN_DEPS: "1" - COVERAGE: "1" - func: run tests vars: AUTH: auth SSL: ssl TOPOLOGY: standalone TEST_MIN_DEPS: "1" - COVERAGE: "1" TOOLCHAIN_VERSION: "3.10" TEST_NAME: default_sync tags: @@ -3293,18 +3288,20 @@ tasks: - standalone-noauth-nossl - async - pr - - name: test-server-version-pypy3.11-sync-noauth-nossl-standalone + - name: test-server-version-pypy3.11-sync-noauth-nossl-standalone-cov commands: - func: run server vars: AUTH: noauth SSL: nossl TOPOLOGY: standalone + COVERAGE: "1" - func: run tests vars: AUTH: noauth SSL: nossl TOPOLOGY: standalone + COVERAGE: "1" TOOLCHAIN_VERSION: pypy3.11 TEST_NAME: default_sync tags: @@ -3313,20 +3310,18 @@ tasks: - standalone-noauth-nossl - sync - pr - - name: test-server-version-python3.14-async-noauth-ssl-standalone-cov + - name: test-server-version-python3.14-async-noauth-ssl-standalone commands: - func: run server vars: AUTH: noauth SSL: ssl TOPOLOGY: standalone - COVERAGE: "1" - func: run tests vars: AUTH: noauth SSL: ssl TOPOLOGY: standalone - COVERAGE: "1" TOOLCHAIN_VERSION: "3.14" TEST_NAME: default_async tags: @@ -4082,7 +4077,7 @@ tasks: - standalone-noauth-nossl - async - pypy - - name: test-standard-latest-python3.12-async-noauth-ssl-replica-set + - name: test-standard-latest-python3.12-async-noauth-ssl-replica-set-cov commands: - func: run server vars: @@ -4090,12 +4085,14 @@ tasks: SSL: ssl TOPOLOGY: replica_set VERSION: latest + COVERAGE: "1" - func: run tests vars: AUTH: noauth SSL: ssl TOPOLOGY: replica_set VERSION: latest + COVERAGE: "1" TOOLCHAIN_VERSION: "3.12" TEST_NAME: default_async tags: @@ -4128,7 +4125,7 @@ tasks: - replica_set-noauth-ssl - async - pypy - - name: test-standard-latest-python3.13-async-auth-ssl-sharded-cluster + - name: test-standard-latest-python3.13-async-auth-ssl-sharded-cluster-cov commands: - func: run server vars: @@ -4136,12 +4133,14 @@ tasks: SSL: ssl TOPOLOGY: sharded_cluster VERSION: latest + COVERAGE: "1" - func: run tests vars: AUTH: auth SSL: ssl TOPOLOGY: sharded_cluster VERSION: latest + COVERAGE: "1" TOOLCHAIN_VERSION: "3.13" TEST_NAME: default_async tags: @@ -4151,7 +4150,7 @@ tasks: - sharded_cluster-auth-ssl - async - pr - - name: test-standard-latest-python3.11-async-noauth-nossl-standalone + - name: test-standard-latest-python3.11-async-noauth-nossl-standalone-cov commands: - func: run server vars: @@ -4159,12 +4158,14 @@ tasks: SSL: nossl TOPOLOGY: standalone VERSION: latest + COVERAGE: "1" - func: run tests vars: AUTH: noauth SSL: nossl TOPOLOGY: standalone VERSION: latest + COVERAGE: "1" TOOLCHAIN_VERSION: "3.11" TEST_NAME: default_async tags: @@ -4174,7 +4175,7 @@ tasks: - standalone-noauth-nossl - async - pr - - name: test-standard-latest-python3.14-async-noauth-nossl-standalone + - name: test-standard-latest-python3.14-async-noauth-nossl-standalone-cov commands: - func: run server vars: @@ -4182,12 +4183,14 @@ tasks: SSL: nossl TOPOLOGY: standalone VERSION: latest + COVERAGE: "1" - func: run tests vars: AUTH: noauth SSL: nossl TOPOLOGY: standalone VERSION: latest + COVERAGE: "1" TOOLCHAIN_VERSION: "3.14" TEST_NAME: default_async tags: @@ -4829,7 +4832,7 @@ tasks: - python-3.13 - standalone-noauth-nossl - noauth - - name: test-non-standard-latest-python3.14t-noauth-ssl-replica-set + - name: test-non-standard-latest-python3.14t-noauth-ssl-replica-set-cov commands: - func: run server vars: @@ -4837,12 +4840,14 @@ tasks: SSL: ssl TOPOLOGY: replica_set VERSION: latest + COVERAGE: "1" - func: run tests vars: AUTH: noauth SSL: ssl TOPOLOGY: replica_set VERSION: latest + COVERAGE: "1" TOOLCHAIN_VERSION: 3.14t tags: - test-non-standard @@ -4874,7 +4879,7 @@ tasks: - replica_set-noauth-ssl - noauth - pypy - - name: test-non-standard-latest-python3.14-auth-ssl-sharded-cluster + - name: test-non-standard-latest-python3.14-auth-ssl-sharded-cluster-cov commands: - func: run server vars: @@ -4882,12 +4887,14 @@ tasks: SSL: ssl TOPOLOGY: sharded_cluster VERSION: latest + COVERAGE: "1" - func: run tests vars: AUTH: auth SSL: ssl TOPOLOGY: sharded_cluster VERSION: latest + COVERAGE: "1" TOOLCHAIN_VERSION: "3.14" tags: - test-non-standard @@ -4896,7 +4903,7 @@ tasks: - sharded_cluster-auth-ssl - auth - pr - - name: test-non-standard-latest-python3.13-noauth-nossl-standalone + - name: test-non-standard-latest-python3.13-noauth-nossl-standalone-cov commands: - func: run server vars: @@ -4904,12 +4911,14 @@ tasks: SSL: nossl TOPOLOGY: standalone VERSION: latest + COVERAGE: "1" - func: run tests vars: AUTH: noauth SSL: nossl TOPOLOGY: standalone VERSION: latest + COVERAGE: "1" TOOLCHAIN_VERSION: "3.13" tags: - test-non-standard @@ -5007,7 +5016,7 @@ tasks: - pypy # Test numpy tests - - name: test-numpy-python3.10 + - name: test-numpy-python3.10-python3.10 commands: - func: test numpy vars: @@ -5017,16 +5026,18 @@ tasks: - vector - python-3.10 - test-numpy - - name: test-numpy-python3.14 + - name: test-numpy-python3.14-python3.14-cov commands: - func: test numpy vars: TOOLCHAIN_VERSION: "3.14" + COVERAGE: "1" tags: - binary - vector - python-3.14 - test-numpy + - pr # Test standard auth tests - name: test-standard-auth-v4.2-python3.10-auth-ssl-sharded-cluster-min-deps @@ -5290,7 +5301,7 @@ tasks: - sharded_cluster-auth-ssl - auth - pypy - - name: test-standard-auth-latest-python3.11-auth-ssl-sharded-cluster + - name: test-standard-auth-latest-python3.11-auth-ssl-sharded-cluster-cov commands: - func: run server vars: @@ -5298,12 +5309,14 @@ tasks: SSL: ssl TOPOLOGY: sharded_cluster VERSION: latest + COVERAGE: "1" - func: run tests vars: AUTH: auth SSL: ssl TOPOLOGY: sharded_cluster VERSION: latest + COVERAGE: "1" TOOLCHAIN_VERSION: "3.11" tags: - test-standard-auth diff --git a/.evergreen/generated_configs/variants.yml b/.evergreen/generated_configs/variants.yml index edca050240..f37c2b0efd 100644 --- a/.evergreen/generated_configs/variants.yml +++ b/.evergreen/generated_configs/variants.yml @@ -368,7 +368,6 @@ buildvariants: run_on: - rhel87-small expansions: - COVERAGE: "1" NO_EXT: "1" # No server tests @@ -420,6 +419,8 @@ buildvariants: run_on: - ubuntu2204-small batchtime: 1440 + expansions: + COVERAGE: "1" tags: [pr] - name: auth-oidc-macos tasks: @@ -477,6 +478,40 @@ buildvariants: expansions: SUB_TEST_NAME: pyopenssl + # Rust tests + - name: test-with-rust-extension + tasks: + - name: .test-standard .server-latest .pr + display_name: Test with Rust Extension + run_on: + - rhel87-small + expansions: + PYMONGO_BUILD_RUST: "1" + PYMONGO_USE_RUST: "1" + tags: [rust, pr] + - name: test-with-rust-extension---macos-arm64 + tasks: + - name: .test-standard .server-latest !.pr + display_name: Test with Rust Extension - macOS ARM64 + run_on: + - macos-14-arm64 + batchtime: 10080 + expansions: + PYMONGO_BUILD_RUST: "1" + PYMONGO_USE_RUST: "1" + tags: [rust] + - name: test-with-rust-extension---windows + tasks: + - name: .test-standard .server-latest !.pr + display_name: Test with Rust Extension - Windows + run_on: + - windows-64-vsMulti-small + batchtime: 10080 + expansions: + PYMONGO_BUILD_RUST: "1" + PYMONGO_USE_RUST: "1" + tags: [rust] + # Search index tests - name: search-index-helpers-rhel8 tasks: diff --git a/.evergreen/resync-specs.sh b/.evergreen/resync-specs.sh index d2bd89c781..4bb9c86304 100755 --- a/.evergreen/resync-specs.sh +++ b/.evergreen/resync-specs.sh @@ -94,6 +94,9 @@ do change-streams|change_streams) cpjson change-streams/tests/ change_streams/ ;; + client-backpressure|client_backpressure) + cpjson client-backpressure/tests client-backpressure + ;; client-side-encryption|csfle|fle) cpjson client-side-encryption/tests/ client-side-encryption/spec cpjson client-side-encryption/corpus/ client-side-encryption/corpus diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index 095b7938dc..0785bcf01d 100755 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -38,6 +38,7 @@ trap "cleanup_tests" SIGINT ERR # Start the test runner. echo "Running tests with UV_PYTHON=${UV_PYTHON:-}..." +echo "UV_ARGS=${UV_ARGS}" uv run ${UV_ARGS} --reinstall-package pymongo .evergreen/scripts/run_tests.py "$@" echo "Running tests with UV_PYTHON=${UV_PYTHON:-}... done." diff --git a/.evergreen/scripts/configure-env.sh b/.evergreen/scripts/configure-env.sh index 8dc328aab3..101812ede6 100755 --- a/.evergreen/scripts/configure-env.sh +++ b/.evergreen/scripts/configure-env.sh @@ -14,6 +14,7 @@ fi PROJECT_DIRECTORY="$(pwd)" DRIVERS_TOOLS="$(dirname $PROJECT_DIRECTORY)/drivers-tools" CARGO_HOME=${CARGO_HOME:-${DRIVERS_TOOLS}/.cargo} +RUSTUP_HOME=${RUSTUP_HOME:-${CARGO_HOME}} UV_TOOL_DIR=$PROJECT_DIRECTORY/.local/uv/tools UV_CACHE_DIR=$PROJECT_DIRECTORY/.local/uv/cache DRIVERS_TOOLS_BINARIES="$DRIVERS_TOOLS/.bin" @@ -27,13 +28,14 @@ else PYMONGO_BIN_DIR=$HOME/cli_bin fi -PATH_EXT="$MONGODB_BINARIES:$DRIVERS_TOOLS_BINARIES:$PYMONGO_BIN_DIR:\$PATH" +PATH_EXT="$MONGODB_BINARIES:$DRIVERS_TOOLS_BINARIES:$PYMONGO_BIN_DIR:$CARGO_HOME/bin:\$PATH" # Python has cygwin path problems on Windows. Detect prospective mongo-orchestration home directory if [ "Windows_NT" = "${OS:-}" ]; then # Magic variable in cygwin DRIVERS_TOOLS=$(cygpath -m $DRIVERS_TOOLS) PROJECT_DIRECTORY=$(cygpath -m $PROJECT_DIRECTORY) CARGO_HOME=$(cygpath -m $CARGO_HOME) + RUSTUP_HOME=$(cygpath -m $RUSTUP_HOME) UV_TOOL_DIR=$(cygpath -m "$UV_TOOL_DIR") UV_CACHE_DIR=$(cygpath -m "$UV_CACHE_DIR") DRIVERS_TOOLS_BINARIES=$(cygpath -m "$DRIVERS_TOOLS_BINARIES") @@ -62,6 +64,7 @@ export DRIVERS_TOOLS_BINARIES="$DRIVERS_TOOLS_BINARIES" export PROJECT_DIRECTORY="$PROJECT_DIRECTORY" export CARGO_HOME="$CARGO_HOME" +export RUSTUP_HOME="$RUSTUP_HOME" export UV_TOOL_DIR="$UV_TOOL_DIR" export UV_CACHE_DIR="$UV_CACHE_DIR" export UV_TOOL_BIN_DIR="$DRIVERS_TOOLS_BINARIES" diff --git a/.evergreen/scripts/generate_config.py b/.evergreen/scripts/generate_config.py index 3375b9e14e..20da339ac3 100644 --- a/.evergreen/scripts/generate_config.py +++ b/.evergreen/scripts/generate_config.py @@ -318,7 +318,7 @@ def create_green_framework_variants(): def create_no_c_ext_variants(): host = DEFAULT_HOST tasks = [".test-standard"] - expansions = dict(COVERAGE="1") + expansions = dict() handle_c_ext(C_EXTS[0], expansions) display_name = get_variant_name("No C Ext", host) return [create_variant(tasks, display_name, host=host, expansions=expansions)] @@ -344,8 +344,12 @@ def create_test_numpy_tasks(): tasks = [] for python in MIN_MAX_PYTHON: tags = ["binary", "vector", f"python-{python}", "test-numpy"] - task_name = get_task_name("test-numpy", python=python) - test_func = FunctionCall(func="test numpy", vars=dict(TOOLCHAIN_VERSION=python)) + vars = dict(TOOLCHAIN_VERSION=python) + if python == MIN_MAX_PYTHON[-1]: + tags.append("pr") + vars["COVERAGE"] = "1" + task_name = get_task_name("test-numpy", python=python, **vars) + test_func = FunctionCall(func="test numpy", vars=vars) tasks.append(EvgTask(name=task_name, tags=tags, commands=[test_func])) return tasks @@ -397,6 +401,7 @@ def create_oidc_auth_variants(): tags=["pr"], host=host, batchtime=BATCHTIME_DAY, + expansions=dict(COVERAGE="1"), ) ) return variants @@ -596,7 +601,7 @@ def create_server_version_tasks(): expansions["TEST_MIN_DEPS"] = "1" if "t" in python: tags.append("free-threaded") - if python not in PYPYS and "t" not in python: + if "pr" in tags: expansions["COVERAGE"] = "1" name = get_task_name( "test-server-version", @@ -661,6 +666,8 @@ def create_test_non_standard_tasks(): expansions = dict(AUTH=auth, SSL=ssl, TOPOLOGY=topology, VERSION=version) if python == ALL_PYTHONS[0]: expansions["TEST_MIN_DEPS"] = "1" + elif pr: + expansions["COVERAGE"] = "1" name = get_task_name("test-non-standard", python=python, **expansions) server_func = FunctionCall(func="run server", vars=expansions) test_vars = expansions.copy() @@ -703,6 +710,8 @@ def create_test_standard_auth_tasks(): expansions = dict(AUTH=auth, SSL=ssl, TOPOLOGY=topology, VERSION=version) if python == ALL_PYTHONS[0]: expansions["TEST_MIN_DEPS"] = "1" + elif pr: + expansions["COVERAGE"] = "1" name = get_task_name("test-standard-auth", python=python, **expansions) server_func = FunctionCall(func="run server", vars=expansions) test_vars = expansions.copy() @@ -741,6 +750,8 @@ def create_standard_tasks(): expansions = dict(AUTH=auth, SSL=ssl, TOPOLOGY=topology, VERSION=version) if python == ALL_PYTHONS[0]: expansions["TEST_MIN_DEPS"] = "1" + elif pr: + expansions["COVERAGE"] = "1" name = get_task_name("test-standard", python=python, sync=sync, **expansions) server_func = FunctionCall(func="run server", vars=expansions) test_vars = expansions.copy() @@ -810,8 +821,11 @@ def create_aws_tasks(): if "t" in python: tags.append("free-threaded") test_vars = dict(TEST_NAME="auth_aws", SUB_TEST_NAME=test_type, TOOLCHAIN_VERSION=python) - if python == ALL_PYTHONS[0]: + if python == MIN_MAX_PYTHON[0]: test_vars["TEST_MIN_DEPS"] = "1" + elif python == MIN_MAX_PYTHON[-1]: + tags.append("pr") + test_vars["COVERAGE"] = "1" name = get_task_name(f"{base_name}-{test_type}", **test_vars) test_func = FunctionCall(func="run tests", vars=test_vars) funcs = [server_func, assume_func, test_func] @@ -849,11 +863,11 @@ def create_oidc_tasks(): tasks = [] for sub_test in ["default", "azure", "gcp", "eks", "aks", "gke"]: vars = dict(TEST_NAME="auth_oidc", SUB_TEST_NAME=sub_test) - test_func = FunctionCall(func="run tests", vars=vars) - task_name = f"test-auth-oidc-{sub_test}" tags = ["auth_oidc"] if sub_test != "default": tags.append("auth_oidc_remote") + test_func = FunctionCall(func="run tests", vars=vars) + task_name = get_task_name(f"test-auth-oidc-{sub_test}", **vars) tasks.append(EvgTask(name=task_name, tags=tags, commands=[test_func])) return tasks @@ -903,14 +917,14 @@ def _create_ocsp_tasks(algo, variant, server_type, base_task_name): ) if python == ALL_PYTHONS[0]: vars["TEST_MIN_DEPS"] = "1" - test_func = FunctionCall(func="run tests", vars=vars) - tags = ["ocsp", f"ocsp-{algo}", version] if "disableStapling" not in variant: tags.append("ocsp-staple") - if algo == "valid-cert-server-staples" and version == "latest": + if base_task_name == "valid-cert-server-staples" and version == "latest": tags.append("pr") - + if "TEST_MIN_DEPS" not in vars: + vars["COVERAGE"] = "1" + test_func = FunctionCall(func="run tests", vars=vars) task_name = get_task_name(f"test-ocsp-{algo}-{base_task_name}", **vars) tasks.append(EvgTask(name=task_name, tags=tags, commands=[test_func])) @@ -958,11 +972,15 @@ def create_search_index_tasks(): def create_perf_tasks(): tasks = [] - for version, ssl, sync in product(["8.0"], ["ssl", "nossl"], ["sync", "async"]): + for version, ssl, sync in product(["8.0"], ["ssl", "nossl"], ["sync", "async", "rust"]): vars = dict(VERSION=f"v{version}-perf", SSL=ssl) server_func = FunctionCall(func="run server", vars=vars) - vars = dict(TEST_NAME="perf", SUB_TEST_NAME=sync) - test_func = FunctionCall(func="run tests", vars=vars) + test_vars = dict(TEST_NAME="perf", SUB_TEST_NAME=sync) + # Enable Rust for rust perf tests + if sync == "rust": + test_vars["PYMONGO_BUILD_RUST"] = "1" + test_vars["PYMONGO_USE_RUST"] = "1" + test_func = FunctionCall(func="run tests", vars=test_vars) attach_func = FunctionCall(func="attach benchmark test results") send_func = FunctionCall(func="send dashboard data") task_name = f"perf-{version}-standalone" @@ -970,6 +988,8 @@ def create_perf_tasks(): task_name += "-ssl" if sync == "async": task_name += "-async" + elif sync == "rust": + task_name += "-rust" tags = ["perf"] commands = [server_func, test_func, attach_func, send_func] tasks.append(EvgTask(name=task_name, tags=tags, commands=commands)) @@ -1087,7 +1107,7 @@ def create_upload_coverage_codecov_func(): "github_pr_number", "github_pr_head_branch", "github_author", - "is_patch", + "requester", "branch_name", ] args = [ @@ -1189,6 +1209,8 @@ def create_run_server_func(): "LOAD_BALANCER", "LOCAL_ATLAS", "NO_EXT", + "PYMONGO_BUILD_RUST", + "PYMONGO_USE_RUST", ] args = [".evergreen/just.sh", "run-server", "${TEST_NAME}"] sub_cmd = get_subprocess_exec(include_expansions_in_env=includes, args=args) @@ -1222,6 +1244,8 @@ def create_run_tests_func(): "IS_WIN32", "REQUIRE_FIPS", "TEST_MIN_DEPS", + "PYMONGO_BUILD_RUST", + "PYMONGO_USE_RUST", ] args = [".evergreen/just.sh", "setup-tests", "${TEST_NAME}", "${SUB_TEST_NAME}"] setup_cmd = get_subprocess_exec(include_expansions_in_env=includes, args=args) @@ -1230,7 +1254,7 @@ def create_run_tests_func(): def create_test_numpy_func(): - includes = ["TOOLCHAIN_VERSION"] + includes = ["TOOLCHAIN_VERSION", "COVERAGE"] test_cmd = get_subprocess_exec( include_expansions_in_env=includes, args=[".evergreen/just.sh", "test-numpy"] ) @@ -1283,6 +1307,55 @@ def create_send_dashboard_data_func(): return "send dashboard data", cmds +def create_rust_variants(): + """Create build variants that test with Rust extension alongside C extension.""" + variants = [] + + # Test Rust on Linux (primary platform) - runs on PRs + # Run standard tests with Rust enabled (both sync and async) + variant = create_variant( + [".test-standard .server-latest .pr"], + "Test with Rust Extension", + host=DEFAULT_HOST, + tags=["rust", "pr"], + expansions=dict( + PYMONGO_BUILD_RUST="1", + PYMONGO_USE_RUST="1", + ), + ) + variants.append(variant) + + # Test on macOS ARM64 (important for M1/M2 Macs) + variant = create_variant( + [".test-standard .server-latest !.pr"], + "Test with Rust Extension - macOS ARM64", + host=HOSTS["macos-arm64"], + tags=["rust"], + batchtime=BATCHTIME_WEEK, + expansions=dict( + PYMONGO_BUILD_RUST="1", + PYMONGO_USE_RUST="1", + ), + ) + variants.append(variant) + + # Test on Windows (important for cross-platform compatibility) + variant = create_variant( + [".test-standard .server-latest !.pr"], + "Test with Rust Extension - Windows", + host=HOSTS["win64"], + tags=["rust"], + batchtime=BATCHTIME_WEEK, + expansions=dict( + PYMONGO_BUILD_RUST="1", + PYMONGO_USE_RUST="1", + ), + ) + variants.append(variant) + + return variants + + mod = sys.modules[__name__] write_variants_to_file(mod) write_tasks_to_file(mod) diff --git a/.evergreen/scripts/install-dependencies.sh b/.evergreen/scripts/install-dependencies.sh index 8df2af79ca..3acc996e1f 100755 --- a/.evergreen/scripts/install-dependencies.sh +++ b/.evergreen/scripts/install-dependencies.sh @@ -30,7 +30,7 @@ fi # Ensure just is installed. if ! command -v just &>/dev/null; then - uv tool install rust-just + uv tool install rust-just || uv tool install --force rust-just fi popd > /dev/null diff --git a/.evergreen/scripts/install-rust.sh b/.evergreen/scripts/install-rust.sh new file mode 100755 index 0000000000..34d97c80ef --- /dev/null +++ b/.evergreen/scripts/install-rust.sh @@ -0,0 +1,71 @@ +#!/bin/bash +# Install Rust toolchain for building the Rust BSON extension. +set -eu + +echo "Installing Rust toolchain..." + +# Check if Rust is already installed +if command -v cargo &> /dev/null; then + echo "Rust is already installed:" + rustc --version + cargo --version + echo "Updating Rust toolchain..." + rustup update stable +else + echo "Rust not found. Installing Rust..." + + # Install Rust using rustup + if [ "Windows_NT" = "${OS:-}" ]; then + # Windows installation + curl --proto '=https' --tlsv1.2 -sSf https://win.rustup.rs/x86_64 -o rustup-init.exe + ./rustup-init.exe -y --default-toolchain stable + rm rustup-init.exe + + # Add to PATH for current session + export PATH="$HOME/.cargo/bin:$PATH" + else + # Unix-like installation (Linux, macOS) + # Ensure CARGO_HOME is exported so rustup uses it + export CARGO_HOME="${CARGO_HOME:-$HOME/.cargo}" + export RUSTUP_HOME="${RUSTUP_HOME:-${CARGO_HOME}}" + + curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable + + # Source cargo env from the installation location + # On CI, CARGO_HOME is set to ${DRIVERS_TOOLS}/.cargo by configure-env.sh + CARGO_ENV_PATH="${CARGO_HOME}/env" + + if [ -f "${CARGO_ENV_PATH}" ]; then + source "${CARGO_ENV_PATH}" + else + echo "Error: Cargo env file not found at ${CARGO_ENV_PATH}" + echo "CARGO_HOME=${CARGO_HOME}" + echo "RUSTUP_HOME=${RUSTUP_HOME}" + echo "HOME=${HOME}" + exit 1 + fi + fi + + echo "Rust installation complete:" + rustc --version + cargo --version +fi + +# Ensure default toolchain is set (needed for rustup to work properly) +echo "Setting default toolchain to stable..." +rustup default stable + +# Install maturin if not already installed +if ! command -v maturin &> /dev/null; then + echo "Installing maturin..." + # Use pip instead of cargo to avoid yanked dependency issues + # (e.g., maturin 1.12.2 depends on cargo-xwin which has yanked xwin versions) + pip install maturin + echo "maturin installation complete:" + maturin --version +else + echo "maturin is already installed:" + maturin --version +fi + +echo "Rust toolchain setup complete." diff --git a/.evergreen/scripts/resync-all-specs.py b/.evergreen/scripts/resync-all-specs.py index 1996d5d634..16782de9a8 100644 --- a/.evergreen/scripts/resync-all-specs.py +++ b/.evergreen/scripts/resync-all-specs.py @@ -7,6 +7,8 @@ from argparse import Namespace from subprocess import CalledProcessError +JIRA_FILTER = "https://jira.mongodb.org/issues/?jql=labels%20%3D%20automated-sync%20AND%20status%20!%3D%20Closed" + def resync_specs(directory: pathlib.Path, errored: dict[str, str]) -> None: """Actually sync the specs""" @@ -117,6 +119,7 @@ def write_summary(errored: dict[str, str], new: list[str], filename: str | None) pr_body += "\n -".join(new) pr_body += "\n" if pr_body != "": + pr_body = f"Jira tickets: {JIRA_FILTER}\n\n" + pr_body if filename is None: print(f"\n{pr_body}") else: diff --git a/.evergreen/scripts/run_server.py b/.evergreen/scripts/run_server.py index a35fbb57a8..9757eb3a4f 100644 --- a/.evergreen/scripts/run_server.py +++ b/.evergreen/scripts/run_server.py @@ -12,7 +12,7 @@ def set_env(name: str, value: Any = "1") -> None: def start_server(): opts, extra_opts = get_test_options( - "Run a MongoDB server. All given flags will be passed to run-orchestration.sh in DRIVERS_TOOLS.", + "Run a MongoDB server. All given flags will be passed to run-mongodb.sh in DRIVERS_TOOLS.", require_sub_test_name=False, allow_extra_opts=True, ) @@ -51,7 +51,7 @@ def start_server(): elif opts.quiet: extra_opts.append("-q") - cmd = ["bash", f"{DRIVERS_TOOLS}/.evergreen/run-orchestration.sh", *extra_opts] + cmd = ["bash", f"{DRIVERS_TOOLS}/.evergreen/run-mongodb.sh", "start", *extra_opts] run_command(cmd, cwd=DRIVERS_TOOLS) diff --git a/.evergreen/scripts/run_tests.py b/.evergreen/scripts/run_tests.py index 9c8101c5b1..d98d181987 100644 --- a/.evergreen/scripts/run_tests.py +++ b/.evergreen/scripts/run_tests.py @@ -4,7 +4,9 @@ import logging import os import platform +import shlex import shutil +import subprocess import sys from datetime import datetime from pathlib import Path @@ -151,6 +153,30 @@ def run() -> None: if os.environ.get("PYMONGOCRYPT_LIB"): handle_pymongocrypt() + # Check if Rust extension is being used + LOGGER.info(f"PYMONGO_USE_RUST={os.environ.get('PYMONGO_USE_RUST', 'not set')}") + LOGGER.info(f"PYMONGO_BUILD_RUST={os.environ.get('PYMONGO_BUILD_RUST', 'not set')}") + + if os.environ.get("PYMONGO_USE_RUST") or os.environ.get("PYMONGO_BUILD_RUST"): + try: + import bson + + impl = bson.get_bson_implementation() + has_rust = bson.has_rust() + has_c = bson.has_c() + + LOGGER.info(f"BSON implementation in use: {impl}") + LOGGER.info(f"Has Rust: {has_rust}, Has C: {has_c}") + + if impl == "rust": + LOGGER.info("✓ Rust extension is ACTIVE") + elif impl == "c": + LOGGER.info("✓ C extension is ACTIVE") + else: + LOGGER.info("✓ Pure Python implementation is ACTIVE") + except Exception as e: + LOGGER.warning(f"Could not check BSON implementation: {e}") + LOGGER.info(f"Test setup:\n{AUTH=}\n{SSL=}\n{UV_ARGS=}\n{TEST_ARGS=}") # Record the start time for a perf test. @@ -202,6 +228,16 @@ def run() -> None: if os.environ.get("DEBUG_LOG"): TEST_ARGS.extend(f"-o log_cli_level={logging.DEBUG}".split()) + if os.environ.get("COVERAGE"): + binary = sys.executable.replace(os.sep, "/") + cmd = f"{binary} -m coverage run -m pytest {' '.join(TEST_ARGS)} {' '.join(sys.argv[1:])}" + result = subprocess.run(shlex.split(cmd), check=False) # noqa: S603 + cmd = f"{binary} -m coverage report" + subprocess.run(shlex.split(cmd), check=False) # noqa: S603 + if result.returncode != 0: + print(result.stderr) + sys.exit(result.returncode) + # Run local tests. ret = pytest.main(TEST_ARGS + sys.argv[1:]) if ret != 0: diff --git a/.evergreen/scripts/setup-dev-env.sh b/.evergreen/scripts/setup-dev-env.sh index fa5f86d798..2fec5c66ac 100755 --- a/.evergreen/scripts/setup-dev-env.sh +++ b/.evergreen/scripts/setup-dev-env.sh @@ -22,6 +22,11 @@ bash $HERE/install-dependencies.sh # Handle the value for UV_PYTHON. . $HERE/setup-uv-python.sh +# Show Rust toolchain status for debugging +echo "Rust toolchain: $(rustc --version 2>/dev/null || echo 'not found')" +echo "Cargo: $(cargo --version 2>/dev/null || echo 'not found')" +echo "Maturin: $(maturin --version 2>/dev/null || echo 'not found')" + # Only run the next part if not running on CI. if [ -z "${CI:-}" ]; then # Add the default install path to the path if needed. diff --git a/.evergreen/scripts/setup-tests.sh b/.evergreen/scripts/setup-tests.sh index 858906a39e..0bb19402f0 100755 --- a/.evergreen/scripts/setup-tests.sh +++ b/.evergreen/scripts/setup-tests.sh @@ -13,6 +13,8 @@ set -eu # MONGODB_API_VERSION The mongodb api version to use in tests. # MONGODB_URI If non-empty, use as the MONGODB_URI in tests. # USE_ACTIVE_VENV If non-empty, use the active virtual environment. +# PYMONGO_BUILD_RUST If non-empty, build and test with Rust extension. +# PYMONGO_USE_RUST If non-empty, use the Rust extension for tests. SCRIPT_DIR=$(dirname ${BASH_SOURCE:-$0}) @@ -21,6 +23,12 @@ if [ -f $SCRIPT_DIR/env.sh ]; then source $SCRIPT_DIR/env.sh fi +# Install Rust toolchain if building Rust extension +if [ -n "${PYMONGO_BUILD_RUST:-}" ]; then + echo "PYMONGO_BUILD_RUST is set, installing Rust toolchain..." + bash $SCRIPT_DIR/install-rust.sh +fi + echo "Setting up tests with args \"$*\"..." uv run ${USE_ACTIVE_VENV:+--active} "$SCRIPT_DIR/setup_tests.py" "$@" echo "Setting up tests with args \"$*\"... done." diff --git a/.evergreen/scripts/setup_tests.py b/.evergreen/scripts/setup_tests.py index 939423ffcc..756802eaab 100644 --- a/.evergreen/scripts/setup_tests.py +++ b/.evergreen/scripts/setup_tests.py @@ -32,6 +32,8 @@ "UV_PYTHON", "REQUIRE_FIPS", "IS_WIN32", + "PYMONGO_USE_RUST", + "PYMONGO_BUILD_RUST", ] # Map the test name to test extra. @@ -153,6 +155,10 @@ def handle_test_env() -> None: # Start compiling the args we'll pass to uv. UV_ARGS = ["--extra test --no-group dev"] + # If USE_ACTIVE_VENV is set, add --active to UV_ARGS so run-tests.sh uses the active venv. + if is_set("USE_ACTIVE_VENV"): + UV_ARGS.append("--active") + test_title = test_name if sub_test_name: test_title += f" {sub_test_name}" @@ -324,7 +330,8 @@ def handle_test_env() -> None: version = os.environ.get("VERSION", "latest") cmd = [ "bash", - f"{DRIVERS_TOOLS}/.evergreen/run-orchestration.sh", + f"{DRIVERS_TOOLS}/.evergreen/run-mongodb.sh", + "start", "--ssl", "--version", version, @@ -431,6 +438,9 @@ def handle_test_env() -> None: # We do not want the default client_context to be initialized. write_env("DISABLE_CONTEXT") + if test_name == "numpy": + UV_ARGS.append("--with numpy") + if test_name == "perf": data_dir = ROOT / "specifications/source/benchmarking/data" if not data_dir.exists(): @@ -447,7 +457,7 @@ def handle_test_env() -> None: # PYTHON-4769 Run perf_test.py directly otherwise pytest's test collection negatively # affects the benchmark results. - if sub_test_name == "sync": + if sub_test_name == "sync" or sub_test_name == "rust": TEST_ARGS = f"test/performance/perf_test.py {TEST_ARGS}" else: TEST_ARGS = f"test/performance/async_perf_test.py {TEST_ARGS}" @@ -458,12 +468,14 @@ def handle_test_env() -> None: # Keep in sync with combine-coverage.sh. # coverage >=5 is needed for relative_files=true. UV_ARGS.append("--group coverage") - TEST_ARGS = f"{TEST_ARGS} --cov" write_env("COVERAGE") if opts.green_framework: framework = opts.green_framework or os.environ["GREEN_FRAMEWORK"] UV_ARGS.append(f"--group {framework}") + if framework == "gevent" and opts.test_min_deps: + # PYTHON-5729. This can be removed when the min supported gevent is moved to 25.9.1. + UV_ARGS.append('--with "setuptools==81.0"') else: TEST_ARGS = f"-v --durations=5 {TEST_ARGS}" @@ -471,6 +483,10 @@ def handle_test_env() -> None: if TEST_SUITE: TEST_ARGS = f"-m {TEST_SUITE} {TEST_ARGS}" + # For test_bson, run the specific test file + if test_name == "test_bson": + TEST_ARGS = f"test/test_bson.py {TEST_ARGS}" + write_env("TEST_ARGS", TEST_ARGS) write_env("UV_ARGS", " ".join(UV_ARGS)) diff --git a/.evergreen/scripts/stop-server.sh b/.evergreen/scripts/stop-server.sh index 7599387f5f..045a655cbd 100755 --- a/.evergreen/scripts/stop-server.sh +++ b/.evergreen/scripts/stop-server.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Stop a server that was started using run-orchestration.sh in DRIVERS_TOOLS. +# Stop a server that was started using run-mongodb.sh in DRIVERS_TOOLS. set -eu HERE=$(dirname ${BASH_SOURCE:-$0}) @@ -11,4 +11,4 @@ if [ -f $HERE/env.sh ]; then source $HERE/env.sh fi -bash ${DRIVERS_TOOLS}/.evergreen/stop-orchestration.sh +bash ${DRIVERS_TOOLS}/.evergreen/run-mongodb.sh stop diff --git a/.evergreen/scripts/upload-codecov.sh b/.evergreen/scripts/upload-codecov.sh index 75bd9d9e21..5c1d84c55c 100755 --- a/.evergreen/scripts/upload-codecov.sh +++ b/.evergreen/scripts/upload-codecov.sh @@ -8,18 +8,20 @@ ROOT=$(dirname "$(dirname $HERE)") pushd $ROOT > /dev/null export FNAME=coverage.xml - -if [ -n "${is_patch:-}" ]; then - echo "This is a patch build, not running codecov" - exit 0 -fi +REQUESTER=${requester:-} if [ ! -f ".coverage" ]; then echo "There are no coverage results, not running codecov" exit 0 fi -echo "Uploading..." +if [[ "${REQUESTER}" == "github_pr" || "${REQUESTER}" == "commit" ]]; then + echo "Uploading codecov for $REQUESTER..." +else + echo "Error: requester must be 'github_pr' or 'commit', got '${REQUESTER}'" >&2 + exit 1 +fi + printf 'sha: %s\n' "$github_commit" printf 'flag: %s-%s\n' "$build_variant" "$task_name" printf 'file: %s\n' "$FNAME" @@ -40,18 +42,16 @@ codecov_args=( if [ -n "${github_pr_number:-}" ]; then printf 'branch: %s:%s\n' "$github_author" "$github_pr_head_branch" printf 'pr: %s\n' "$github_pr_number" - uv tool run --from codecov-cli codecovcli \ "${codecov_args[@]}" \ --pr "${github_pr_number}" \ --branch "${github_author}:${github_pr_head_branch}" else printf 'branch: %s\n' "$branch_name" - uv tool run --from codecov-cli codecovcli \ "${codecov_args[@]}" \ --branch "${branch_name}" fi -echo "Uploading...done." +echo "Uploading codecov for $REQUESTER... done." popd > /dev/null diff --git a/.evergreen/scripts/utils.py b/.evergreen/scripts/utils.py index 2bc9c720d2..ae0242fd5b 100644 --- a/.evergreen/scripts/utils.py +++ b/.evergreen/scripts/utils.py @@ -44,6 +44,8 @@ class Distro: "mockupdb": "mockupdb", "ocsp": "ocsp", "perf": "perf", + "numpy": "", + "test_bson": "", } # Tests that require a sub test suite. @@ -51,7 +53,7 @@ class Distro: EXTRA_TESTS = ["mod_wsgi", "aws_lambda", "doctest"] -# Tests that do not use run-orchestration directly. +# Tests that do not use run-mongodb directly. NO_RUN_ORCHESTRATION = [ "auth_oidc", "atlas_connect", diff --git a/.evergreen/spec-patch/PYTHON-5759.patch b/.evergreen/spec-patch/PYTHON-5759.patch new file mode 100644 index 0000000000..3b19ed065e --- /dev/null +++ b/.evergreen/spec-patch/PYTHON-5759.patch @@ -0,0 +1,460 @@ +diff --git a/test/client-side-encryption/spec/unified/accessToken-azure.json b/test/client-side-encryption/spec/unified/accessToken-azure.json +new file mode 100644 +index 00000000..510d8795 +--- /dev/null ++++ b/test/client-side-encryption/spec/unified/accessToken-azure.json +@@ -0,0 +1,186 @@ ++{ ++ "description": "accessToken-azure", ++ "schemaVersion": "1.28", ++ "runOnRequirements": [ ++ { ++ "minServerVersion": "4.1.10", ++ "csfle": { ++ "minLibmongocryptVersion": "1.6.0" ++ } ++ } ++ ], ++ "createEntities": [ ++ { ++ "client": { ++ "id": "client", ++ "autoEncryptOpts": { ++ "keyVaultNamespace": "keyvault.datakeys", ++ "kmsProviders": { ++ "azure": { ++ "accessToken": { ++ "$$placeholder": 1 ++ } ++ } ++ } ++ } ++ } ++ }, ++ { ++ "database": { ++ "id": "db", ++ "client": "client", ++ "databaseName": "db" ++ } ++ }, ++ { ++ "collection": { ++ "id": "coll", ++ "database": "db", ++ "collectionName": "coll" ++ } ++ }, ++ { ++ "clientEncryption": { ++ "id": "clientEncryption", ++ "clientEncryptionOpts": { ++ "keyVaultClient": "client", ++ "keyVaultNamespace": "keyvault.datakeys", ++ "kmsProviders": { ++ "azure": { ++ "accessToken": { ++ "$$placeholder": 1 ++ } ++ } ++ } ++ } ++ } ++ } ++ ], ++ "initialData": [ ++ { ++ "databaseName": "db", ++ "collectionName": "coll", ++ "documents": [], ++ "createOptions": { ++ "validator": { ++ "$jsonSchema": { ++ "properties": { ++ "secret": { ++ "encrypt": { ++ "keyId": [ ++ { ++ "$binary": { ++ "base64": "AZURE+AAAAAAAAAAAAAAAA==", ++ "subType": "04" ++ } ++ } ++ ], ++ "bsonType": "string", ++ "algorithm": "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic" ++ } ++ } ++ }, ++ "bsonType": "object" ++ } ++ } ++ } ++ }, ++ { ++ "databaseName": "keyvault", ++ "collectionName": "datakeys", ++ "documents": [ ++ { ++ "_id": { ++ "$binary": { ++ "base64": "AZURE+AAAAAAAAAAAAAAAA==", ++ "subType": "04" ++ } ++ }, ++ "keyAltNames": [ ++ "my-key" ++ ], ++ "keyMaterial": { ++ "$binary": { ++ "base64": "n+HWZ0ZSVOYA3cvQgP7inN4JSXfOH85IngmeQxRpQHjCCcqT3IFqEWNlrsVHiz3AELimHhX4HKqOLWMUeSIT6emUDDoQX9BAv8DR1+E1w4nGs/NyEneac78EYFkK3JysrFDOgl2ypCCTKAypkn9CkAx1if4cfgQE93LW4kczcyHdGiH36CIxrCDGv1UzAvERN5Qa47DVwsM6a+hWsF2AAAJVnF0wYLLJU07TuRHdMrrphPWXZsFgyV+lRqJ7DDpReKNO8nMPLV/mHqHBHGPGQiRdb9NoJo8CvokGz4+KE8oLwzKf6V24dtwZmRkrsDV4iOhvROAzz+Euo1ypSkL3mw==", ++ "subType": "00" ++ } ++ }, ++ "creationDate": { ++ "$date": { ++ "$numberLong": "1552949630483" ++ } ++ }, ++ "updateDate": { ++ "$date": { ++ "$numberLong": "1552949630483" ++ } ++ }, ++ "status": { ++ "$numberInt": "0" ++ }, ++ "masterKey": { ++ "provider": "azure", ++ "keyVaultEndpoint": "key-vault-csfle.vault.azure.net", ++ "keyName": "key-name-csfle" ++ } ++ } ++ ] ++ } ++ ], ++ "tests": [ ++ { ++ "description": "Auto encrypt using access token Azure credentials", ++ "operations": [ ++ { ++ "name": "insertOne", ++ "arguments": { ++ "document": { ++ "_id": 1, ++ "secret": "string0" ++ } ++ }, ++ "object": "coll" ++ } ++ ], ++ "outcome": [ ++ { ++ "documents": [ ++ { ++ "_id": 1, ++ "secret": { ++ "$binary": { ++ "base64": "AQGVERPgAAAAAAAAAAAAAAAC5DbBSwPwfSlBrDtRuglvNvCXD1KzDuCKY2P+4bRFtHDjpTOE2XuytPAUaAbXf1orsPq59PVZmsbTZbt2CB8qaQ==", ++ "subType": "06" ++ } ++ } ++ } ++ ], ++ "collectionName": "coll", ++ "databaseName": "db" ++ } ++ ] ++ }, ++ { ++ "description": "Explicit encrypt using access token Azure credentials", ++ "operations": [ ++ { ++ "name": "encrypt", ++ "object": "clientEncryption", ++ "arguments": { ++ "value": "string0", ++ "opts": { ++ "keyAltName": "my-key", ++ "algorithm": "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic" ++ } ++ }, ++ "expectResult": { ++ "$binary": { ++ "base64": "AQGVERPgAAAAAAAAAAAAAAAC5DbBSwPwfSlBrDtRuglvNvCXD1KzDuCKY2P+4bRFtHDjpTOE2XuytPAUaAbXf1orsPq59PVZmsbTZbt2CB8qaQ==", ++ "subType": "06" ++ } ++ } ++ } ++ ] ++ } ++ ] ++} +diff --git a/test/client-side-encryption/spec/unified/accessToken-gcp.json b/test/client-side-encryption/spec/unified/accessToken-gcp.json +new file mode 100644 +index 00000000..f5cf8914 +--- /dev/null ++++ b/test/client-side-encryption/spec/unified/accessToken-gcp.json +@@ -0,0 +1,188 @@ ++{ ++ "description": "accessToken-gcp", ++ "schemaVersion": "1.28", ++ "runOnRequirements": [ ++ { ++ "minServerVersion": "4.1.10", ++ "csfle": { ++ "minLibmongocryptVersion": "1.6.0" ++ } ++ } ++ ], ++ "createEntities": [ ++ { ++ "client": { ++ "id": "client", ++ "autoEncryptOpts": { ++ "keyVaultNamespace": "keyvault.datakeys", ++ "kmsProviders": { ++ "gcp": { ++ "accessToken": { ++ "$$placeholder": 1 ++ } ++ } ++ } ++ } ++ } ++ }, ++ { ++ "database": { ++ "id": "db", ++ "client": "client", ++ "databaseName": "db" ++ } ++ }, ++ { ++ "collection": { ++ "id": "coll", ++ "database": "db", ++ "collectionName": "coll" ++ } ++ }, ++ { ++ "clientEncryption": { ++ "id": "clientEncryption", ++ "clientEncryptionOpts": { ++ "keyVaultClient": "client", ++ "keyVaultNamespace": "keyvault.datakeys", ++ "kmsProviders": { ++ "gcp": { ++ "accessToken": { ++ "$$placeholder": 1 ++ } ++ } ++ } ++ } ++ } ++ } ++ ], ++ "initialData": [ ++ { ++ "databaseName": "db", ++ "collectionName": "coll", ++ "documents": [], ++ "createOptions": { ++ "validator": { ++ "$jsonSchema": { ++ "properties": { ++ "secret": { ++ "encrypt": { ++ "keyId": [ ++ { ++ "$binary": { ++ "base64": "GCP+AAAAAAAAAAAAAAAAAA==", ++ "subType": "04" ++ } ++ } ++ ], ++ "bsonType": "string", ++ "algorithm": "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic" ++ } ++ } ++ }, ++ "bsonType": "object" ++ } ++ } ++ } ++ }, ++ { ++ "databaseName": "keyvault", ++ "collectionName": "datakeys", ++ "documents": [ ++ { ++ "_id": { ++ "$binary": { ++ "base64": "GCP+AAAAAAAAAAAAAAAAAA==", ++ "subType": "04" ++ } ++ }, ++ "keyAltNames": [ ++ "my-key" ++ ], ++ "keyMaterial": { ++ "$binary": { ++ "base64": "CiQAIgLj0WyktnB4dfYHo5SLZ41K4ASQrjJUaSzl5vvVH0G12G0SiQEAjlV8XPlbnHDEDFbdTO4QIe8ER2/172U1ouLazG0ysDtFFIlSvWX5ZnZUrRMmp/R2aJkzLXEt/zf8Mn4Lfm+itnjgo5R9K4pmPNvvPKNZX5C16lrPT+aA+rd+zXFSmlMg3i5jnxvTdLHhg3G7Q/Uv1ZIJskKt95bzLoe0tUVzRWMYXLIEcohnQg==", ++ "subType": "00" ++ } ++ }, ++ "creationDate": { ++ "$date": { ++ "$numberLong": "1552949630483" ++ } ++ }, ++ "updateDate": { ++ "$date": { ++ "$numberLong": "1552949630483" ++ } ++ }, ++ "status": { ++ "$numberInt": "0" ++ }, ++ "masterKey": { ++ "provider": "gcp", ++ "projectId": "devprod-drivers", ++ "location": "global", ++ "keyRing": "key-ring-csfle", ++ "keyName": "key-name-csfle" ++ } ++ } ++ ] ++ } ++ ], ++ "tests": [ ++ { ++ "description": "Auto encrypt using access token GCP credentials", ++ "operations": [ ++ { ++ "name": "insertOne", ++ "arguments": { ++ "document": { ++ "_id": 1, ++ "secret": "string0" ++ } ++ }, ++ "object": "coll" ++ } ++ ], ++ "outcome": [ ++ { ++ "documents": [ ++ { ++ "_id": 1, ++ "secret": { ++ "$binary": { ++ "base64": "ARgj/gAAAAAAAAAAAAAAAAACwFd+Y5Ojw45GUXNvbcIpN9YkRdoHDHkR4kssdn0tIMKlDQOLFkWFY9X07IRlXsxPD8DcTiKnl6XINK28vhcGlg==", ++ "subType": "06" ++ } ++ } ++ } ++ ], ++ "collectionName": "coll", ++ "databaseName": "db" ++ } ++ ] ++ }, ++ { ++ "description": "Explicit encrypt using access token GCP credentials", ++ "operations": [ ++ { ++ "name": "encrypt", ++ "object": "clientEncryption", ++ "arguments": { ++ "value": "string0", ++ "opts": { ++ "keyAltName": "my-key", ++ "algorithm": "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic" ++ } ++ }, ++ "expectResult": { ++ "$binary": { ++ "base64": "ARgj/gAAAAAAAAAAAAAAAAACwFd+Y5Ojw45GUXNvbcIpN9YkRdoHDHkR4kssdn0tIMKlDQOLFkWFY9X07IRlXsxPD8DcTiKnl6XINK28vhcGlg==", ++ "subType": "06" ++ } ++ } ++ } ++ ] ++ } ++ ] ++} +diff --git a/test/unified-test-format/invalid/clientEncryptionOpts-kmsProviders-azure-accessToken-type.json b/test/unified-test-format/invalid/clientEncryptionOpts-kmsProviders-azure-accessToken-type.json +new file mode 100644 +index 00000000..8fe5c150 +--- /dev/null ++++ b/test/unified-test-format/invalid/clientEncryptionOpts-kmsProviders-azure-accessToken-type.json +@@ -0,0 +1,31 @@ ++{ ++ "description": "clientEncryptionOpts-kmsProviders-azure-accessToken-type", ++ "schemaVersion": "1.28", ++ "createEntities": [ ++ { ++ "client": { ++ "id": "client0" ++ } ++ }, ++ { ++ "clientEncryption": { ++ "id": "clientEncryption0", ++ "clientEncryptionOpts": { ++ "keyVaultClient": "client0", ++ "keyVaultNamespace": "keyvault.datakeys", ++ "kmsProviders": { ++ "azure": { ++ "accessToken": 0 ++ } ++ } ++ } ++ } ++ } ++ ], ++ "tests": [ ++ { ++ "description": "", ++ "operations": [] ++ } ++ ] ++} +diff --git a/test/unified-test-format/invalid/clientEncryptionOpts-kmsProviders-gcp-accessToken-type.json b/test/unified-test-format/invalid/clientEncryptionOpts-kmsProviders-gcp-accessToken-type.json +new file mode 100644 +index 00000000..2284e26c +--- /dev/null ++++ b/test/unified-test-format/invalid/clientEncryptionOpts-kmsProviders-gcp-accessToken-type.json +@@ -0,0 +1,31 @@ ++{ ++ "description": "clientEncryptionOpts-kmsProviders-gcp-accessToken-type", ++ "schemaVersion": "1.28", ++ "createEntities": [ ++ { ++ "client": { ++ "id": "client0" ++ } ++ }, ++ { ++ "clientEncryption": { ++ "id": "clientEncryption0", ++ "clientEncryptionOpts": { ++ "keyVaultClient": "client0", ++ "keyVaultNamespace": "keyvault.datakeys", ++ "kmsProviders": { ++ "gcp": { ++ "accessToken": 0 ++ } ++ } ++ } ++ } ++ } ++ ], ++ "tests": [ ++ { ++ "description": "", ++ "operations": [] ++ } ++ ] ++} diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 0000000000..a8943d11ac --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,3 @@ +Please see [AGENTS.md](../AGENTS.md). + +Follow the repository instructions defined in `AGENTS.md` when working in this codebase. diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 2e1a132c95..b1f0987b3c 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -6,8 +6,8 @@ If you are an external contributor and there is no JIRA ticket associated with y for the PR title. A MongoDB employee will create a JIRA ticket and edit the name and links as appropriate. Note on AI Contributions: -We do not accept pull requests that are primarily or substantially generated by AI tools (ChatGPT, Copilot, etc.). -All contributions must be written and understood by human contributors. +We only accept pull requests that are authored and submitted by human contributors who fully understand the changes they are proposing. +All contributions must be written and understood by human contributors. Please read about our policy in our contributing guide. --> [JIRA TICKET] diff --git a/.github/workflows/dist.yml b/.github/workflows/dist.yml index 645cb74de2..579bcc5f49 100644 --- a/.github/workflows/dist.yml +++ b/.github/workflows/dist.yml @@ -61,7 +61,7 @@ jobs: - name: Set up QEMU if: runner.os == 'Linux' - uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3 + uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0 with: # setup-qemu-action by default uses `tonistiigi/binfmt:latest` image, # which is out of date. This causes seg faults during build. @@ -92,7 +92,7 @@ jobs: # Free-threading builds: ls wheelhouse/*cp314t*.whl - - uses: actions/upload-artifact@v6 + - uses: actions/upload-artifact@v7 with: name: wheel-${{ matrix.buildplat[1] }} path: ./wheelhouse/*.whl @@ -125,7 +125,7 @@ jobs: cd .. python -c "from pymongo import has_c; assert has_c()" - - uses: actions/upload-artifact@v6 + - uses: actions/upload-artifact@v7 with: name: "sdist" path: ./dist/*.tar.gz @@ -136,13 +136,13 @@ jobs: name: Download Wheels steps: - name: Download all workflow run artifacts - uses: actions/download-artifact@v7 + uses: actions/download-artifact@v8 - name: Flatten directory working-directory: . run: | find . -mindepth 2 -type f -exec mv {} . \; find . -type d -empty -delete - - uses: actions/upload-artifact@v6 + - uses: actions/upload-artifact@v7 with: name: all-dist-${{ github.run_id }} path: "./*" diff --git a/.github/workflows/release-python.yml b/.github/workflows/release-python.yml index e3dd1edb1c..4387303224 100644 --- a/.github/workflows/release-python.yml +++ b/.github/workflows/release-python.yml @@ -75,7 +75,7 @@ jobs: id-token: write steps: - name: Download all the dists - uses: actions/download-artifact@v7 + uses: actions/download-artifact@v8 with: name: all-dist-${{ github.run_id }} path: dist/ diff --git a/.github/workflows/sbom.yml b/.github/workflows/sbom.yml index 69a07c8be2..ce7308cf1b 100644 --- a/.github/workflows/sbom.yml +++ b/.github/workflows/sbom.yml @@ -67,7 +67,7 @@ jobs: run: rm -rf .venv .venv-sbom sbom-requirements.txt - name: Upload SBOM artifact - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@v7 with: name: sbom path: sbom.json diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index 388f68bbe5..0c75fe8235 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -26,7 +26,7 @@ jobs: with: persist-credentials: false - name: Install uv - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7 with: enable-cache: true python-version: "3.10" @@ -61,23 +61,40 @@ jobs: os: [ubuntu-latest] python-version: ["3.10", "pypy-3.11", "3.13t"] mongodb-version: ["8.0"] + extension: ["c", "rust"] + exclude: + # Don't test Rust with pypy + - python-version: "pypy-3.11" + extension: "rust" + # Don't test Rust with free-threaded Python (not yet supported) + - python-version: "3.13t" + extension: "rust" - name: CPython ${{ matrix.python-version }}-${{ matrix.os }} + name: CPython ${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.extension }} + continue-on-error: ${{ matrix.extension == 'rust' }} steps: - uses: actions/checkout@v6 with: persist-credentials: false - name: Install uv - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7 with: enable-cache: true python-version: ${{ matrix.python-version }} + - name: Install Rust toolchain + if: matrix.extension == 'rust' + uses: dtolnay/rust-toolchain@efa25f7f19611383d5b0ccf2d1c8914531636bf9 # stable + with: + toolchain: stable - id: setup-mongodb uses: mongodb-labs/drivers-evergreen-tools@master with: version: "${{ matrix.mongodb-version }}" - name: Run tests run: uv run --extra test pytest -v + env: + PYMONGO_BUILD_RUST: ${{ matrix.extension == 'rust' && '1' || '' }} + PYMONGO_USE_RUST: ${{ matrix.extension == 'rust' && '1' || '' }} coverage: # This enables a coverage report for a given PR, which will be augmented by @@ -90,7 +107,7 @@ jobs: with: persist-credentials: false - name: Install uv - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7 with: enable-cache: true python-version: "3.10" @@ -118,7 +135,7 @@ jobs: with: persist-credentials: false - name: Install uv - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7 with: enable-cache: true python-version: "3.10" @@ -143,7 +160,7 @@ jobs: with: persist-credentials: false - name: Install uv - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7 with: enable-cache: true python-version: "3.10" @@ -162,7 +179,7 @@ jobs: with: persist-credentials: false - name: Install uv - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7 with: enable-cache: true python-version: "3.10" @@ -184,7 +201,7 @@ jobs: with: persist-credentials: false - name: Install uv - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7 with: enable-cache: true python-version: "${{matrix.python}}" @@ -205,7 +222,7 @@ jobs: with: persist-credentials: false - name: Install uv - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7 with: enable-cache: true python-version: "3.10" @@ -245,7 +262,7 @@ jobs: run: | pip install build python -m build --sdist - - uses: actions/upload-artifact@v6 + - uses: actions/upload-artifact@v7 with: name: "sdist" path: dist/*.tar.gz @@ -257,7 +274,7 @@ jobs: timeout-minutes: 20 steps: - name: Download sdist - uses: actions/download-artifact@v7 + uses: actions/download-artifact@v8 with: path: sdist/ - name: Unpack SDist @@ -295,7 +312,7 @@ jobs: with: persist-credentials: false - name: Install uv - uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7 with: python-version: "3.9" - id: setup-mongodb diff --git a/.github/workflows/zizmor.yml b/.github/workflows/zizmor.yml index 26f75fa792..6a642977be 100644 --- a/.github/workflows/zizmor.yml +++ b/.github/workflows/zizmor.yml @@ -18,4 +18,4 @@ jobs: with: persist-credentials: false - name: Run zizmor 🌈 - uses: zizmorcore/zizmor-action@135698455da5c3b3e55f73f4419e481ab68cdd95 # v0.4.1 + uses: zizmorcore/zizmor-action@71321a20a9ded102f6e9ce5718a2fcec2c4f70d8 # v0.5.2 diff --git a/.gitignore b/.gitignore index cb4940a55e..e6910999bd 100644 --- a/.gitignore +++ b/.gitignore @@ -43,3 +43,8 @@ test/lambda/*.json xunit-results/ coverage.xml server.log +.coverage + +# Rust build artifacts +target/ +Cargo.lock diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d2b9d9a17a..c1351a3813 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -103,7 +103,8 @@ repos: # - test/test_bson.py:267: isnt ==> isn't # - test/versioned-api/crud-api-version-1-strict.json:514: nin ==> inn, min, bin, nine # - test/test_client.py:188: te ==> the, be, we, to - args: ["-L", "fle,fo,infinit,isnt,nin,te,aks"] + # - README.md:534: crate ==> create (Rust terminology - a crate is a Rust package) + args: ["-L", "fle,fo,infinit,isnt,nin,te,aks,crate"] - repo: local hooks: diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000000..b67cb49aca --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,44 @@ +When reviewing code, focus on: + +## Security Critical Issues +- Check for hardcoded secrets, API keys, or credentials. +- Check for instances of potential method call injection, dynamic code execution, symbol injection or other code injection vulnerabilities. + +## Performance Red Flags +- Spot inefficient loops and algorithmic issues. +- Check for memory leaks and resource cleanup. + +## Code Quality Essentials +- Methods should be focused and appropriately sized. If a method is doing too much, suggest refactorings to split it up. +- Use clear, descriptive naming conventions. +- Avoid encapsulation violations and ensure proper separation of concerns. +- All public classes, modules, and methods should have clear documentation in Sphinx format. + +## PyMongo-specific Concerns +- Do not review files within `pymongo/synchronous` or files in `test/` that also have a file of the same name in `test/asynchronous` unless the reviewed changes include a `_IS_SYNC` statement. PyMongo generates these files from `pymongo/asynchronous` and `test/asynchronous` using `tools/synchro.py`. +- All asynchronous functions must not call any blocking I/O. + +## Review Style +- Be specific and actionable in feedback. +- Explain the "why" behind recommendations. +- Acknowledge good patterns when you see them. +- Ask clarifying questions when code intent is unclear. + +Always prioritize security vulnerabilities and performance issues that could impact users. + +Always suggest changes to improve readability and testability. For example, this suggestion seeks to make the code more readable, reusable, and testable: + +```python +# Instead of: +if user.email and "@" in user.email and len(user.email) > 5: + submit_button.enabled = True +else: + submit_button.enabled = False + +# Consider: +def valid_email(email): + return email and "@" in email and len(email) > 5 + + +submit_button.enabled = valid_email(user.email) +``` diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index eb1c35fc8b..77888eb087 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -85,49 +85,53 @@ likelihood for getting review sooner shoots up. - `versionadded:: 3.11` - `versionchanged:: 3.5` -**Pull Request Template Breakdown** +### AI-Generated Contributions Policy -- **Github PR Title** +#### Our Stance - - The PR Title format should always be - `[JIRA-ID] : Jira Title or Blurb Summary`. +We only accept pull requests that are authored and submitted by human contributors who fully understand the changes they are proposing. Pull requests that are not clearly owned and understood by a human contributor may be closed. **All contributions must be submitted, reviewed, and understood by human contributors.** -- **JIRA LINK** +##### Why This Policy Exists -- Convenient link to the associated JIRA ticket. +At MongoDB, we understand the power and prevalence of AI tools in software development. With that being said, many MongoDB libraries are foundational tools used in production systems worldwide. The nature of these libraries requires: -- **Summary** +- **Deep domain expertise**: MongoDB's wire protocol, BSON specification, connection pooling, authentication mechanisms, and concurrency patterns require an understanding that AI alone cannot substantiate. - - Small blurb on why this is needed. The JIRA task should have - the more in-depth description, but this should still, at a - high level, give anyone looking an understanding of why the - PR has been checked in. +- **Long-term maintainability**: Contributors need to be able to explain *why* code is written a certain way, explain design decisions, and be available to iterate on their contributions. -- **Changes in this PR** +- **Security responsibility**: Authentication, credential handling, and TLS implementation cannot be left to probabilistic code generation. - - The explicit code changes that this PR is introducing. This - should be more specific than just the task name. (Unless the - task name is very clear). +##### What This Means for Contributors -- **Test Plan** +**Required:** - - Everything needs a test description. Describe what you did - to validate your changes actually worked; if you did - nothing, then document you did not test it. Aim to make - these steps reproducible by other engineers, specifically - with your primary reviewer in mind. +- Full understanding of every line of code you submit +- Ability to explain and defend your implementation choices +- Willingness to iterate and maintain your contributions -- **Screenshots** +**Encouraged:** - - Any images that provide more context to the PR. Usually, - these just coincide with the test plan. +- Using AI assistants as learning tools to understand concepts +- IDE autocomplete features that suggest standard patterns +- AI help for brainstorming approaches (but write the code yourself) +- Writing code using AI tools, reviewing each line and revising code as necessary. -- **Callouts or follow-up items** +**Not allowed:** - - This is a good place for identifying "to-dos" that you've - placed in the code (Must have an accompanying JIRA Ticket). - - Potential bugs that you are unsure how to test in the code. - - Opinions you want to receive about your code. +- Submitting PRs generated solely by AI tools +- Copy-pasting AI-generated code without full understanding + +##### Disclosure + +If you used AI assistance in any way during your contribution, please disclose what the AI assistant was used for in your PR description. We would love to know what tools developers have found useful in iterating in their day to day. + +##### Questions? + +If you're unsure whether your contribution complies with this policy, please ask for guidance within the scope of the PR and clarify any uncertainty. We're happy to guide contributors toward successful contributions. + +--- + +*This policy helps us maintain the reliability, security, and trustworthiness that production applications depend on. Thank you for understanding and for contributing thoughtfully to PyMongo.* ## Running Linters @@ -197,7 +201,7 @@ the pages will re-render and the browser will automatically refresh. version of Python, set `UV_PYTHON` before running `just install`. - Ensure you have started the appropriate Mongo Server(s). You can run `just run-server` with optional args to set up the server. All given options will be passed to - [`run-orchestration.sh`](https://github.com/mongodb-labs/drivers-evergreen-tools/blob/master/.evergreen/run-orchestration.sh). Run `$DRIVERS_TOOLS/.evergreen/run-orchestration.sh -h` + [`run-mongodb.sh`](https://github.com/mongodb-labs/drivers-evergreen-tools/blob/master/.evergreen/run-mongodb.sh). Run `$DRIVERS_TOOLS/.evergreen/run-mongodb.sh start -h` for a full list of options. - Run `just test` or `pytest` to run all of the tests. - Append `test/.py::::` to run @@ -205,6 +209,7 @@ the pages will re-render and the browser will automatically refresh. and the `` to test a full module. For example: `just test test/test_change_stream.py::TestUnifiedChangeStreamsErrors::test_change_stream_errors_on_ElectionInProgress`. - Use the `-k` argument to select tests by pattern. +- Run `just test-coverage` to run tests with coverage and display a report. After running tests with coverage, use `just coverage-html` to generate an HTML report in `htmlcov/index.html`. ## Running tests that require secrets, services, or other configuration @@ -396,7 +401,7 @@ To run any of the test suites with minimum supported dependencies, pass `--test- - If adding new tests files that should only be run for that test suite, add a pytest marker to the file and add to the list of pytest markers in `pyproject.toml`. Then add the test suite to the `TEST_SUITE_MAP` in `.evergreen/scripts/utils.py`. If for some reason it is not a pytest-runnable test, add it to the list of `EXTRA_TESTS` instead. -- If the test uses Atlas or otherwise doesn't use `run-orchestration.sh`, add it to the `NO_RUN_ORCHESTRATION` list in +- If the test uses Atlas or otherwise doesn't use `run-mongodb.sh`, add it to the `NO_RUN_ORCHESTRATION` list in `.evergreen/scripts/utils.py`. - If there is something special required to run the local server or there is an extra flag that should always be set like `AUTH`, add that logic to `.evergreen/scripts/run_server.py`. diff --git a/README.md b/README.md index c807733e5b..703e000dc5 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ [![Python Versions](https://img.shields.io/pypi/pyversions/pymongo)](https://pypi.org/project/pymongo) [![Monthly Downloads](https://static.pepy.tech/badge/pymongo/month)](https://pepy.tech/project/pymongo) [![API Documentation Status](https://readthedocs.org/projects/pymongo/badge/?version=stable)](http://pymongo.readthedocs.io/en/stable/api?badge=stable) +[![codecov](https://codecov.io/gh/mongodb/mongo-python-driver/graph/badge.svg?branch=master)](https://codecov.io/gh/mongodb/mongo-python-driver) ## About @@ -215,4 +216,4 @@ pip install -e ".[test]" pytest ``` -For more advanced testing scenarios, see the [contributing guide](./CONTRIBUTING.md#running-tests-locally). +For more advanced testing scenarios, see the [contributing guide](https://github.com/mongodb/mongo-python-driver/blob/master/CONTRIBUTING.md#running-tests-locally). diff --git a/bson/__init__.py b/bson/__init__.py index ebb1bd0ccc..59b84e4d19 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -72,6 +72,7 @@ from __future__ import annotations import datetime +import importlib.util import itertools import os import re @@ -143,12 +144,79 @@ from bson.raw_bson import RawBSONDocument from bson.typings import _DocumentType, _ReadableBuffer +# Try to import C and Rust extensions +_cbson = None +_rbson = None +_HAS_C = False +_HAS_RUST = False + +# Use importlib to avoid circular import issues +_spec = None try: - from bson import _cbson # type: ignore[attr-defined] + # Check if already loaded (e.g., when reloading bson module) + if "bson._cbson" in sys.modules: + _cbson = sys.modules["bson._cbson"] + if hasattr(_cbson, "_bson_to_dict"): + _HAS_C = True + else: + _spec = importlib.util.find_spec("bson._cbson") + if _spec and _spec.loader: + _cbson = importlib.util.module_from_spec(_spec) + _spec.loader.exec_module(_cbson) + if hasattr(_cbson, "_bson_to_dict"): + _HAS_C = True + else: + _cbson = None +except (ImportError, AttributeError): + pass - _USE_C = True -except ImportError: - _USE_C = False +try: + # Check if already loaded (e.g., when reloading bson module) + if "bson._rbson" in sys.modules: + _rbson = sys.modules["bson._rbson"] + if hasattr(_rbson, "_bson_to_dict"): + _HAS_RUST = True + else: + _spec = importlib.util.find_spec("bson._rbson") + if _spec and _spec.loader: + _rbson = importlib.util.module_from_spec(_spec) + _spec.loader.exec_module(_rbson) + if hasattr(_rbson, "_bson_to_dict"): + _HAS_RUST = True + else: + _rbson = None +except (ImportError, AttributeError): + pass + +# Clean up the spec variable to avoid polluting the module namespace +del _spec + +# Determine which extension to use at runtime +# Priority: PYMONGO_USE_RUST env var > C extension (default) > pure Python +_USE_RUST_RUNTIME = os.environ.get("PYMONGO_USE_RUST", "").lower() in ("1", "true", "yes") + +# Decide which extension to actually use +_USE_C = False +_USE_RUST = False + +if _USE_RUST_RUNTIME: + if _HAS_RUST: + # User requested Rust and it's available - use Rust, not C + _USE_RUST = True + elif _HAS_C: + # User requested Rust but it's not available - warn and use C + import warnings + + warnings.warn( + "PYMONGO_USE_RUST is set but Rust extension is not available. " + "Falling back to C extension.", + stacklevel=2, + ) + _USE_C = True +else: + # User didn't request Rust - use C by default if available + if _HAS_C: + _USE_C = True __all__ = [ "ALL_UUID_SUBTYPES", @@ -209,6 +277,8 @@ "is_valid", "BSON", "has_c", + "has_rust", + "get_bson_implementation", "DatetimeConversion", "DatetimeMS", ] @@ -543,7 +613,7 @@ def _element_to_dict( ) -> Tuple[str, Any, int]: return cast( "Tuple[str, Any, int]", - _cbson._element_to_dict(data, position, obj_end, opts, raw_array), + _cbson._element_to_dict(data, position, obj_end, opts, raw_array), # type: ignore[union-attr] ) else: @@ -634,8 +704,13 @@ def _bson_to_dict(data: Any, opts: CodecOptions[_DocumentType]) -> _DocumentType raise InvalidBSON(str(exc_value)).with_traceback(exc_tb) from None -if _USE_C: - _bson_to_dict = _cbson._bson_to_dict +# Save reference to Python implementation before overriding +_bson_to_dict_python = _bson_to_dict + +if _USE_RUST: + _bson_to_dict = _rbson._bson_to_dict # type: ignore[union-attr] +elif _USE_C: + _bson_to_dict = _cbson._bson_to_dict # type: ignore[union-attr] _PACK_FLOAT = struct.Struct(" lis if _USE_C: - _decode_all = _cbson._decode_all + _decode_all = _cbson._decode_all # type: ignore[union-attr] @overload @@ -1223,7 +1300,7 @@ def _array_of_documents_to_buffer(data: Union[memoryview, bytes]) -> bytes: if _USE_C: - _array_of_documents_to_buffer = _cbson._array_of_documents_to_buffer + _array_of_documents_to_buffer = _cbson._array_of_documents_to_buffer # type: ignore[union-attr] def _convert_raw_document_lists_to_streams(document: Any) -> None: @@ -1470,7 +1547,30 @@ def decode( # type:ignore[override] def has_c() -> bool: """Is the C extension installed?""" - return _USE_C + return _HAS_C + + +def has_rust() -> bool: + """Is the Rust extension installed? + + .. versionadded:: 5.0 + """ + return _HAS_RUST + + +def get_bson_implementation() -> str: + """Get the name of the BSON implementation being used. + + Returns one of: 'rust', 'c', or 'python'. + + .. versionadded:: 5.0 + """ + if _USE_RUST: + return "rust" + elif _USE_C: + return "c" + else: + return "python" def _after_fork() -> None: diff --git a/bson/_cbsonmodule.c b/bson/_cbsonmodule.c index 7d184641c5..034490f558 100644 --- a/bson/_cbsonmodule.c +++ b/bson/_cbsonmodule.c @@ -356,7 +356,8 @@ static PyObject* datetime_ms_from_millis(PyObject* self, long long millis){ if (!(ll_millis = PyLong_FromLongLong(millis))){ return NULL; } - dt = PyObject_CallFunctionObjArgs(state->DatetimeMS, ll_millis, NULL); + PyObject* args[1] = {ll_millis}; + dt = PyObject_Vectorcall(state->DatetimeMS, args, 1, NULL); Py_DECREF(ll_millis); return dt; } @@ -401,7 +402,9 @@ static PyObject* decode_datetime(PyObject* self, long long millis, const codec_o int64_t min_millis_offset = 0; int64_t max_millis_offset = 0; if (options->tz_aware && options->tzinfo && options->tzinfo != Py_None) { - PyObject* utcoffset = PyObject_CallMethodObjArgs(options->tzinfo, state->_utcoffset_str, state->min_datetime, NULL); + PyObject* utcoffset_args[2] = {options->tzinfo, state->min_datetime}; + PyObject* utcoffset = PyObject_VectorcallMethod( + state->_utcoffset_str, utcoffset_args, 2, NULL); if (utcoffset == NULL) { return 0; } @@ -420,7 +423,9 @@ static PyObject* decode_datetime(PyObject* self, long long millis, const codec_o (PyDateTime_DELTA_GET_MICROSECONDS(utcoffset) / 1000); } Py_DECREF(utcoffset); - utcoffset = PyObject_CallMethodObjArgs(options->tzinfo, state->_utcoffset_str, state->max_datetime, NULL); + utcoffset_args[1] = state->max_datetime; + utcoffset = PyObject_VectorcallMethod( + state->_utcoffset_str, utcoffset_args, 2, NULL); if (utcoffset == NULL) { return 0; } @@ -481,7 +486,9 @@ static PyObject* decode_datetime(PyObject* self, long long millis, const codec_o /* convert to local time */ if (options->tzinfo != Py_None) { - PyObject* temp = PyObject_CallMethodObjArgs(value, state->_astimezone_str, options->tzinfo, NULL); + PyObject* astimezone_args[2] = {value, options->tzinfo}; + PyObject* temp = PyObject_VectorcallMethod( + state->_astimezone_str, astimezone_args, 2, NULL); Py_DECREF(value); value = temp; } @@ -688,7 +695,8 @@ static int _load_python_objects(PyObject* module) { return 1; } - compiled = PyObject_CallFunction(re_compile, "O", empty_string); + PyObject* compile_args[1] = {empty_string}; + compiled = PyObject_Vectorcall(re_compile, compile_args, 1, NULL); Py_DECREF(re_compile); if (compiled == NULL) { state->REType = NULL; @@ -711,13 +719,19 @@ static long _type_marker(PyObject* object, PyObject* _type_marker_str) { PyObject* type_marker = NULL; long type = 0; - if (PyObject_HasAttr(object, _type_marker_str)) { - type_marker = PyObject_GetAttr(object, _type_marker_str); - if (type_marker == NULL) { + #if PY_VERSION_HEX >= 0x030D0000 + // 3.13 + if (PyObject_GetOptionalAttr(object, _type_marker_str, &type_marker) == -1) { return -1; } - } - + # else + if (PyObject_HasAttr(object, _type_marker_str)) { + type_marker = PyObject_GetAttr(object, _type_marker_str); + if (type_marker == NULL) { + return -1; + } + } + #endif /* * Python objects with broken __getattr__ implementations could return * arbitrary types for a call to PyObject_GetAttrString. For example @@ -814,6 +828,7 @@ int convert_codec_options(PyObject* self, PyObject* options_obj, codec_options_t } options->is_raw_bson = (101 == type_marker); + options->is_dict_class = (options->document_class == (PyObject*)&PyDict_Type); options->options_obj = options_obj; Py_INCREF(options->options_obj); @@ -1013,10 +1028,20 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer, } /* * Use _type_marker attribute instead of PyObject_IsInstance for better perf. + * + * Skip _type_marker lookup for common built-in types + * that we know don't have a _type_marker attribute. This avoids the overhead + * of PyObject_HasAttr/PyObject_GetAttr calls for the most common cases. */ - type = _type_marker(value, state->_type_marker_str); - if (type < 0) { - return 0; + if (PyUnicode_CheckExact(value) || PyLong_CheckExact(value) || PyFloat_CheckExact(value) || + PyBool_Check(value) || PyDict_CheckExact(value) || PyList_CheckExact(value) || + PyTuple_CheckExact(value) || PyBytes_CheckExact(value) || value == Py_None) { + type = 0; + } else { + type = _type_marker(value, state->_type_marker_str); + if (type < 0) { + return 0; + } } switch (type) { @@ -1227,7 +1252,9 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer, case 100: { /* DBRef */ - PyObject* as_doc = PyObject_CallMethodObjArgs(value, state->_as_doc_str, NULL); + PyObject* as_doc_args[1] = {value}; + PyObject* as_doc = PyObject_VectorcallMethod( + state->_as_doc_str, as_doc_args, 1, NULL); if (!as_doc) { return 0; } @@ -1383,7 +1410,9 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer, return write_unicode(buffer, value); } else if (PyDateTime_Check(value)) { long long millis; - PyObject* utcoffset = PyObject_CallMethodObjArgs(value, state->_utcoffset_str , NULL); + PyObject* utcoffset_args[1] = {value}; + PyObject* utcoffset = PyObject_VectorcallMethod( + state->_utcoffset_str, utcoffset_args, 1, NULL); if (utcoffset == NULL) return 0; if (utcoffset != Py_None) { @@ -1422,7 +1451,9 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer, if (!(uuid_rep_obj = PyLong_FromLong(options->uuid_rep))) { return 0; } - binary_value = PyObject_CallMethodObjArgs(state->Binary, state->_from_uuid_str, value, uuid_rep_obj, NULL); + PyObject* from_uuid_args[3] = {state->Binary, value, uuid_rep_obj}; + binary_value = PyObject_VectorcallMethod( + state->_from_uuid_str, from_uuid_args, 3, NULL); Py_DECREF(uuid_rep_obj); if (binary_value == NULL) { @@ -1452,7 +1483,8 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer, if (converter != NULL) { /* Transform types that have a registered converter. * A new reference is created upon transformation. */ - new_value = PyObject_CallFunctionObjArgs(converter, value, NULL); + PyObject* converter_args[1] = {value}; + new_value = PyObject_Vectorcall(converter, converter_args, 1, NULL); if (new_value == NULL) { return 0; } @@ -1466,8 +1498,9 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer, /* Try the fallback encoder if one is provided and we have not already * attempted to use the fallback encoder. */ if (!in_fallback_call && options->type_registry.has_fallback_encoder) { - new_value = PyObject_CallFunctionObjArgs( - options->type_registry.fallback_encoder, value, NULL); + PyObject* fallback_args[1] = {value}; + new_value = PyObject_Vectorcall( + options->type_registry.fallback_encoder, fallback_args, 1, NULL); if (new_value == NULL) { // propagate any exception raised by the callback return 0; @@ -1668,7 +1701,8 @@ void handle_invalid_doc_error(PyObject* dict) { goto cleanup; } // Add doc to the error instance as a property. - new_evalue = PyObject_CallFunctionObjArgs(InvalidDocument, new_msg, dict, NULL); + PyObject* exc_args[2] = {new_msg, dict}; + new_evalue = PyObject_Vectorcall(InvalidDocument, exc_args, 2, NULL); Py_DECREF(evalue); Py_DECREF(etype); etype = InvalidDocument; @@ -1944,7 +1978,8 @@ static PyObject *_dbref_hook(PyObject* self, PyObject* value) { PyMapping_DelItem(value, state->_dollar_db_str); } - ret = PyObject_CallFunctionObjArgs(state->DBRef, ref, id, database, value, NULL); + PyObject* dbref_args[4] = {ref, id, database, value}; + ret = PyObject_Vectorcall(state->DBRef, dbref_args, 4, NULL); Py_DECREF(value); } else { ret = value; @@ -2160,7 +2195,13 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer, goto uuiderror; } - binary_value = PyObject_CallFunction(state->Binary, "(Oi)", data, subtype); + PyObject* subtype_obj = PyLong_FromLong(subtype); + if (!subtype_obj) { + goto uuiderror; + } + PyObject* binary_args[2] = {data, subtype_obj}; + binary_value = PyObject_Vectorcall(state->Binary, binary_args, 2, NULL); + Py_DECREF(subtype_obj); if (binary_value == NULL) { goto uuiderror; } @@ -2175,7 +2216,9 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer, if (!uuid_rep_obj) { goto uuiderror; } - value = PyObject_CallMethodObjArgs(binary_value, state->_as_uuid_str, uuid_rep_obj, NULL); + PyObject* as_uuid_args[2] = {binary_value, uuid_rep_obj}; + value = PyObject_VectorcallMethod( + state->_as_uuid_str, as_uuid_args, 2, NULL); Py_DECREF(uuid_rep_obj); } @@ -2194,7 +2237,8 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer, Py_DECREF(data); goto invalid; } - value = PyObject_CallFunctionObjArgs(state->Binary, data, st, NULL); + PyObject* binary_args[2] = {data, st}; + value = PyObject_Vectorcall(state->Binary, binary_args, 2, NULL); Py_DECREF(st); Py_DECREF(data); if (!value) { @@ -2215,7 +2259,13 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer, if (max < 12) { goto invalid; } - value = PyObject_CallFunction(state->ObjectId, "y#", buffer + *position, (Py_ssize_t)12); + PyObject* oid_bytes = PyBytes_FromStringAndSize(buffer + *position, 12); + if (!oid_bytes) { + goto invalid; + } + PyObject* oid_args[1] = {oid_bytes}; + value = PyObject_Vectorcall(state->ObjectId, oid_args, 1, NULL); + Py_DECREF(oid_bytes); *position += 12; break; } @@ -2294,7 +2344,14 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer, } *position += (unsigned)flags_length + 1; - value = PyObject_CallFunction(state->Regex, "Oi", pattern, flags); + PyObject* flags_obj = PyLong_FromLong(flags); + if (!flags_obj) { + Py_DECREF(pattern); + goto invalid; + } + PyObject* regex_args[2] = {pattern, flags_obj}; + value = PyObject_Vectorcall(state->Regex, regex_args, 2, NULL); + Py_DECREF(flags_obj); Py_DECREF(pattern); break; } @@ -2327,13 +2384,21 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer, } *position += coll_length; - id = PyObject_CallFunction(state->ObjectId, "y#", buffer + *position, (Py_ssize_t)12); + PyObject* oid_bytes = PyBytes_FromStringAndSize(buffer + *position, 12); + if (!oid_bytes) { + Py_DECREF(collection); + goto invalid; + } + PyObject* oid_args[1] = {oid_bytes}; + id = PyObject_Vectorcall(state->ObjectId, oid_args, 1, NULL); + Py_DECREF(oid_bytes); if (!id) { Py_DECREF(collection); goto invalid; } *position += 12; - value = PyObject_CallFunctionObjArgs(state->DBRef, collection, id, NULL); + PyObject* dbref_args[2] = {collection, id}; + value = PyObject_Vectorcall(state->DBRef, dbref_args, 2, NULL); Py_DECREF(collection); Py_DECREF(id); break; @@ -2363,7 +2428,8 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer, goto invalid; } *position += value_length; - value = PyObject_CallFunctionObjArgs(state->Code, code, NULL, NULL); + PyObject* code_args[1] = {code}; + value = PyObject_Vectorcall(state->Code, code_args, 1, NULL); Py_DECREF(code); break; } @@ -2429,7 +2495,8 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer, } *position += scope_size; - value = PyObject_CallFunctionObjArgs(state->Code, code, scope, NULL); + PyObject* code_scope_args[2] = {code, scope}; + value = PyObject_Vectorcall(state->Code, code_scope_args, 2, NULL); Py_DECREF(code); Py_DECREF(scope); break; @@ -2459,7 +2526,19 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer, memcpy(&time, buffer + *position + 4, 4); inc = BSON_UINT32_FROM_LE(inc); time = BSON_UINT32_FROM_LE(time); - value = PyObject_CallFunction(state->Timestamp, "II", time, inc); + PyObject* time_obj = PyLong_FromUnsignedLong(time); + if (!time_obj) { + goto invalid; + } + PyObject* inc_obj = PyLong_FromUnsignedLong(inc); + if (!inc_obj) { + Py_DECREF(time_obj); + goto invalid; + } + PyObject* ts_args[2] = {time_obj, inc_obj}; + value = PyObject_Vectorcall(state->Timestamp, ts_args, 2, NULL); + Py_DECREF(time_obj); + Py_DECREF(inc_obj); *position += 8; break; } @@ -2471,7 +2550,13 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer, } memcpy(&ll, buffer + *position, 8); ll = (int64_t)BSON_UINT64_FROM_LE(ll); - value = PyObject_CallFunction(state->BSONInt64, "L", ll); + PyObject* ll_obj = PyLong_FromLongLong(ll); + if (!ll_obj) { + goto invalid; + } + PyObject* int64_args[1] = {ll_obj}; + value = PyObject_Vectorcall(state->BSONInt64, int64_args, 1, NULL); + Py_DECREF(ll_obj); *position += 8; break; } @@ -2484,19 +2569,21 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer, if (!_bytes_obj) { goto invalid; } - value = PyObject_CallMethodObjArgs(state->Decimal128, state->_from_bid_str, _bytes_obj, NULL); + PyObject* dec128_args[2] = {state->Decimal128, _bytes_obj}; + value = PyObject_VectorcallMethod( + state->_from_bid_str, dec128_args, 2, NULL); Py_DECREF(_bytes_obj); *position += 16; break; } case 255: { - value = PyObject_CallFunctionObjArgs(state->MinKey, NULL); + value = PyObject_Vectorcall(state->MinKey, NULL, 0, NULL); break; } case 127: { - value = PyObject_CallFunctionObjArgs(state->MaxKey, NULL); + value = PyObject_Vectorcall(state->MaxKey, NULL, 0, NULL); break; } default: @@ -2548,7 +2635,8 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer, } converter = PyDict_GetItem(options->type_registry.decoder_map, value_type); if (converter != NULL) { - PyObject* new_value = PyObject_CallFunctionObjArgs(converter, value, NULL); + PyObject* converter_args[1] = {value}; + PyObject* new_value = PyObject_Vectorcall(converter, converter_args, 1, NULL); Py_DECREF(value_type); Py_DECREF(value); return new_value; @@ -2716,11 +2804,20 @@ static PyObject* _elements_to_dict(PyObject* self, const char* string, unsigned max, const codec_options_t* options) { unsigned position = 0; - PyObject* dict = PyObject_CallObject(options->document_class, NULL); + PyObject* dict; + int raw_array = 0; + + /* Use PyDict_New() directly when document_class is dict. + * This avoids the overhead of PyObject_CallObject() for the common case. */ + if (options->is_dict_class) { + dict = PyDict_New(); + } else { + dict = PyObject_CallObject(options->document_class, NULL); + } if (!dict) { return NULL; } - int raw_array = 0; + while (position < max) { PyObject* name = NULL; PyObject* value = NULL; @@ -2735,7 +2832,24 @@ static PyObject* _elements_to_dict(PyObject* self, const char* string, position = (unsigned)new_position; } - PyObject_SetItem(dict, name, value); + /* Use PyDict_SetItem() when document_class is dict. + * PyDict_SetItem() is faster than PyObject_SetItem() because it + * avoids method lookup overhead. */ + if (options->is_dict_class) { + if (PyDict_SetItem(dict, name, value) < 0) { + Py_DECREF(name); + Py_DECREF(value); + Py_DECREF(dict); + return NULL; + } + } else { + if (PyObject_SetItem(dict, name, value) < 0) { + Py_DECREF(name); + Py_DECREF(value); + Py_DECREF(dict); + return NULL; + } + } Py_DECREF(name); Py_DECREF(value); } @@ -2747,9 +2861,14 @@ static PyObject* elements_to_dict(PyObject* self, const char* string, const codec_options_t* options) { PyObject* result; if (options->is_raw_bson) { - return PyObject_CallFunction( - options->document_class, "y#O", - string, max, options->options_obj); + PyObject* bson_bytes = PyBytes_FromStringAndSize(string, max); + if (!bson_bytes) { + return NULL; + } + PyObject* raw_args[2] = {bson_bytes, options->options_obj}; + result = PyObject_Vectorcall(options->document_class, raw_args, 2, NULL); + Py_DECREF(bson_bytes); + return result; } if (Py_EnterRecursiveCall(" while decoding a BSON document")) return NULL; diff --git a/bson/_cbsonmodule.h b/bson/_cbsonmodule.h index 3be2b74427..a9bee24b8d 100644 --- a/bson/_cbsonmodule.h +++ b/bson/_cbsonmodule.h @@ -72,6 +72,7 @@ typedef struct codec_options_t { unsigned char datetime_conversion; PyObject* options_obj; unsigned char is_raw_bson; + unsigned char is_dict_class; } codec_options_t; /* C API functions */ diff --git a/bson/_rbson/Cargo.toml b/bson/_rbson/Cargo.toml new file mode 100644 index 0000000000..05ea598953 --- /dev/null +++ b/bson/_rbson/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "bson-rbson" +version = "0.1.0" +edition = "2021" + +[lib] +name = "_rbson" +crate-type = ["cdylib"] + +[dependencies] +pyo3 = { version = "0.23", features = ["extension-module", "abi3-py39"] } +bson = "2.13" +serde = "1.0" +once_cell = "1.20" + +[profile.release] +opt-level = 3 +lto = true +codegen-units = 1 +strip = true diff --git a/bson/_rbson/README.md b/bson/_rbson/README.md new file mode 100644 index 0000000000..69e1e0e166 --- /dev/null +++ b/bson/_rbson/README.md @@ -0,0 +1,441 @@ +# Rust BSON Extension Module + +⚠️ **NOT PRODUCTION READY** - This is an experimental implementation with incomplete feature support and performance limitations. See [Test Status](#test-status) and [Performance Analysis](#performance-analysis) sections below. + +This directory contains a Rust-based implementation of BSON encoding/decoding for PyMongo, developed as part of [PYTHON-5683](https://jira.mongodb.org/browse/PYTHON-5683). + +## Overview + +The Rust extension (`_rbson`) provides a **partial implementation** of the C extension (`_cbson`) interface, implemented in Rust using: +- **PyO3**: Python bindings for Rust +- **bson crate**: MongoDB's official Rust BSON library +- **Maturin**: Build tool for Rust Python extensions + +## Test Status + +### ✅ Core BSON Tests: 86 passed, 2 skipped +The basic BSON encoding/decoding functionality works correctly (`test/test_bson.py`). + +### ⏭️ Skipped Tests: ~85 tests across multiple test files +The following features are **not implemented** and tests are skipped when using the Rust extension: + +#### Custom Type Encoders (test/test_custom_types.py) +- **`TypeEncoder` and `TypeRegistry`** - Custom type encoding/decoding +- **`FallbackEncoder`** - Fallback encoding for unknown types +- **Tests skipped**: All tests in `TestBSONFallbackEncoder`, `TestCustomPythonBSONTypeToBSONMonolithicCodec`, `TestCustomPythonBSONTypeToBSONMultiplexedCodec` +- **Reason**: Rust extension doesn't support custom type encoders or fallback encoders + +#### RawBSONDocument (test/test_raw_bson.py) +- **`RawBSONDocument` codec options** - Raw BSON document handling +- **Tests skipped**: All tests in `TestRawBSONDocument` +- **Reason**: Rust extension doesn't implement RawBSONDocument codec options + +#### DBRef Edge Cases (test/test_dbref.py) +- **DBRef validation and edge cases** +- **Tests skipped**: Some DBRef tests +- **Reason**: Incomplete DBRef handling in Rust extension + +#### Type Checking (test/test_typing.py) +- **Type hints and mypy validation** +- **Tests skipped**: Some typing tests +- **Reason**: Type checking issues with Rust extension + +### Skip Mechanism +Tests are skipped using the `@skip_if_rust_bson` pytest marker defined in `test/__init__.py`: +```python +skip_if_rust_bson = pytest.mark.skipif( + _use_rust_bson(), reason="Rust BSON extension does not support this feature" +) +``` + +This marker is applied to test classes and methods that use unimplemented features. + +## Implementation History + +This implementation was developed through [PR #2695](https://github.com/mongodb/mongo-python-driver/pull/2695) to investigate using Rust as an alternative to C for Python extension modules. + +### Key Milestones + +1. **Initial Implementation** - Basic BSON type support with core functionality +2. **Performance Optimizations** - Type caching, fast paths for common types, direct byte operations +3. **Modular Refactoring** - Split monolithic lib.rs into 6 well-organized modules +4. **Test Integration** - Added skip markers for unimplemented features (~85 tests skipped) + +## Features + +### Supported BSON Types + +The Rust extension supports basic BSON types: +- **Primitives**: Double, String, Int32, Int64, Boolean, Null +- **Complex Types**: Document, Array, Binary, ObjectId, DateTime +- **Special Types**: Regex, Code, Timestamp, Decimal128, MinKey, MaxKey +- **Deprecated Types**: DBPointer (decodes to DBRef) + +### CodecOptions Support + +**Partial** support for PyMongo's `CodecOptions`: +- ✅ `document_class` - Custom document classes (basic support) +- ✅ `tz_aware` - Timezone-aware datetime handling +- ✅ `tzinfo` - Timezone conversion +- ✅ `uuid_representation` - UUID encoding/decoding modes +- ✅ `datetime_conversion` - DateTime handling modes (AUTO, CLAMP, MS) +- ✅ `unicode_decode_error_handler` - UTF-8 error handling +- ❌ `type_registry` - Custom type encoders/decoders (NOT IMPLEMENTED) +- ❌ RawBSONDocument support (NOT IMPLEMENTED) + +### Runtime Selection + +The Rust extension can be enabled via environment variable: +```bash +export PYMONGO_USE_RUST=1 +python your_script.py +``` + +Without this variable, PyMongo uses the C extension by default. + +## Performance Analysis + +### Current Performance: ~0.21x (5x slower than C) + +**Benchmark Results** (from PR #2695): +``` +Simple documents: C: 100% | Rust: 21% +Mixed types: C: 100% | Rust: 20% +Nested documents: C: 100% | Rust: 18% +Lists: C: 100% | Rust: 22% +``` + +### Root Cause: Architectural Difference + +The performance gap is due to a fundamental architectural difference: + +**C Extension Architecture:** +``` +Python objects → BSON bytes (direct) +``` +- Writes BSON bytes directly from Python objects +- No intermediate data structures +- Minimal memory allocations + +**Rust Extension Architecture:** +``` +Python objects → Rust Bson enum → BSON bytes +``` +- Converts Python objects to Rust `Bson` enum +- Then serializes `Bson` to bytes +- Extra conversion layer adds overhead + +### Optimization Attempts + +Multiple optimization strategies were attempted in PR #2695: + +1. **Type Caching** - Cache frequently used Python types (UUID, datetime, etc.) +2. **Fast Paths** - Special handling for common types (int, str, bool, None) +3. **Direct Byte Writing** - Write BSON bytes directly without intermediate `Document` +4. **PyDict Fast Path** - Use `PyDict_Next` for efficient dict iteration + +**Result**: These optimizations improved performance from ~0.15x to ~0.21x, but the fundamental architectural difference remains. + +## Comparison with Copilot POC (PR #2689) + +The current implementation evolved significantly from the initial Copilot-generated proof-of-concept in PR #2689: + +### Copilot POC (PR #2689) - Initial Spike +**Status**: 53/88 tests passing (60%) + +**Build System**: `cargo build --release` (manual copy of .so file) +- Used raw `cargo` commands +- Manual file copying to project root +- No wheel generation +- Located in `rust/` directory + +**What it had:** +- ✅ Basic BSON type support (int, float, string, bool, bytes, dict, list, null) +- ✅ ObjectId, DateTime, Regex encoding/decoding +- ✅ Binary, Code, Timestamp, Decimal128, MinKey, MaxKey support +- ✅ DBRef and DBPointer decoding +- ✅ Int64 type marker support +- ✅ Basic CodecOptions (tz_aware, uuid_representation) +- ✅ Buffer protocol support (memoryview, array) +- ✅ _id field ordering at top level +- ✅ Benchmark scripts and performance analysis +- ✅ Comprehensive documentation (RUST_SPIKE_RESULTS.md) +- ✅ **Same Rust architecture**: PyO3 0.27 + bson 2.13 crate (Python → Bson enum → bytes) + +**What it lacked:** +- ❌ Only 60% test pass rate (53/88 tests) +- ❌ Incomplete datetime handling (no DATETIME_CLAMP, DATETIME_AUTO, DATETIME_MS modes) +- ❌ Missing unicode_decode_error_handler support +- ❌ No document_class support from CodecOptions +- ❌ No tzinfo conversion support +- ❌ Missing BSON validation (size checks, null terminator) +- ❌ No performance optimizations (type caching, fast paths) +- ❌ Located in `rust/` directory instead of `bson/_rbson/` + +**Performance Claims**: 2.89x average speedup over C (from benchmarks in POC) + +**Why the POC appeared faster:** +The Copilot POC's claimed 2.89x speedup was likely due to: +1. **Limited test scope** - Benchmarks only tested simple documents that passed (53/88 tests) +2. **Missing validation** - No BSON size checks, null terminator validation, or extra bytes detection +3. **Incomplete CodecOptions** - Skipped expensive operations like: + - Timezone conversions (`tzinfo` with `astimezone()`) + - DateTime mode handling (CLAMP, AUTO, MS) + - Unicode error handler fallbacks to Python + - Custom document_class instantiation +4. **Optimistic measurements** - May have measured only the fast path without edge cases +5. **Different test methodology** - POC used custom benchmarks vs production testing with full PyMongo test suite + +When these missing features were added to achieve 100% compatibility, the true performance cost of the Rust `Bson` enum architecture became apparent. + +### Current Implementation (PR #2695) - Experimental +**Status**: 86/88 core BSON tests passing, ~85 feature tests skipped + +**Build System**: `maturin build --release` (proper wheel generation) +- Uses Maturin for proper Python packaging +- Generates wheels with correct metadata +- Extracts .so file to `bson/` directory +- Located in `bson/_rbson/` directory (proper module structure) + +**Improvements over Copilot POC:** +- ✅ **Core BSON functionality** (86/88 tests passing in test_bson.py) +- ✅ **Basic CodecOptions support**: + - `document_class` - Custom document classes (basic support) + - `tzinfo` - Timezone conversion with astimezone() + - `datetime_conversion` - All modes (AUTO, CLAMP, MS) + - `unicode_decode_error_handler` - Fallback to Python for non-strict handlers +- ✅ **BSON validation** (size checks, null terminator, extra bytes detection) +- ✅ **Performance optimizations**: + - Type caching (UUID, datetime, Pattern, etc.) + - Fast paths for common types (int, str, bool, None) + - Direct byte operations where possible + - PyDict fast path with pre-allocation +- ✅ **Modular code structure** (6 well-organized Rust modules) +- ✅ **Proper module structure** (`bson/_rbson/` with build.sh and maturin) +- ✅ **Runtime selection** via PYMONGO_USE_RUST environment variable +- ✅ **Test skip markers** for unimplemented features +- ✅ **Same Rust architecture**: PyO3 0.23 + bson 2.13 crate (Python → Bson enum → bytes) + +**Missing Features** (see [Test Status](#test-status)): +- ❌ **Custom type encoders** (`TypeEncoder`, `TypeRegistry`, `FallbackEncoder`) +- ❌ **RawBSONDocument** codec options +- ❌ **Some DBRef edge cases** +- ❌ **Complete type checking support** + +**Performance Reality**: ~0.21x (5x slower than C) - see Performance Analysis section + +**Key Insights**: +1. **Same Architecture, Different Results**: Both implementations use the same Rust architecture (PyO3 + bson crate with intermediate `Bson` enum), so the build system (cargo vs maturin) is not the cause of the performance difference. +2. **Incomplete Implementation**: The current implementation has ~85 tests skipped due to unimplemented features (custom type encoders, RawBSONDocument, etc.). This is an experimental implementation, not production-ready. +3. **The Fundamental Issue**: The Rust architecture (Python → Bson enum → bytes) has inherent performance limitations compared to the C extension's direct byte-writing approach. + +## Direct Byte-Writing Performance Results + +### Implementation: `_dict_to_bson_direct()` + +A new implementation has been added that writes BSON bytes directly from Python objects without converting to `Bson` enum types first. This eliminates the intermediate conversion layer. + +**Architecture Comparison:** +``` +Regular: Python objects → Rust Bson enum → BSON bytes +Direct: Python objects → BSON bytes (no intermediate types) +``` + +### Benchmark Results + +Comprehensive benchmarks on realistic document types show **consistent 2x speedup**: + +| Document Type | Regular (ops/sec) | Direct (ops/sec) | Speedup | +|--------------|-------------------|------------------|---------| +| User Profile | 99,970 | 208,658 | **2.09x** | +| E-commerce Order | 93,578 | 165,636 | **1.77x** | +| IoT Sensor Data | 136,824 | 312,058 | **2.28x** | +| Blog Post | 65,782 | 134,154 | **2.04x** | + +**Average Speedup: 2.04x** (range: 1.77x - 2.28x) + +### Performance by Document Composition + +| Document Type | Regular (ops/sec) | Direct (ops/sec) | Speedup | +|--------------|-------------------|------------------|---------| +| Simple types (int, str, float, bool, None) | 177,588 | 800,670 | **4.51x** | +| Mixed types | 223,856 | 342,305 | **1.53x** | +| Nested documents | 130,884 | 287,758 | **2.20x** | +| BSON-specific types only | 342,059 | 304,844 | 0.89x | + +### Key Findings + +1. **Massive speedup for simple types**: 4.51x faster for documents with Python native types +2. **Consistent 2x improvement for real-world documents**: All realistic mixed-type documents show 1.77x - 2.28x speedup +3. **Slight slowdown for pure BSON types**: Documents with only BSON-specific types (ObjectId, Binary, etc.) are 10% slower due to extra Python attribute lookups +4. **100% correctness**: All outputs verified to be byte-identical to the regular implementation + +### Why Direct Byte-Writing is Faster + +1. **Eliminates heap allocations**: No need to create intermediate `Bson` enum values +2. **Reduces function call overhead**: Writes bytes immediately instead of going through `python_to_bson()` → `write_bson_value()` +3. **Better for common types**: Python's native types (int, str, float, bool) can be written directly without any conversion + +### Implementation Details + +The direct approach is implemented in these functions: +- `_dict_to_bson_direct()` - Public API function +- `write_document_bytes_direct()` - Writes document structure directly +- `write_element_direct()` - Writes individual elements without Bson conversion +- `write_bson_type_direct()` - Handles BSON-specific types directly + +### Usage + +```python +from bson import _rbson +from bson.codec_options import DEFAULT_CODEC_OPTIONS + +# Use direct byte-writing approach +doc = {"name": "John", "age": 30, "score": 95.5} +bson_bytes = _rbson._dict_to_bson_direct(doc, False, DEFAULT_CODEC_OPTIONS) +``` + + + +## Steps to Achieve Performance Parity with C Extensions + +Based on the analysis in PR #2695 and the direct byte-writing results, here are the steps needed to match C extension performance: + +### 1. ✅ Eliminate Intermediate Bson Enum (High Impact) - COMPLETED +**Current**: Python → Bson → bytes +**Target**: Python → bytes (direct) + +**Status**: ✅ **Implemented as `_dict_to_bson_direct()`** + +**Actual Impact**: **2.04x average speedup** on realistic documents (range: 1.77x - 2.28x) + +This brings the Rust extension from ~0.21x (5x slower than C) to **~0.43x (2.3x slower than C)** - a significant improvement! + +### 2. Optimize Python API Calls (Medium Impact) +- Reduce `getattr()` calls by caching attribute lookups +- Use `PyDict_GetItem` instead of `dict.get_item()` +- Minimize Python exception handling overhead +- Use `PyTuple_GET_ITEM` for tuple access + +**Estimated Impact**: 1.2-1.5x performance improvement + +### 3. Memory Allocation Optimization (Low-Medium Impact) +- Pre-allocate buffers based on estimated document size +- Reuse buffers across multiple encode operations +- Use arena allocation for temporary objects + +**Estimated Impact**: 1.1-1.3x performance improvement + +### 4. SIMD Optimizations (Low Impact) +- Use SIMD for byte copying operations +- Vectorize validation checks +- Optimize string encoding/decoding + +**Estimated Impact**: 1.05-1.1x performance improvement + +### Combined Potential (Updated with Direct Byte-Writing Results) +With direct byte-writing implemented: +- **Before**: 0.21x (5x slower than C) +- **After direct byte-writing**: 0.43x (2.3x slower than C) ✅ +- **With all optimizations**: 0.43x × 1.3 × 1.2 × 1.05 = **~0.71x** (1.4x slower than C) +- **Optimistic target**: Could potentially reach **~0.9x - 1.0x** (parity with C) + +The direct byte-writing approach has already delivered the largest performance gain (2x). Additional optimizations could close the remaining gap to C extension performance. + +## Building + +```bash +cd bson/_rbson +./build.sh +``` + +Or using maturin directly: +```bash +maturin develop --release +``` + +## Testing + +Run the core BSON test suite with the Rust extension: +```bash +PYMONGO_USE_RUST=1 python -m pytest test/test_bson.py -v +# Expected: 86 passed, 2 skipped +``` + +Run all tests (including skipped tests): +```bash +PYMONGO_USE_RUST=1 python -m pytest test/ -v +# Expected: Many tests passed, ~85 tests skipped due to unimplemented features +``` + +Run performance benchmarks: +```bash +# Quick benchmark run +FASTBENCH=1 python test/performance/perf_test.py -v + +# With Rust extension enabled +PYMONGO_USE_RUST=1 FASTBENCH=1 python test/performance/perf_test.py -v + +# Full benchmark setup (see test/performance/perf_test.py for details) +python -m pip install simplejson +git clone --depth 1 https://github.com/mongodb/specifications.git +cd specifications/source/benchmarking/data +tar xf extended_bson.tgz +tar xf parallel.tgz +tar xf single_and_multi_document.tgz +cd - +export TEST_PATH="specifications/source/benchmarking/data" +export OUTPUT_FILE="results.json" +python test/performance/perf_test.py -v +``` + +## Module Structure + +The Rust codebase is organized into 6 well-structured modules (refactored from a single 3,117-line file): + +- **`lib.rs`** (76 lines) - Module exports and public API +- **`types.rs`** (266 lines) - Type cache and BSON type markers +- **`errors.rs`** (56 lines) - Error handling utilities +- **`utils.rs`** (154 lines) - Utility functions (datetime, regex, validation) +- **`encode.rs`** (1,545 lines) - BSON encoding functions +- **`decode.rs`** (1,141 lines) - BSON decoding functions + +This modular structure improves: +- Code organization and maintainability +- Compilation times (parallel module compilation) +- Code navigation and testing +- Clear separation of concerns + +## Conclusion + +The Rust extension demonstrates that: +1. ✅ **Rust can provide basic BSON encoding/decoding functionality** +2. ❌ **Complete feature parity with C extension is not achieved** (~85 tests skipped) +3. ❌ **Performance parity with C requires bypassing the `bson` crate** +4. ❌ **The engineering effort may not justify the benefits** + +### Recommendation + +⚠️ **NOT PRODUCTION READY** - The Rust extension is **experimental** and has significant limitations: + +**Missing Features:** +- Custom type encoders (`TypeEncoder`, `TypeRegistry`, `FallbackEncoder`) +- RawBSONDocument codec options +- Some DBRef edge cases +- Complete type checking support + +**Performance Issues:** +- ~5x slower than C extension (0.21x performance) +- Even with direct byte-writing optimizations, still ~2.3x slower (0.43x performance) + +**Use Cases for Rust Extension:** +- **Experimental/research purposes only** +- Testing Rust-Python interop with PyO3 +- Platforms where C compilation is difficult (with caveats about missing features) +- Future exploration if `bson` crate performance improves + +**For production use, the C extension (`_cbson`) is strongly recommended.** + +For more details, see: +- [PYTHON-5683 JIRA ticket](https://jira.mongodb.org/browse/PYTHON-5683) +- [PR #2695](https://github.com/mongodb/mongo-python-driver/pull/2695) diff --git a/bson/_rbson/build.sh b/bson/_rbson/build.sh new file mode 100755 index 0000000000..af73121cb1 --- /dev/null +++ b/bson/_rbson/build.sh @@ -0,0 +1,84 @@ +#!/bin/bash +# Build script for Rust BSON extension POC +# +# This script builds the Rust extension and makes it available for testing +# alongside the existing C extension. +set -eu + +HERE=$(dirname ${BASH_SOURCE:-$0}) +HERE="$( cd -- "$HERE" > /dev/null 2>&1 && pwd )" +BSON_DIR=$(dirname "$HERE") + +echo "=== Building Rust BSON Extension POC ===" +echo "" + +# Check if Rust is installed +if ! command -v cargo &>/dev/null; then + echo "Error: Rust is not installed" + echo "" + echo "Install Rust with:" + echo " curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh" + echo "" + exit 1 +fi + +echo "Rust toolchain found: $(rustc --version)" + +# Check if maturin is installed +if ! command -v maturin &>/dev/null; then + echo "maturin not found, installing..." + pip install maturin +fi + +echo "maturin found: $(maturin --version)" +echo "" + +# Build the extension +echo "Building Rust extension..." +cd "$HERE" + +# Build wheel to a temporary directory +TEMP_DIR=$(mktemp -d) +trap 'rm -rf "$TEMP_DIR"' EXIT + +maturin build --release --out "$TEMP_DIR" + +# Extract the .so file from the wheel +echo "Extracting extension from wheel..." +WHEEL_FILE=$(ls "$TEMP_DIR"/*.whl | head -1) + +if [ -z "$WHEEL_FILE" ]; then + echo "Error: No wheel file found" + exit 1 +fi + +# Wheels are zip files - extract the .so file +python -c " +import zipfile +import sys +from pathlib import Path + +wheel_path = Path(sys.argv[1]) +bson_dir = Path(sys.argv[2]) + +with zipfile.ZipFile(wheel_path, 'r') as whl: + for name in whl.namelist(): + if name.endswith(('.so', '.pyd')) and '_rbson' in name: + # Extract to bson/ directory + so_data = whl.read(name) + so_name = Path(name).name + target = bson_dir / so_name + target.write_bytes(so_data) + print(f'Installed to {target}') + sys.exit(0) + +print('Error: Could not find .so file in wheel') +sys.exit(1) +" "$WHEEL_FILE" "$BSON_DIR" + +echo "" +echo "Build complete!" +echo "" +echo "Test the extension with:" +echo " python -c 'from bson import _rbson; print(_rbson._test_rust_extension())'" +echo "" diff --git a/bson/_rbson/src/decode.rs b/bson/_rbson/src/decode.rs new file mode 100644 index 0000000000..d9e536a932 --- /dev/null +++ b/bson/_rbson/src/decode.rs @@ -0,0 +1,1140 @@ +// Copyright 2025-present MongoDB, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! BSON decoding functions +//! +//! This module contains all functions for decoding BSON bytes to Python objects. + +use bson::{doc, Bson, Document}; +use pyo3::exceptions::{PyTypeError, PyValueError}; +use pyo3::prelude::*; +use pyo3::types::{IntoPyDict, PyAny, PyBytes, PyDict, PyList, PyString}; +use std::io::Cursor; + +use crate::errors::{invalid_bson_error, invalid_document_error}; +use crate::types::{TYPE_CACHE}; +use crate::utils::{str_flags_to_int}; + +#[pyfunction] +#[pyo3(signature = (data, _codec_options))] +pub fn _bson_to_dict( + py: Python, + data: &Bound<'_, PyAny>, + _codec_options: &Bound<'_, PyAny>, +) -> PyResult> { + let codec_options = Some(_codec_options); + // Accept bytes, bytearray, memoryview, and other buffer protocol objects + // Try to get bytes using the buffer protocol + let bytes: Vec = if let Ok(b) = data.extract::>() { + b + } else if let Ok(bytes_obj) = data.downcast::() { + bytes_obj.as_bytes().to_vec() + } else { + // Try to use buffer protocol for memoryview, array, mmap, etc. + match data.call_method0("__bytes__") { + Ok(bytes_result) => { + if let Ok(bytes_obj) = bytes_result.downcast::() { + bytes_obj.as_bytes().to_vec() + } else { + return Err(PyTypeError::new_err("data must be bytes, bytearray, memoryview, or buffer protocol object")); + } + } + Err(_) => { + // Try tobytes() method (for array.array) + match data.call_method0("tobytes") { + Ok(bytes_result) => { + if let Ok(bytes_obj) = bytes_result.downcast::() { + bytes_obj.as_bytes().to_vec() + } else { + return Err(PyTypeError::new_err("data must be bytes, bytearray, memoryview, or buffer protocol object")); + } + } + Err(_) => { + // Try read() method (for mmap) + match data.call_method0("read") { + Ok(bytes_result) => { + if let Ok(bytes_obj) = bytes_result.downcast::() { + bytes_obj.as_bytes().to_vec() + } else { + return Err(PyTypeError::new_err("data must be bytes, bytearray, memoryview, or buffer protocol object")); + } + } + Err(_) => { + return Err(PyTypeError::new_err("data must be bytes, bytearray, memoryview, or buffer protocol object")); + } + } + } + } + } + } + }; + + // Validate BSON document structure + // Minimum size is 5 bytes (4 bytes for size + 1 byte for null terminator) + if bytes.len() < 5 { + return Err(invalid_bson_error(py, "not enough data for a BSON document".to_string())); + } + + // Check that the size field matches the actual data length + let size = i32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize; + if size != bytes.len() { + if size < bytes.len() { + return Err(invalid_bson_error(py, "bad eoo".to_string())); + } else { + return Err(invalid_bson_error(py, "invalid message size".to_string())); + } + } + + // Check that the document ends with a null terminator + if bytes[bytes.len() - 1] != 0 { + return Err(invalid_bson_error(py, "bad eoo".to_string())); + } + + // Check minimum size + if size < 5 { + return Err(invalid_bson_error(py, "invalid message size".to_string())); + } + + // Extract unicode_decode_error_handler from codec_options + let unicode_error_handler = if let Some(opts) = codec_options { + opts.getattr("unicode_decode_error_handler") + .ok() + .and_then(|h| h.extract::().ok()) + .unwrap_or_else(|| "strict".to_string()) + } else { + "strict".to_string() + }; + + // Try direct byte reading for better performance + // If we encounter an unsupported type, fall back to Document-based approach + match read_document_from_bytes(py, &bytes, 0, codec_options) { + Ok(dict) => return Ok(dict), + Err(e) => { + let error_msg = format!("{}", e); + + // If we got a UTF-8 error and have a non-strict error handler, use Python fallback + if error_msg.contains("utf-8") && unicode_error_handler != "strict" { + let decode_func = TYPE_CACHE.get_bson_to_dict_python(py)?; + let py_data = PyBytes::new_bound(py, &bytes); + let py_opts = if let Some(opts) = codec_options { + opts.clone().into_py(py).into_bound(py) + } else { + py.None().into_bound(py) + }; + return Ok(decode_func.bind(py).call1((py_data, py_opts))?.into()); + } + + // If we got an unsupported type error, fall back to Document-based approach + if error_msg.contains("Unsupported BSON type") || error_msg.contains("Detected unknown BSON type") { + // Fall through to old implementation below + } else { + // For other errors, propagate them + return Err(e); + } + } + } + + // Fallback: Use Document-based approach for documents with unsupported types + let cursor = Cursor::new(&bytes); + let doc_result = Document::from_reader(cursor); + + if let Err(ref e) = doc_result { + let error_msg = format!("{}", e); + if error_msg.contains("utf-8") && unicode_error_handler != "strict" { + let decode_func = TYPE_CACHE.get_bson_to_dict_python(py)?; + let py_data = PyBytes::new_bound(py, &bytes); + let py_opts = if let Some(opts) = codec_options { + opts.clone().into_py(py).into_bound(py) + } else { + py.None().into_bound(py) + }; + return Ok(decode_func.bind(py).call1((py_data, py_opts))?.into()); + } + } + + let doc = doc_result.map_err(|e| { + let error_msg = format!("{}", e); + + // Try to match C extension error format for unknown BSON types + // C extension: "type b'\\x14' for fieldname 'foo'" + // Rust bson: "error at key \"foo\": malformed value: \"invalid tag: 20\"" + if error_msg.contains("invalid tag:") { + // Extract the tag number and field name + if let Some(tag_start) = error_msg.find("invalid tag: ") { + let tag_str = &error_msg[tag_start + 13..]; + if let Some(tag_end) = tag_str.find('"') { + if let Ok(tag_num) = tag_str[..tag_end].parse::() { + if let Some(key_start) = error_msg.find("error at key \"") { + let key_str = &error_msg[key_start + 14..]; + if let Some(key_end) = key_str.find('"') { + let field_name = &key_str[..key_end]; + + // If the field name is numeric (array index), try to find the parent field name + let actual_field_name = if field_name.chars().all(|c| c.is_ascii_digit()) { + // Try to find the parent field name by parsing the BSON + find_parent_field_for_unknown_type(&bytes, tag_num).unwrap_or(field_name) + } else { + field_name + }; + + let formatted_msg = format!("type b'\\x{:02x}' for fieldname '{}'", tag_num, actual_field_name); + return invalid_bson_error(py, formatted_msg); + } + } + } + } + } + } + + invalid_bson_error(py, format!("invalid bson: {}", error_msg)) + })?; + bson_doc_to_python_dict(py, &doc, codec_options) + + // Old path using Document::from_reader (kept as fallback, but not used) + /* + let cursor = Cursor::new(&bytes); + let doc_result = Document::from_reader(cursor); + + // If we got a UTF-8 error and have a non-strict error handler, use Python fallback + if let Err(ref e) = doc_result { + let error_msg = format!("{}", e); + if error_msg.contains("utf-8") && unicode_error_handler != "strict" { + // Use Python's fallback implementation which handles unicode_decode_error_handler + let bson_module = py.import("bson")?; + let decode_func = bson_module.getattr("_bson_to_dict_python")?; + let py_data = PyBytes::new(py, &bytes); + let py_opts = if let Some(opts) = codec_options { + opts.clone().into_py(py).into_bound(py) + } else { + py.None().into_bound(py) + }; + return Ok(decode_func.call1((py_data, py_opts))?.into()); + } + } + */ +} + +/// Process a single item from a mapping's items() iterator + +fn read_document_from_bytes( + py: Python, + bytes: &[u8], + offset: usize, + codec_options: Option<&Bound<'_, PyAny>>, +) -> PyResult> { + read_document_from_bytes_with_parent(py, bytes, offset, codec_options, None) +} + + +fn read_document_from_bytes_with_parent( + py: Python, + bytes: &[u8], + offset: usize, + codec_options: Option<&Bound<'_, PyAny>>, + parent_field_name: Option<&str>, +) -> PyResult> { + // Read document size + if bytes.len() < offset + 4 { + return Err(invalid_bson_error(py, "not enough data for a BSON document".to_string())); + } + + let size = i32::from_le_bytes([ + bytes[offset], + bytes[offset + 1], + bytes[offset + 2], + bytes[offset + 3], + ]) as usize; + + if offset + size > bytes.len() { + return Err(invalid_bson_error(py, "invalid message size".to_string())); + } + + // Get document_class from codec_options, default to dict + let dict: Bound<'_, PyAny> = if let Some(opts) = codec_options { + let document_class = opts.getattr("document_class")?; + document_class.call0()? + } else { + PyDict::new(py).into_any() + }; + + // Read elements + let mut pos = offset + 4; // Skip size field + let end = offset + size - 1; // -1 for null terminator + + // Track if this might be a DBRef (has $ref and $id fields) + let mut has_ref = false; + let mut has_id = false; + + while pos < end { + // Read type byte + let type_byte = bytes[pos]; + pos += 1; + + if type_byte == 0 { + break; // End of document + } + + // Read key (null-terminated string) + let key_start = pos; + while pos < bytes.len() && bytes[pos] != 0 { + pos += 1; + } + + if pos >= bytes.len() { + return Err(invalid_bson_error(py, "invalid bson: unexpected end of data".to_string())); + } + + let key = std::str::from_utf8(&bytes[key_start..pos]) + .map_err(|e| invalid_bson_error(py, format!("invalid bson: invalid UTF-8 in key: {}", e)))?; + + pos += 1; // Skip null terminator + + // Track DBRef fields + if key == "$ref" { + has_ref = true; + } else if key == "$id" { + has_id = true; + } + + // Determine the field name to use for error reporting + // If the key is numeric (array index) and we have a parent field name, use the parent + let error_field_name = if let Some(parent) = parent_field_name { + if key.chars().all(|c| c.is_ascii_digit()) { + parent + } else { + key + } + } else { + key + }; + + // Read value based on type + let (value, new_pos) = read_bson_value(py, bytes, pos, type_byte, codec_options, error_field_name)?; + pos = new_pos; + + dict.set_item(key, value)?; + } + + // Validate that we consumed exactly the right number of bytes + // pos should be at end (which is offset + size - 1) + // and the next byte should be the null terminator + if pos != end { + return Err(invalid_bson_error(py, "invalid length or type code".to_string())); + } + + // Verify null terminator + if bytes[pos] != 0 { + return Err(invalid_bson_error(py, "invalid length or type code".to_string())); + } + + // If this looks like a DBRef, convert it to a DBRef object + if has_ref && has_id { + return convert_dict_to_dbref(py, &dict, codec_options); + } + + Ok(dict.into()) +} + +/// Read a single BSON value from bytes + +fn read_bson_value( + py: Python, + bytes: &[u8], + pos: usize, + type_byte: u8, + codec_options: Option<&Bound<'_, PyAny>>, + field_name: &str, +) -> PyResult<(Py, usize)> { + match type_byte { + 0x01 => { + // Double + if pos + 8 > bytes.len() { + return Err(invalid_bson_error(py, "invalid bson: not enough data for double".to_string())); + } + let value = f64::from_le_bytes([ + bytes[pos], bytes[pos + 1], bytes[pos + 2], bytes[pos + 3], + bytes[pos + 4], bytes[pos + 5], bytes[pos + 6], bytes[pos + 7], + ]); + Ok((value.into_py(py), pos + 8)) + } + 0x02 => { + // String + if pos + 4 > bytes.len() { + return Err(invalid_bson_error(py, "invalid bson: not enough data for string length".to_string())); + } + let str_len = i32::from_le_bytes([ + bytes[pos], bytes[pos + 1], bytes[pos + 2], bytes[pos + 3], + ]) as isize; + + // String length must be at least 1 (for null terminator) + if str_len < 1 { + return Err(invalid_bson_error(py, "invalid bson: bad string length".to_string())); + } + + let str_start = pos + 4; + let str_end = str_start + (str_len as usize) - 1; // -1 for null terminator + + if str_end >= bytes.len() { + return Err(invalid_bson_error(py, "invalid bson: bad string length".to_string())); + } + + // Validate that the null terminator is actually present + if bytes[str_end] != 0 { + return Err(invalid_bson_error(py, "invalid bson: bad string length".to_string())); + } + + let s = std::str::from_utf8(&bytes[str_start..str_end]) + .map_err(|e| invalid_bson_error(py, format!("invalid bson: invalid UTF-8 in string: {}", e)))?; + + Ok((s.into_py(py), str_end + 1)) // +1 to skip null terminator + } + 0x03 => { + // Embedded document + let doc = read_document_from_bytes(py, bytes, pos, codec_options)?; + let size = i32::from_le_bytes([ + bytes[pos], bytes[pos + 1], bytes[pos + 2], bytes[pos + 3], + ]) as usize; + Ok((doc, pos + size)) + } + 0x04 => { + // Array + let arr = read_array_from_bytes(py, bytes, pos, codec_options, field_name)?; + let size = i32::from_le_bytes([ + bytes[pos], bytes[pos + 1], bytes[pos + 2], bytes[pos + 3], + ]) as usize; + Ok((arr, pos + size)) + } + 0x08 => { + // Boolean + if pos >= bytes.len() { + return Err(invalid_bson_error(py, "invalid bson: not enough data for boolean".to_string())); + } + let value = bytes[pos] != 0; + Ok((value.into_py(py), pos + 1)) + } + 0x0A => { + // Null + Ok((py.None(), pos)) + } + 0x10 => { + // Int32 + if pos + 4 > bytes.len() { + return Err(invalid_bson_error(py, "invalid bson: not enough data for int32".to_string())); + } + let value = i32::from_le_bytes([ + bytes[pos], bytes[pos + 1], bytes[pos + 2], bytes[pos + 3], + ]); + Ok((value.into_py(py), pos + 4)) + } + 0x12 => { + // Int64 - return as Int64 type to preserve type information + if pos + 8 > bytes.len() { + return Err(invalid_bson_error(py, "invalid bson: not enough data for int64".to_string())); + } + let value = i64::from_le_bytes([ + bytes[pos], bytes[pos + 1], bytes[pos + 2], bytes[pos + 3], + bytes[pos + 4], bytes[pos + 5], bytes[pos + 6], bytes[pos + 7], + ]); + + // Use cached Int64 class + let int64_class = TYPE_CACHE.get_int64_class(py)?; + let int64_obj = int64_class.bind(py).call1((value,))?; + + Ok((int64_obj.into(), pos + 8)) + } + _ => { + // For unknown BSON types, raise an error with the correct field name + // Match C extension error format: "Detected unknown BSON type b'\xNN' for fieldname 'foo'" + let error_msg = format!( + "Detected unknown BSON type b'\\x{:02x}' for fieldname '{}'. Are you using the latest driver version?", + type_byte, field_name + ); + Err(invalid_bson_error(py, error_msg)) + } + } +} + + +fn read_array_from_bytes( + py: Python, + bytes: &[u8], + offset: usize, + codec_options: Option<&Bound<'_, PyAny>>, + parent_field_name: &str, +) -> PyResult> { + // Arrays are encoded as documents with numeric keys + // We need to read it as a document and convert to a list + // Pass the parent field name so that errors in array elements report the array field name + let doc_dict = read_document_from_bytes_with_parent(py, bytes, offset, codec_options, Some(parent_field_name))?; + + // Convert dict to list (keys should be "0", "1", "2", ...) + let dict = doc_dict.bind(py); + let items = dict.call_method0("items")?; + let mut pairs: Vec<(usize, Py)> = Vec::new(); + + for item in items.iter()? { + let item = item?; + let tuple = item.downcast::()?; + let key: String = tuple.get_item(0)?.extract()?; + let value = tuple.get_item(1)?; + let index: usize = key.parse() + .map_err(|_| PyErr::new::( + "Invalid array index" + ))?; + pairs.push((index, value.into_py(py))); + } + + // Sort by index and extract values + pairs.sort_by_key(|(idx, _)| *idx); + let values: Vec> = pairs.into_iter().map(|(_, v)| v).collect(); + + Ok(pyo3::types::PyList::new(py, values)?.into_py(py)) +} + +/// Find the parent field name for an unknown type in an array + +fn find_parent_field_for_unknown_type(bytes: &[u8], unknown_type: u8) -> Option<&str> { + // Parse the BSON to find the field that contains the unknown type + // We're looking for an array field that contains an element with the unknown type + + if bytes.len() < 5 { + return None; + } + + let size = i32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize; + if size > bytes.len() { + return None; + } + + let mut pos = 4; // Skip size field + let end = size - 1; // -1 for null terminator + + while pos < end && pos < bytes.len() { + let type_byte = bytes[pos]; + pos += 1; + + if type_byte == 0 { + break; + } + + // Read field name + let key_start = pos; + while pos < bytes.len() && bytes[pos] != 0 { + pos += 1; + } + + if pos >= bytes.len() { + return None; + } + + let key = match std::str::from_utf8(&bytes[key_start..pos]) { + Ok(k) => k, + Err(_) => return None, + }; + + pos += 1; // Skip null terminator + + // Check if this is an array (type 0x04) + if type_byte == 0x04 { + // Read array size + if pos + 4 > bytes.len() { + return None; + } + let array_size = i32::from_le_bytes([ + bytes[pos], bytes[pos + 1], bytes[pos + 2], bytes[pos + 3], + ]) as usize; + + // Check if the array contains the unknown type + let array_start = pos; + let array_end = pos + array_size; + if array_end > bytes.len() { + return None; + } + + // Scan the array for the unknown type + let mut array_pos = array_start + 4; // Skip array size + while array_pos < array_end - 1 { + let elem_type = bytes[array_pos]; + if elem_type == 0 { + break; + } + + if elem_type == unknown_type { + // Found it! Return the array field name + return Some(key); + } + + array_pos += 1; + + // Skip element name + while array_pos < bytes.len() && bytes[array_pos] != 0 { + array_pos += 1; + } + if array_pos >= bytes.len() { + return None; + } + array_pos += 1; + + // We can't easily skip the value without parsing it fully, + // so just break here and return the key if we found the type + break; + } + + pos += array_size; + } else { + // Skip other types - we need to know their sizes + match type_byte { + 0x01 => pos += 8, // Double + 0x02 => { // String + if pos + 4 > bytes.len() { + return None; + } + let str_len = i32::from_le_bytes([ + bytes[pos], bytes[pos + 1], bytes[pos + 2], bytes[pos + 3], + ]) as usize; + pos += 4 + str_len; + } + 0x03 | 0x04 => { // Document or Array + if pos + 4 > bytes.len() { + return None; + } + let doc_size = i32::from_le_bytes([ + bytes[pos], bytes[pos + 1], bytes[pos + 2], bytes[pos + 3], + ]) as usize; + pos += doc_size; + } + 0x08 => pos += 1, // Boolean + 0x0A => {}, // Null + 0x10 => pos += 4, // Int32 + 0x12 => pos += 8, // Int64 + _ => return None, // Unknown type, can't continue + } + } + } + + None +} + +/// Decode BSON bytes to a Python dictionary +/// This is the main entry point matching the C extension API + +fn bson_to_python( + py: Python, + bson: &Bson, + codec_options: Option<&Bound<'_, PyAny>>, +) -> PyResult> { + match bson { + Bson::Null => Ok(py.None()), + Bson::Boolean(v) => Ok((*v).into_py(py)), + Bson::Int32(v) => Ok((*v as i64).into_py(py)), + Bson::Int64(v) => { + // Return bson.int64.Int64 object instead of plain Python int + let int64_class = TYPE_CACHE.get_int64_class(py)?; + let int64_obj = int64_class.bind(py).call1((*v,))?; + Ok(int64_obj.into()) + } + Bson::Double(v) => Ok((*v).into_py(py)), + Bson::String(v) => Ok(v.into_py(py)), + Bson::Binary(v) => decode_binary(py, v, codec_options), + Bson::Document(v) => bson_doc_to_python_dict(py, v, codec_options), + Bson::Array(v) => { + let list = pyo3::types::PyList::empty(py); + for item in v { + list.append(bson_to_python(py, item, codec_options)?)?; + } + Ok(list.into()) + } + Bson::ObjectId(v) => { + // Use cached ObjectId class + let objectid_class = TYPE_CACHE.get_objectid_class(py)?; + + // Create ObjectId from bytes + let bytes = PyBytes::new_bound(py, &v.bytes()); + let objectid = objectid_class.bind(py).call1((bytes,))?; + Ok(objectid.into()) + } + Bson::DateTime(v) => decode_datetime(py, v, codec_options), + Bson::RegularExpression(v) => { + // Use cached Regex class + let regex_class = TYPE_CACHE.get_regex_class(py)?; + + // Convert BSON regex options to Python flags + let flags = str_flags_to_int(&v.options); + + // Create Regex(pattern, flags) + let regex = regex_class.bind(py).call1((v.pattern.clone(), flags))?; + Ok(regex.into()) + } + Bson::JavaScriptCode(v) => { + // Use cached Code class + let code_class = TYPE_CACHE.get_code_class(py)?; + + // Create Code(code) + let code = code_class.bind(py).call1((v,))?; + Ok(code.into()) + } + Bson::JavaScriptCodeWithScope(v) => { + // Use cached Code class + let code_class = TYPE_CACHE.get_code_class(py)?; + + // Convert scope to Python dict + let scope_dict = bson_doc_to_python_dict(py, &v.scope, codec_options)?; + + // Create Code(code, scope) + let code = code_class.bind(py).call1((v.code.clone(), scope_dict))?; + Ok(code.into()) + } + Bson::Timestamp(v) => { + // Use cached Timestamp class + let timestamp_class = TYPE_CACHE.get_timestamp_class(py)?; + + // Create Timestamp(time, inc) + let timestamp = timestamp_class.bind(py).call1((v.time, v.increment))?; + Ok(timestamp.into()) + } + Bson::Decimal128(v) => { + // Use cached Decimal128 class + let decimal128_class = TYPE_CACHE.get_decimal128_class(py)?; + + // Create Decimal128 from bytes + let bytes = PyBytes::new_bound(py, &v.bytes()); + + // Use from_bid class method + let decimal128 = decimal128_class.bind(py).call_method1("from_bid", (bytes,))?; + Ok(decimal128.into()) + } + Bson::MaxKey => { + // Use cached MaxKey class + let maxkey_class = TYPE_CACHE.get_maxkey_class(py)?; + + // Create MaxKey instance + let maxkey = maxkey_class.bind(py).call0()?; + Ok(maxkey.into()) + } + Bson::MinKey => { + // Use cached MinKey class + let minkey_class = TYPE_CACHE.get_minkey_class(py)?; + + // Create MinKey instance + let minkey = minkey_class.bind(py).call0()?; + Ok(minkey.into()) + } + Bson::Symbol(v) => { + // Symbol is deprecated but we need to support decoding it + Ok(PyString::new(py, v).into()) + } + Bson::Undefined => { + // Undefined is deprecated, return None + Ok(py.None()) + } + Bson::DbPointer(v) => { + // DBPointer is deprecated, decode to DBRef + // The DbPointer struct has private fields, so we need to use Debug to extract them + let debug_str = format!("{:?}", v); + + // Parse the debug string: DbPointer { namespace: "...", id: ObjectId("...") } + // Extract namespace and ObjectId hex string + let namespace_start = debug_str.find("namespace: \"").map(|i| i + 12); + let namespace_end = debug_str.find("\", id:"); + let oid_start = debug_str.find("ObjectId(\"").map(|i| i + 10); + let oid_end = debug_str.rfind("\")"); + + if let (Some(ns_start), Some(ns_end), Some(oid_start), Some(oid_end)) = + (namespace_start, namespace_end, oid_start, oid_end) { + let namespace = &debug_str[ns_start..ns_end]; + let oid_hex = &debug_str[oid_start..oid_end]; + + // Use cached DBRef and ObjectId classes + let dbref_class = TYPE_CACHE.get_dbref_class(py)?; + let objectid_class = TYPE_CACHE.get_objectid_class(py)?; + + // Create ObjectId from hex string + let objectid = objectid_class.bind(py).call1((oid_hex,))?; + + // Create DBRef(collection, id) + let dbref = dbref_class.bind(py).call1((namespace, objectid))?; + Ok(dbref.into()) + } else { + Err(invalid_document_error(py, format!( + "invalid bson: Failed to parse DBPointer: {:?}", + v + ))) + } + } + _ => Err(invalid_document_error(py, format!( + "invalid bson: Unsupported BSON type for Python conversion: {:?}", + bson + ))), + } +} + + +fn bson_doc_to_python_dict( + py: Python, + doc: &Document, + codec_options: Option<&Bound<'_, PyAny>>, +) -> PyResult> { + // Check if this document is a DBRef (has $ref and $id fields) + if doc.contains_key("$ref") && doc.contains_key("$id") { + return decode_dbref(py, doc, codec_options); + } + + // Get document_class from codec_options, default to dict + let dict: Bound<'_, PyAny> = if let Some(opts) = codec_options { + let document_class = opts.getattr("document_class")?; + document_class.call0()? + } else { + PyDict::new(py).into_any() + }; + + for (key, value) in doc { + let py_value = bson_to_python(py, value, codec_options)?; + dict.set_item(key, py_value)?; + } + + Ok(dict.into()) +} + +/// Convert a Python dict that looks like a DBRef to a DBRef object + +fn convert_dict_to_dbref( + py: Python, + dict: &Bound<'_, PyAny>, + _codec_options: Option<&Bound<'_, PyAny>>, +) -> PyResult> { + // Check if $ref field exists + if !dict.call_method1("__contains__", ("$ref",))?.extract::()? { + return Err(PyErr::new::("DBRef missing $ref field")); + } + let collection = dict.call_method1("get", ("$ref",))?; + let collection_str: String = collection.extract()?; + + // Check if $id field exists (value can be None) + if !dict.call_method1("__contains__", ("$id",))?.extract::()? { + return Err(PyErr::new::("DBRef missing $id field")); + } + let id_obj = dict.call_method1("get", ("$id",))?; + + // Use cached DBRef class + let dbref_class = TYPE_CACHE.get_dbref_class(py)?; + + // Get optional $db field + let database_opt = dict.call_method1("get", ("$db",))?; + + // Build kwargs for extra fields (anything other than $ref, $id, $db) + let kwargs = PyDict::new(py); + let items = dict.call_method0("items")?; + for item in items.try_iter()? { + let item = item?; + let tuple = item.downcast::()?; + let key: String = tuple.get_item(0)?.extract()?; + if key != "$ref" && key != "$id" && key != "$db" { + let value = tuple.get_item(1)?; + kwargs.set_item(key, value)?; + } + } + + // Create DBRef with positional args and kwargs + if !database_opt.is_none() { + let database_str: String = database_opt.extract()?; + let dbref = dbref_class.bind(py).call((collection_str, id_obj, database_str), Some(&kwargs))?; + return Ok(dbref.into()); + } + + let dbref = dbref_class.bind(py).call((collection_str, id_obj), Some(&kwargs))?; + Ok(dbref.into()) +} + + +fn decode_dbref( + py: Python, + doc: &Document, + codec_options: Option<&Bound<'_, PyAny>>, +) -> PyResult> { + let collection = if let Some(Bson::String(s)) = doc.get("$ref") { + s.clone() + } else { + return Err(invalid_document_error(py, "Invalid document: DBRef $ref field must be a string".to_string())); + }; + + let id_bson = doc.get("$id").ok_or_else(|| invalid_document_error(py, "Invalid document: DBRef missing $id field".to_string()))?; + let id_py = bson_to_python(py, id_bson, codec_options)?; + + // Use cached DBRef class + let dbref_class = TYPE_CACHE.get_dbref_class(py)?; + + // Get optional $db field + let database_arg = if let Some(db_bson) = doc.get("$db") { + if let Bson::String(database) = db_bson { + Some(database.clone()) + } else { + None + } + } else { + None + }; + + // Collect any extra fields (not $ref, $id, or $db) as kwargs + let kwargs = PyDict::new(py); + for (key, value) in doc { + if key != "$ref" && key != "$id" && key != "$db" { + let py_value = bson_to_python(py, value, codec_options)?; + kwargs.set_item(key, py_value)?; + } + } + + // Create DBRef with positional args and kwargs + if let Some(database) = database_arg { + let dbref = dbref_class.bind(py).call((collection, id_py, database), Some(&kwargs))?; + Ok(dbref.into()) + } else { + let dbref = dbref_class.bind(py).call((collection, id_py), Some(&kwargs))?; + Ok(dbref.into()) + } +} + + +fn decode_binary( + py: Python, + v: &bson::Binary, + codec_options: Option<&Bound<'_, PyAny>>, +) -> PyResult> { + let subtype = match &v.subtype { + bson::spec::BinarySubtype::Generic => 0u8, + bson::spec::BinarySubtype::Function => 1u8, + bson::spec::BinarySubtype::BinaryOld => 2u8, + bson::spec::BinarySubtype::UuidOld => 3u8, + bson::spec::BinarySubtype::Uuid => 4u8, + bson::spec::BinarySubtype::Md5 => 5u8, + bson::spec::BinarySubtype::Encrypted => 6u8, + bson::spec::BinarySubtype::Column => 7u8, + bson::spec::BinarySubtype::Sensitive => 8u8, + bson::spec::BinarySubtype::Vector => 9u8, + bson::spec::BinarySubtype::Reserved(s) => *s, + bson::spec::BinarySubtype::UserDefined(s) => *s, + _ => { + return Err(invalid_document_error(py, + "invalid bson: Encountered unknown binary subtype that cannot be converted".to_string(), + )); + } + }; + + // Check for UUID subtypes (3 and 4) + if subtype == 3 || subtype == 4 { + let should_decode_as_uuid = if let Some(opts) = codec_options { + if let Ok(uuid_rep) = opts.getattr("uuid_representation") { + if let Ok(rep_value) = uuid_rep.extract::() { + // Decode as UUID if representation is not UNSPECIFIED (0) + rep_value != 0 + } else { + true + } + } else { + true + } + } else { + true + }; + + if should_decode_as_uuid { + // Decode as UUID using cached class + let uuid_class = TYPE_CACHE.get_uuid_class(py)?; + let bytes_obj = PyBytes::new_bound(py, &v.bytes); + let kwargs = [("bytes", bytes_obj)].into_py_dict_bound(py); + let uuid_obj = uuid_class.bind(py).call((), Some(&kwargs))?; + return Ok(uuid_obj.into()); + } + } + + if subtype == 0 { + Ok(PyBytes::new_bound(py, &v.bytes).into()) + } else { + // Use cached Binary class + let binary_class = TYPE_CACHE.get_binary_class(py)?; + + // Create Binary(data, subtype) + let bytes = PyBytes::new_bound(py, &v.bytes); + let binary = binary_class.bind(py).call1((bytes, subtype))?; + Ok(binary.into()) + } +} + + +fn decode_datetime( + py: Python, + v: &bson::DateTime, + codec_options: Option<&Bound<'_, PyAny>>, +) -> PyResult> { + // Check datetime_conversion from codec_options + // DATETIME_CLAMP = 2, DATETIME_MS = 3, DATETIME_AUTO = 4 + let datetime_conversion = if let Some(opts) = codec_options { + if let Ok(dt_conv) = opts.getattr("datetime_conversion") { + // Extract the enum value as an integer + if let Ok(conv_int) = dt_conv.call_method0("__int__") { + conv_int.extract::().unwrap_or(4) + } else { + 4 + } + } else { + 4 + } + } else { + 4 + }; + + // Python datetime range: datetime.min to datetime.max + // Min: -62135596800000 ms (year 1) + // Max: 253402300799999 ms (year 9999) + const DATETIME_MIN_MS: i64 = -62135596800000; + const DATETIME_MAX_MS: i64 = 253402300799999; + + // Extremely out of range values (beyond what can be represented) + // These should raise InvalidBSON with a helpful error message + const EXTREME_MIN_MS: i64 = -2i64.pow(52); // -4503599627370496 + const EXTREME_MAX_MS: i64 = 2i64.pow(52); // 4503599627370496 + + let mut millis = v.timestamp_millis(); + let is_out_of_range = millis < DATETIME_MIN_MS || millis > DATETIME_MAX_MS; + let is_extremely_out_of_range = millis <= EXTREME_MIN_MS || millis >= EXTREME_MAX_MS; + + // If extremely out of range, raise InvalidBSON with suggestion + if is_extremely_out_of_range { + let error_msg = format!( + "Value {} is too large or too small to be a valid BSON datetime. \ + (Consider Using CodecOptions(datetime_conversion=DATETIME_AUTO) or \ + MongoClient(datetime_conversion='DATETIME_AUTO')). See: \ + https://www.mongodb.com/docs/languages/python/pymongo-driver/current/data-formats/dates-and-times/#handling-out-of-range-datetimes", + millis + ); + return Err(invalid_bson_error(py, error_msg)); + } + + // If DATETIME_MS (3), always return DatetimeMS object + if datetime_conversion == 3 { + let datetime_ms_class = TYPE_CACHE.get_datetime_ms_class(py)?; + let datetime_ms = datetime_ms_class.bind(py).call1((millis,))?; + return Ok(datetime_ms.into()); + } + + // If DATETIME_AUTO (4) and out of range, return DatetimeMS + if datetime_conversion == 4 && is_out_of_range { + let datetime_ms_class = TYPE_CACHE.get_datetime_ms_class(py)?; + let datetime_ms = datetime_ms_class.bind(py).call1((millis,))?; + return Ok(datetime_ms.into()); + } + + // Track the original millis value before clamping for timezone conversion + let original_millis = millis; + + // If DATETIME_CLAMP (2), clamp to valid datetime range + if datetime_conversion == 2 { + if millis < DATETIME_MIN_MS { + millis = DATETIME_MIN_MS; + } else if millis > DATETIME_MAX_MS { + millis = DATETIME_MAX_MS; + } + } else if is_out_of_range { + // For other modes, raise error if out of range + return Err(PyErr::new::( + "date value out of range" + )); + } + + // Check if tz_aware is False in codec_options + let tz_aware = if let Some(opts) = codec_options { + if let Ok(tz_aware_val) = opts.getattr("tz_aware") { + tz_aware_val.extract::().unwrap_or(true) + } else { + true + } + } else { + true + }; + + // Convert to Python datetime using cached class + let datetime_class = TYPE_CACHE.get_datetime_class(py)?; + + // Convert milliseconds to seconds and microseconds + let seconds = millis / 1000; + let microseconds = (millis % 1000) * 1000; + + if tz_aware { + // Return timezone-aware datetime with UTC timezone using cached utc + let utc = TYPE_CACHE.get_utc(py)?; + + // Construct datetime from epoch using timedelta to avoid platform-specific limitations + // This works on all platforms including Windows for dates outside fromtimestamp() range + let epoch = datetime_class.bind(py).call1((1970, 1, 1, 0, 0, 0, 0, utc.bind(py)))?; + let datetime_module = py.import_bound("datetime")?; + let timedelta_class = datetime_module.getattr("timedelta")?; + + // Create timedelta for seconds and microseconds + let kwargs = [("seconds", seconds), ("microseconds", microseconds)].into_py_dict_bound(py); + let delta = timedelta_class.call((), Some(&kwargs))?; + let dt_final = epoch.call_method1("__add__", (delta,))?; + + // Convert to local timezone if tzinfo is provided in codec_options + if let Some(opts) = codec_options { + if let Ok(tzinfo) = opts.getattr("tzinfo") { + if !tzinfo.is_none() { + // Call astimezone(tzinfo) to convert to the specified timezone + // This might fail with OverflowError if the datetime is at the boundary + match dt_final.call_method1("astimezone", (&tzinfo,)) { + Ok(local_dt) => return Ok(local_dt.into()), + Err(e) => { + // If OverflowError during clamping, return datetime.min or datetime.max with the target tzinfo + if e.is_instance_of::(py) && datetime_conversion == 2 { + // Check if dt_final is at datetime.min or datetime.max + let datetime_min = datetime_class.bind(py).getattr("min")?; + let datetime_max = datetime_class.bind(py).getattr("max")?; + + // Compare year to determine if we're at min or max + let year = dt_final.getattr("year")?.extract::()?; + + if year == 1 { + // At datetime.min, return datetime.min.replace(tzinfo=tzinfo) + let kwargs = [("tzinfo", &tzinfo)].into_py_dict_bound(py); + let dt_with_tz = datetime_min.call_method("replace", (), Some(&kwargs))?; + return Ok(dt_with_tz.into()); + } else { + // At datetime.max, return datetime.max.replace(tzinfo=tzinfo, microsecond=999000) + let microsecond = 999000i32.into_py(py).into_bound(py); + let kwargs = [("tzinfo", &tzinfo), ("microsecond", µsecond)].into_py_dict_bound(py); + let dt_with_tz = datetime_max.call_method("replace", (), Some(&kwargs))?; + return Ok(dt_with_tz.into()); + } + } else { + return Err(e); + } + } + } + } + } + } + + Ok(dt_final.into()) + } else { + // Return naive datetime (no timezone) + // Construct datetime from epoch using timedelta to avoid platform-specific limitations + let epoch = datetime_class.bind(py).call1((1970, 1, 1, 0, 0, 0, 0))?; + let datetime_module = py.import_bound("datetime")?; + let timedelta_class = datetime_module.getattr("timedelta")?; + + // Create timedelta for seconds and microseconds + let kwargs = [("seconds", seconds), ("microseconds", microseconds)].into_py_dict_bound(py); + let delta = timedelta_class.call((), Some(&kwargs))?; + let naive_dt = epoch.call_method1("__add__", (delta,))?; + Ok(naive_dt.into()) + } +} diff --git a/bson/_rbson/src/encode.rs b/bson/_rbson/src/encode.rs new file mode 100644 index 0000000000..45c3ce40da --- /dev/null +++ b/bson/_rbson/src/encode.rs @@ -0,0 +1,1543 @@ +// Copyright 2025-present MongoDB, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! BSON encoding functions +//! +//! This module contains all functions for encoding Python objects to BSON bytes. + +use bson::{doc, Bson, Document}; +use pyo3::exceptions::{PyTypeError, PyValueError}; +use pyo3::prelude::*; +use pyo3::types::{IntoPyDict, PyAny, PyBool, PyBytes, PyDict, PyFloat, PyInt, PyList, PyString, PyTuple}; +use std::io::Cursor; + +use crate::errors::{invalid_document_error, invalid_document_error_with_doc}; +use crate::types::{ + TYPE_CACHE, BINARY_TYPE_MARKER, CODE_TYPE_MARKER, DATETIME_TYPE_MARKER, DBPOINTER_TYPE_MARKER, + DBREF_TYPE_MARKER, DECIMAL128_TYPE_MARKER, INT64_TYPE_MARKER, MAXKEY_TYPE_MARKER, + MINKEY_TYPE_MARKER, OBJECTID_TYPE_MARKER, REGEX_TYPE_MARKER, SYMBOL_TYPE_MARKER, + TIMESTAMP_TYPE_MARKER, +}; +use crate::utils::{datetime_to_millis, int_flags_to_str, validate_key, write_cstring, write_string}; + +#[pyfunction] +#[pyo3(signature = (obj, check_keys, _codec_options))] +pub fn _dict_to_bson( + py: Python, + obj: &Bound<'_, PyAny>, + check_keys: bool, + _codec_options: &Bound<'_, PyAny>, +) -> PyResult> { + let codec_options = Some(_codec_options); + + // Use python_mapping_to_bson_doc for efficient encoding + // This uses items() method and efficient tuple extraction + // See PR #2695 for implementation details and performance analysis + let doc = python_mapping_to_bson_doc(obj, check_keys, codec_options, true) + .map_err(|e| { + // Match C extension behavior: TypeError for non-mapping types, InvalidDocument for encoding errors + let err_str = e.to_string(); + + // If it's a TypeError about mapping type, pass it through unchanged (matches C extension) + if err_str.contains("encoder expected a mapping type") { + return e; + } + + // For other errors, wrap in InvalidDocument with document property + if err_str.contains("cannot encode object:") || err_str.contains("Object must be a dict") { + // Strip "InvalidDocument: " prefix if present, then add "Invalid document: " + let msg = if let Some(stripped) = err_str.strip_prefix("InvalidDocument: ") { + format!("Invalid document: {}", stripped) + } else { + format!("Invalid document: {}", err_str) + }; + invalid_document_error_with_doc(py, msg, obj) + } else { + e + } + })?; + + // Use to_writer() to write directly to buffer + // This is faster than bson::to_vec() which creates an intermediate Vec + let mut buf = Vec::new(); + doc.to_writer(&mut buf) + .map_err(|e| invalid_document_error(py, format!("Failed to serialize BSON: {}", e)))?; + + Ok(PyBytes::new(py, &buf).into()) +} + +/// Encode a Python dictionary to BSON bytes WITHOUT using Bson types +/// This version writes bytes directly from Python objects for better performance +#[pyfunction] +#[pyo3(signature = (obj, check_keys, _codec_options))] +pub fn _dict_to_bson_direct( + py: Python, + obj: &Bound<'_, PyAny>, + check_keys: bool, + _codec_options: &Bound<'_, PyAny>, +) -> PyResult> { + let codec_options = Some(_codec_options); + + // Write directly to bytes without converting to Bson types + let mut buf = Vec::new(); + write_document_bytes_direct(&mut buf, obj, check_keys, codec_options, true) + .map_err(|e| { + // Match C extension behavior: TypeError for non-mapping types, InvalidDocument for encoding errors + let err_str = e.to_string(); + + // If it's a TypeError about mapping type, pass it through unchanged (matches C extension) + if err_str.contains("encoder expected a mapping type") { + return e; + } + + // For other errors, wrap in InvalidDocument with document property + if err_str.contains("cannot encode object:") { + let msg = format!("Invalid document: {}", err_str); + invalid_document_error_with_doc(py, msg, obj) + } else { + e + } + })?; + + Ok(PyBytes::new(py, &buf).into()) +} + +/// Read a BSON document directly from bytes and convert to Python dict + +fn write_document_bytes( + buf: &mut Vec, + obj: &Bound<'_, PyAny>, + check_keys: bool, + codec_options: Option<&Bound<'_, PyAny>>, + is_top_level: bool, +) -> PyResult<()> { + use std::io::Write; + + // Reserve space for document size (will be filled in at the end) + let size_pos = buf.len(); + buf.extend_from_slice(&[0u8; 4]); + + // Handle _id field first if this is top-level + let mut id_written = false; + + // FAST PATH: Check if it's a PyDict first (most common case) + if let Ok(dict) = obj.downcast::() { + // First pass: write _id if present at top level + if is_top_level { + if let Some(id_value) = dict.get_item("_id")? { + write_element(buf, "_id", &id_value, check_keys, codec_options)?; + id_written = true; + } + } + + // Second pass: write all other fields + for (key, value) in dict { + let key_str: String = key.extract()?; + + // Skip _id if we already wrote it + if is_top_level && id_written && key_str == "_id" { + continue; + } + + // Validate key + validate_key(&key_str, check_keys)?; + + write_element(buf, &key_str, &value, check_keys, codec_options)?; + } + } else { + // SLOW PATH: Use items() method for SON, OrderedDict, etc. + if let Ok(items_method) = obj.getattr("items") { + if let Ok(items_result) = items_method.call0() { + // Collect items into a vector + let items: Vec<(String, Bound<'_, PyAny>)> = if let Ok(items_list) = items_result.downcast::() { + items_list.iter() + .map(|item| { + let tuple = item.downcast::()?; + let key: String = tuple.get_item(0)?.extract()?; + let value = tuple.get_item(1)?; + Ok((key, value)) + }) + .collect::>>()? + } else { + return Err(PyTypeError::new_err("items() must return a list")); + }; + + // First pass: write _id if present at top level + if is_top_level { + for (key, value) in &items { + if key == "_id" { + write_element(buf, "_id", value, check_keys, codec_options)?; + id_written = true; + break; + } + } + } + + // Second pass: write all other fields + for (key, value) in items { + // Skip _id if we already wrote it + if is_top_level && id_written && key == "_id" { + continue; + } + + // Validate key + validate_key(&key, check_keys)?; + + write_element(buf, &key, &value, check_keys, codec_options)?; + } + } else { + return Err(PyTypeError::new_err("items() call failed")); + } + } else { + return Err(PyTypeError::new_err(format!("encoder expected a mapping type but got: {}", obj))); + } + } + + // Write null terminator + buf.push(0); + + // Write document size at the beginning + let doc_size = (buf.len() - size_pos) as i32; + buf[size_pos..size_pos + 4].copy_from_slice(&doc_size.to_le_bytes()); + + Ok(()) +} + +fn write_document_bytes_direct( + buf: &mut Vec, + obj: &Bound<'_, PyAny>, + check_keys: bool, + codec_options: Option<&Bound<'_, PyAny>>, + is_top_level: bool, +) -> PyResult<()> { + // Reserve space for document size (will be filled in at the end) + let size_pos = buf.len(); + buf.extend_from_slice(&[0u8; 4]); + + // Handle _id field first if this is top-level + let mut id_written = false; + + // FAST PATH: Check if it's a PyDict first (most common case) + if let Ok(dict) = obj.downcast::() { + // First pass: write _id if present at top level + if is_top_level { + if let Some(id_value) = dict.get_item("_id")? { + write_element_direct(buf, "_id", &id_value, check_keys, codec_options)?; + id_written = true; + } + } + + // Second pass: write all other fields + for (key, value) in dict { + let key_str: String = key.extract()?; + + // Skip _id if we already wrote it + if is_top_level && id_written && key_str == "_id" { + continue; + } + + // Validate key + validate_key(&key_str, check_keys)?; + + write_element_direct(buf, &key_str, &value, check_keys, codec_options)?; + } + } else { + // SLOW PATH: Use items() method for SON, OrderedDict, etc. + if let Ok(items_method) = obj.getattr("items") { + if let Ok(items_result) = items_method.call0() { + // Collect items into a vector + let items: Vec<(String, Bound<'_, PyAny>)> = if let Ok(items_list) = items_result.downcast::() { + items_list.iter() + .map(|item| { + let tuple = item.downcast::()?; + let key: String = tuple.get_item(0)?.extract()?; + let value = tuple.get_item(1)?; + Ok((key, value)) + }) + .collect::>>()? + } else { + return Err(PyTypeError::new_err("items() must return a list")); + }; + + // First pass: write _id if present at top level + if is_top_level { + for (key, value) in &items { + if key == "_id" { + write_element_direct(buf, "_id", value, check_keys, codec_options)?; + id_written = true; + break; + } + } + } + + // Second pass: write all other fields + for (key, value) in items { + // Skip _id if we already wrote it + if is_top_level && id_written && key == "_id" { + continue; + } + + // Validate key + validate_key(&key, check_keys)?; + + write_element_direct(buf, &key, &value, check_keys, codec_options)?; + } + } else { + return Err(PyTypeError::new_err("items() call failed")); + } + } else { + return Err(PyTypeError::new_err(format!("encoder expected a mapping type but got: {}", obj))); + } + } + + // Write null terminator + buf.push(0); + + // Write document size at the beginning + let doc_size = (buf.len() - size_pos) as i32; + buf[size_pos..size_pos + 4].copy_from_slice(&doc_size.to_le_bytes()); + + Ok(()) +} + +fn write_element( + buf: &mut Vec, + key: &str, + value: &Bound<'_, PyAny>, + check_keys: bool, + codec_options: Option<&Bound<'_, PyAny>>, +) -> PyResult<()> { + use pyo3::types::{PyList, PyLong, PyTuple}; + use std::io::Write; + + // FAST PATH: Check for common Python types FIRST + if value.is_none() { + // Type 0x0A: Null + buf.push(0x0A); + write_cstring(buf, key); + return Ok(()); + } else if let Ok(v) = value.extract::() { + // Type 0x08: Boolean + buf.push(0x08); + write_cstring(buf, key); + buf.push(if v { 1 } else { 0 }); + return Ok(()); + } else if value.is_instance_of::() { + // Try i32 first, then i64 + if let Ok(v) = value.extract::() { + // Type 0x10: Int32 + buf.push(0x10); + write_cstring(buf, key); + buf.extend_from_slice(&v.to_le_bytes()); + return Ok(()); + } else if let Ok(v) = value.extract::() { + // Type 0x12: Int64 + buf.push(0x12); + write_cstring(buf, key); + buf.extend_from_slice(&v.to_le_bytes()); + return Ok(()); + } else { + return Err(PyErr::new::( + "MongoDB can only handle up to 8-byte ints" + )); + } + } else if let Ok(v) = value.extract::() { + // Type 0x01: Double + buf.push(0x01); + write_cstring(buf, key); + buf.extend_from_slice(&v.to_le_bytes()); + return Ok(()); + } else if let Ok(v) = value.extract::() { + // Type 0x02: String + buf.push(0x02); + write_cstring(buf, key); + write_string(buf, &v); + return Ok(()); + } + + // Check for dict/list BEFORE converting to Bson (much faster for nested structures) + if let Ok(dict) = value.downcast::() { + // Type 0x03: Embedded document + buf.push(0x03); + write_cstring(buf, key); + write_document_bytes(buf, value, check_keys, codec_options, false)?; + return Ok(()); + } else if let Ok(list) = value.downcast::() { + // Type 0x04: Array + buf.push(0x04); + write_cstring(buf, key); + write_array_bytes(buf, list, check_keys, codec_options)?; + return Ok(()); + } else if let Ok(tuple) = value.downcast::() { + // Type 0x04: Array (tuples are treated as arrays) + buf.push(0x04); + write_cstring(buf, key); + write_tuple_bytes(buf, tuple, check_keys, codec_options)?; + return Ok(()); + } else if value.hasattr("items")? { + // Type 0x03: Embedded document (SON, OrderedDict, etc.) + buf.push(0x03); + write_cstring(buf, key); + write_document_bytes(buf, value, check_keys, codec_options, false)?; + return Ok(()); + } + + // SLOW PATH: Handle BSON types and other Python types + // Convert to Bson and then write + let bson_value = python_to_bson(value.clone(), check_keys, codec_options)?; + write_bson_value(buf, key, &bson_value)?; + + Ok(()) +} + +fn write_element_direct( + buf: &mut Vec, + key: &str, + value: &Bound<'_, PyAny>, + check_keys: bool, + codec_options: Option<&Bound<'_, PyAny>>, +) -> PyResult<()> { + use pyo3::types::{PyList, PyLong, PyTuple}; + let py = value.py(); + + // FAST PATH: Check for common Python types FIRST + if value.is_none() { + // Type 0x0A: Null + buf.push(0x0A); + write_cstring(buf, key); + return Ok(()); + } else if let Ok(v) = value.extract::() { + // Type 0x08: Boolean + buf.push(0x08); + write_cstring(buf, key); + buf.push(if v { 1 } else { 0 }); + return Ok(()); + } else if value.is_instance_of::() { + // Try i32 first, then i64 + if let Ok(v) = value.extract::() { + // Type 0x10: Int32 + buf.push(0x10); + write_cstring(buf, key); + buf.extend_from_slice(&v.to_le_bytes()); + return Ok(()); + } else if let Ok(v) = value.extract::() { + // Type 0x12: Int64 + buf.push(0x12); + write_cstring(buf, key); + buf.extend_from_slice(&v.to_le_bytes()); + return Ok(()); + } else { + return Err(PyErr::new::( + "MongoDB can only handle up to 8-byte ints" + )); + } + } else if let Ok(v) = value.extract::() { + // Type 0x01: Double + buf.push(0x01); + write_cstring(buf, key); + buf.extend_from_slice(&v.to_le_bytes()); + return Ok(()); + } else if let Ok(v) = value.extract::() { + // Type 0x02: String + buf.push(0x02); + write_cstring(buf, key); + write_string(buf, &v); + return Ok(()); + } + + // Check for dict/list BEFORE checking BSON types + if let Ok(dict) = value.downcast::() { + // Type 0x03: Embedded document + buf.push(0x03); + write_cstring(buf, key); + write_document_bytes_direct(buf, value, check_keys, codec_options, false)?; + return Ok(()); + } else if let Ok(list) = value.downcast::() { + // Type 0x04: Array + buf.push(0x04); + write_cstring(buf, key); + write_array_bytes_direct(buf, list, check_keys, codec_options)?; + return Ok(()); + } else if let Ok(tuple) = value.downcast::() { + // Type 0x04: Array (tuples are treated as arrays) + buf.push(0x04); + write_cstring(buf, key); + write_tuple_bytes_direct(buf, tuple, check_keys, codec_options)?; + return Ok(()); + } + + // Check for BSON types with _type_marker and write directly + if let Ok(type_marker) = value.getattr("_type_marker") { + if let Ok(marker) = type_marker.extract::() { + return write_bson_type_direct(buf, key, value, marker, check_keys, codec_options); + } + } + + // Check for bytes (Python bytes type) + if let Ok(bytes_data) = value.extract::>() { + // Type 0x05: Binary (subtype 0 for generic binary) + buf.push(0x05); + write_cstring(buf, key); + buf.extend_from_slice(&(bytes_data.len() as i32).to_le_bytes()); + buf.push(0); // subtype 0 + buf.extend_from_slice(&bytes_data); + return Ok(()); + } + + // Check for mapping types (SON, OrderedDict, etc.) + if value.hasattr("items")? { + // Type 0x03: Embedded document + buf.push(0x03); + write_cstring(buf, key); + write_document_bytes_direct(buf, value, check_keys, codec_options, false)?; + return Ok(()); + } + + Err(PyErr::new::( + format!("cannot encode object: {:?}", value) + )) +} + +fn write_bson_type_direct( + buf: &mut Vec, + key: &str, + value: &Bound<'_, PyAny>, + marker: i32, + check_keys: bool, + codec_options: Option<&Bound<'_, PyAny>>, +) -> PyResult<()> { + match marker { + BINARY_TYPE_MARKER => { + // Type 0x05: Binary + let subtype: u8 = value.getattr("subtype")?.extract()?; + let bytes_data: Vec = value.extract()?; + buf.push(0x05); + write_cstring(buf, key); + buf.extend_from_slice(&(bytes_data.len() as i32).to_le_bytes()); + buf.push(subtype); + buf.extend_from_slice(&bytes_data); + Ok(()) + } + OBJECTID_TYPE_MARKER => { + // Type 0x07: ObjectId + let binary: Vec = value.getattr("binary")?.extract()?; + if binary.len() != 12 { + return Err(PyErr::new::( + "ObjectId must be 12 bytes" + )); + } + buf.push(0x07); + write_cstring(buf, key); + buf.extend_from_slice(&binary); + Ok(()) + } + DATETIME_TYPE_MARKER => { + // Type 0x09: DateTime (UTC datetime as milliseconds since epoch) + let millis: i64 = value.getattr("_value")?.extract()?; + buf.push(0x09); + write_cstring(buf, key); + buf.extend_from_slice(&millis.to_le_bytes()); + Ok(()) + } + REGEX_TYPE_MARKER => { + // Type 0x0B: Regular expression + let pattern_obj = value.getattr("pattern")?; + let pattern: String = if let Ok(s) = pattern_obj.extract::() { + s + } else if let Ok(b) = pattern_obj.extract::>() { + String::from_utf8_lossy(&b).to_string() + } else { + return Err(PyErr::new::( + "Regex pattern must be str or bytes" + )); + }; + + let flags_obj = value.getattr("flags")?; + let flags_str = if let Ok(flags_int) = flags_obj.extract::() { + int_flags_to_str(flags_int) + } else { + flags_obj.extract::().unwrap_or_default() + }; + + buf.push(0x0B); + write_cstring(buf, key); + write_cstring(buf, &pattern); + write_cstring(buf, &flags_str); + Ok(()) + } + CODE_TYPE_MARKER => { + // Type 0x0D: JavaScript code or 0x0F: JavaScript code with scope + let code_str: String = value.extract()?; + + if let Ok(scope_obj) = value.getattr("scope") { + if !scope_obj.is_none() { + // Type 0x0F: Code with scope + buf.push(0x0F); + write_cstring(buf, key); + + // Reserve space for total size + let size_pos = buf.len(); + buf.extend_from_slice(&[0u8; 4]); + + // Write code string + write_string(buf, &code_str); + + // Write scope document + write_document_bytes_direct(buf, &scope_obj, check_keys, codec_options, false)?; + + // Write total size + let total_size = (buf.len() - size_pos) as i32; + buf[size_pos..size_pos + 4].copy_from_slice(&total_size.to_le_bytes()); + + return Ok(()); + } + } + + // Type 0x0D: Code without scope + buf.push(0x0D); + write_cstring(buf, key); + write_string(buf, &code_str); + Ok(()) + } + TIMESTAMP_TYPE_MARKER => { + // Type 0x11: Timestamp + let time: u32 = value.getattr("time")?.extract()?; + let inc: u32 = value.getattr("inc")?.extract()?; + buf.push(0x11); + write_cstring(buf, key); + buf.extend_from_slice(&inc.to_le_bytes()); + buf.extend_from_slice(&time.to_le_bytes()); + Ok(()) + } + INT64_TYPE_MARKER => { + // Type 0x12: Int64 + let val: i64 = value.extract()?; + buf.push(0x12); + write_cstring(buf, key); + buf.extend_from_slice(&val.to_le_bytes()); + Ok(()) + } + DECIMAL128_TYPE_MARKER => { + // Type 0x13: Decimal128 + let bid: Vec = value.getattr("bid")?.extract()?; + if bid.len() != 16 { + return Err(PyErr::new::( + "Decimal128 must be 16 bytes" + )); + } + buf.push(0x13); + write_cstring(buf, key); + buf.extend_from_slice(&bid); + Ok(()) + } + MAXKEY_TYPE_MARKER => { + // Type 0x7F: MaxKey + buf.push(0x7F); + write_cstring(buf, key); + Ok(()) + } + MINKEY_TYPE_MARKER => { + // Type 0xFF: MinKey + buf.push(0xFF); + write_cstring(buf, key); + Ok(()) + } + _ => { + Err(PyErr::new::( + format!("Unknown BSON type marker: {}", marker) + )) + } + } +} + + +fn write_array_bytes( + buf: &mut Vec, + list: &Bound<'_, pyo3::types::PyList>, + check_keys: bool, + codec_options: Option<&Bound<'_, PyAny>>, +) -> PyResult<()> { + // Arrays are encoded as documents with numeric string keys ("0", "1", "2", ...) + let size_pos = buf.len(); + buf.extend_from_slice(&[0u8; 4]); // Reserve space for size + + for (i, item) in list.iter().enumerate() { + write_element(buf, &i.to_string(), &item, check_keys, codec_options)?; + } + + buf.push(0); // null terminator + + let arr_size = (buf.len() - size_pos) as i32; + buf[size_pos..size_pos + 4].copy_from_slice(&arr_size.to_le_bytes()); + + Ok(()) +} + +fn write_tuple_bytes( + buf: &mut Vec, + tuple: &Bound<'_, pyo3::types::PyTuple>, + check_keys: bool, + codec_options: Option<&Bound<'_, PyAny>>, +) -> PyResult<()> { + // Arrays are encoded as documents with numeric string keys ("0", "1", "2", ...) + let size_pos = buf.len(); + buf.extend_from_slice(&[0u8; 4]); // Reserve space for size + + for (i, item) in tuple.iter().enumerate() { + write_element(buf, &i.to_string(), &item, check_keys, codec_options)?; + } + + buf.push(0); // null terminator + + let arr_size = (buf.len() - size_pos) as i32; + buf[size_pos..size_pos + 4].copy_from_slice(&arr_size.to_le_bytes()); + + Ok(()) +} + +fn write_array_bytes_direct( + buf: &mut Vec, + list: &Bound<'_, pyo3::types::PyList>, + check_keys: bool, + codec_options: Option<&Bound<'_, PyAny>>, +) -> PyResult<()> { + // Arrays are encoded as documents with numeric string keys ("0", "1", "2", ...) + let size_pos = buf.len(); + buf.extend_from_slice(&[0u8; 4]); // Reserve space for size + + for (i, item) in list.iter().enumerate() { + write_element_direct(buf, &i.to_string(), &item, check_keys, codec_options)?; + } + + buf.push(0); // null terminator + + let arr_size = (buf.len() - size_pos) as i32; + buf[size_pos..size_pos + 4].copy_from_slice(&arr_size.to_le_bytes()); + + Ok(()) +} + +fn write_tuple_bytes_direct( + buf: &mut Vec, + tuple: &Bound<'_, pyo3::types::PyTuple>, + check_keys: bool, + codec_options: Option<&Bound<'_, PyAny>>, +) -> PyResult<()> { + // Arrays are encoded as documents with numeric string keys ("0", "1", "2", ...) + let size_pos = buf.len(); + buf.extend_from_slice(&[0u8; 4]); // Reserve space for size + + for (i, item) in tuple.iter().enumerate() { + write_element_direct(buf, &i.to_string(), &item, check_keys, codec_options)?; + } + + buf.push(0); // null terminator + + let arr_size = (buf.len() - size_pos) as i32; + buf[size_pos..size_pos + 4].copy_from_slice(&arr_size.to_le_bytes()); + + Ok(()) +} + +fn write_bson_value(buf: &mut Vec, key: &str, value: &Bson) -> PyResult<()> { + use std::io::Write; + + match value { + Bson::Double(v) => { + buf.push(0x01); + write_cstring(buf, key); + buf.extend_from_slice(&v.to_le_bytes()); + } + Bson::String(v) => { + buf.push(0x02); + write_cstring(buf, key); + write_string(buf, v); + } + Bson::Document(doc) => { + buf.push(0x03); + write_cstring(buf, key); + // Serialize the document + let mut doc_buf = Vec::new(); + doc.to_writer(&mut doc_buf) + .map_err(|e| PyErr::new::( + format!("Failed to encode nested document: {}", e) + ))?; + buf.extend_from_slice(&doc_buf); + } + Bson::Array(arr) => { + buf.push(0x04); + write_cstring(buf, key); + // Arrays are encoded as documents with numeric string keys + let size_pos = buf.len(); + buf.extend_from_slice(&[0u8; 4]); + + for (i, item) in arr.iter().enumerate() { + write_bson_value(buf, &i.to_string(), item)?; + } + + buf.push(0); // null terminator + + let arr_size = (buf.len() - size_pos) as i32; + buf[size_pos..size_pos + 4].copy_from_slice(&arr_size.to_le_bytes()); + } + Bson::Binary(bin) => { + buf.push(0x05); + write_cstring(buf, key); + buf.extend_from_slice(&(bin.bytes.len() as i32).to_le_bytes()); + buf.push(bin.subtype.into()); + buf.extend_from_slice(&bin.bytes); + } + Bson::ObjectId(oid) => { + buf.push(0x07); + write_cstring(buf, key); + buf.extend_from_slice(&oid.bytes()); + } + Bson::Boolean(v) => { + buf.push(0x08); + write_cstring(buf, key); + buf.push(if *v { 1 } else { 0 }); + } + Bson::DateTime(dt) => { + buf.push(0x09); + write_cstring(buf, key); + buf.extend_from_slice(&dt.timestamp_millis().to_le_bytes()); + } + Bson::Null => { + buf.push(0x0A); + write_cstring(buf, key); + } + Bson::RegularExpression(regex) => { + buf.push(0x0B); + write_cstring(buf, key); + write_cstring(buf, ®ex.pattern); + write_cstring(buf, ®ex.options); + } + Bson::Int32(v) => { + buf.push(0x10); + write_cstring(buf, key); + buf.extend_from_slice(&v.to_le_bytes()); + } + Bson::Timestamp(ts) => { + buf.push(0x11); + write_cstring(buf, key); + buf.extend_from_slice(&ts.time.to_le_bytes()); + buf.extend_from_slice(&ts.increment.to_le_bytes()); + } + Bson::Int64(v) => { + buf.push(0x12); + write_cstring(buf, key); + buf.extend_from_slice(&v.to_le_bytes()); + } + Bson::Decimal128(dec) => { + buf.push(0x13); + write_cstring(buf, key); + buf.extend_from_slice(&dec.bytes()); + } + _ => { + return Err(PyErr::new::( + format!("Unsupported BSON type: {:?}", value) + )); + } + } + + Ok(()) +} + +/// Encode a Python dictionary to BSON bytes + +fn python_to_bson( + obj: Bound<'_, PyAny>, + check_keys: bool, + codec_options: Option<&Bound<'_, PyAny>>, +) -> PyResult { + let py = obj.py(); + + // Check if this is a BSON type with a _type_marker FIRST + // This must come before string/int checks because Code inherits from str, Int64 inherits from int, etc. + if let Ok(type_marker) = obj.getattr("_type_marker") { + if let Ok(marker) = type_marker.extract::() { + return handle_bson_type_marker(obj, marker, check_keys, codec_options); + } + } + + // FAST PATH: Check for common Python types (int, str, float, bool, None) + // This avoids expensive module/attribute lookups for the majority of values + use pyo3::types::PyLong; + + if obj.is_none() { + return Ok(Bson::Null); + } else if let Ok(v) = obj.extract::() { + return Ok(Bson::Boolean(v)); + } else if obj.is_instance_of::() { + // It's a Python int - try to fit it in i32 or i64 + if let Ok(v) = obj.extract::() { + return Ok(Bson::Int32(v)); + } else if let Ok(v) = obj.extract::() { + return Ok(Bson::Int64(v)); + } else { + // Integer doesn't fit in i64 - raise OverflowError + return Err(PyErr::new::( + "MongoDB can only handle up to 8-byte ints" + )); + } + } else if let Ok(v) = obj.extract::() { + return Ok(Bson::Double(v)); + } else if let Ok(v) = obj.extract::() { + return Ok(Bson::String(v)); + } + + // Check for Python UUID objects (uuid.UUID) - use cached type + if let Ok(uuid_class) = TYPE_CACHE.get_uuid_class(py) { + if obj.is_instance(&uuid_class.bind(py))? { + // Check uuid_representation from codec_options + let uuid_representation = if let Some(opts) = codec_options { + if let Ok(uuid_rep) = opts.getattr("uuid_representation") { + uuid_rep.extract::().unwrap_or(0) + } else { + 0 + } + } else { + 0 + }; + + // UNSPECIFIED = 0, cannot encode native UUID + if uuid_representation == 0 { + return Err(PyErr::new::( + "cannot encode native uuid.UUID with UuidRepresentation.UNSPECIFIED. \ + UUIDs can be manually converted to bson.Binary instances using \ + bson.Binary.from_uuid() or a different UuidRepresentation can be \ + configured. See the documentation for UuidRepresentation for more information." + )); + } + + // Convert UUID to Binary with appropriate subtype based on representation + // UNSPECIFIED = 0, PYTHON_LEGACY = 3, STANDARD = 4, JAVA_LEGACY = 5, CSHARP_LEGACY = 6 + let uuid_bytes: Vec = obj.getattr("bytes")?.extract()?; + let subtype = match uuid_representation { + 3 => bson::spec::BinarySubtype::UuidOld, // PYTHON_LEGACY (subtype 3) + 4 => bson::spec::BinarySubtype::Uuid, // STANDARD (subtype 4) + 5 => bson::spec::BinarySubtype::UuidOld, // JAVA_LEGACY (subtype 3) + 6 => bson::spec::BinarySubtype::UuidOld, // CSHARP_LEGACY (subtype 3) + _ => bson::spec::BinarySubtype::Uuid, // Default to STANDARD + }; + + return Ok(Bson::Binary(bson::Binary { + subtype, + bytes: uuid_bytes, + })); + } + } + + // Check for compiled regex Pattern objects - use cached type + if let Ok(pattern_class) = TYPE_CACHE.get_pattern_class(py) { + if obj.is_instance(&pattern_class.bind(py))? { + // Extract pattern and flags from re.Pattern + if obj.hasattr("pattern")? && obj.hasattr("flags")? { + let pattern_obj = obj.getattr("pattern")?; + let pattern: String = if let Ok(s) = pattern_obj.extract::() { + s + } else if let Ok(b) = pattern_obj.extract::>() { + // Pattern is bytes, convert to string + String::from_utf8_lossy(&b).to_string() + } else { + return Err(invalid_document_error(py, + "Invalid document: Regex pattern must be str or bytes".to_string())); + }; + let flags: i32 = obj.getattr("flags")?.extract()?; + let flags_str = int_flags_to_str(flags); + return Ok(Bson::RegularExpression(bson::Regex { + pattern, + options: flags_str, + })); + } + } + } + + // Check for Python datetime objects - use cached type + if let Ok(datetime_class) = TYPE_CACHE.get_datetime_class(py) { + if obj.is_instance(&datetime_class.bind(py))? { + // Convert Python datetime to milliseconds since epoch (inline) + let millis = datetime_to_millis(py, &obj)?; + return Ok(Bson::DateTime(bson::DateTime::from_millis(millis))); + } + } + + // Handle remaining Python types (bytes, lists, dicts) + handle_remaining_python_types(obj, check_keys, codec_options) +} + + +fn python_mapping_to_bson_doc( + obj: &Bound<'_, PyAny>, + check_keys: bool, + codec_options: Option<&Bound<'_, PyAny>>, + is_top_level: bool, +) -> PyResult { + let mut doc = Document::new(); + let mut has_id = false; + let mut id_value: Option = None; + + // FAST PATH: Check if it's a PyDict first (most common case) + // Iterate directly over dict items - much faster than calling items() + if let Ok(dict) = obj.downcast::() { + for (key, value) in dict { + // Check if key is bytes - this is not allowed + if key.extract::>().is_ok() { + let py = obj.py(); + let key_repr = key.repr()?.to_string(); + return Err(invalid_document_error(py, + format!("documents must have only string keys, key was {}", key_repr))); + } + + // Extract key as string + let key_str: String = if let Ok(s) = key.extract::() { + s + } else { + let py = obj.py(); + return Err(invalid_document_error(py, + format!("Dictionary keys must be strings, got {}", + key.get_type().name()?))); + }; + + // Check keys if requested + if check_keys { + if key_str.starts_with('$') { + let py = obj.py(); + return Err(invalid_document_error(py, + format!("key '{}' must not start with '$'", key_str))); + } + if key_str.contains('.') { + let py = obj.py(); + return Err(invalid_document_error(py, + format!("key '{}' must not contain '.'", key_str))); + } + } + + let bson_value = python_to_bson(value, check_keys, codec_options)?; + + // Handle _id field ordering + if key_str == "_id" { + has_id = true; + id_value = Some(bson_value); + } else { + doc.insert(key_str, bson_value); + } + } + + // Insert _id first if present and at top level + if has_id { + if let Some(id_val) = id_value { + if is_top_level { + // At top level, move _id to the front + let mut new_doc = Document::new(); + new_doc.insert("_id", id_val); + for (k, v) in doc { + new_doc.insert(k, v); + } + return Ok(new_doc); + } else { + // Not at top level, just insert _id in normal position + doc.insert("_id", id_val); + } + } + } + + return Ok(doc); + } + + // SLOW PATH: Fall back to mapping protocol for SON, OrderedDict, etc. + // Use items() method for efficient iteration + if let Ok(items_method) = obj.getattr("items") { + if let Ok(items_result) = items_method.call0() { + // Try to downcast to PyList or PyTuple first for efficient iteration + if let Ok(items_list) = items_result.downcast::() { + for item in items_list { + process_mapping_item( + &item, + &mut doc, + &mut has_id, + &mut id_value, + check_keys, + codec_options, + )?; + } + } else if let Ok(items_tuple) = items_result.downcast::() { + for item in items_tuple { + process_mapping_item( + &item, + &mut doc, + &mut has_id, + &mut id_value, + check_keys, + codec_options, + )?; + } + } else { + // Fall back to generic iteration using PyIterator + let py = obj.py(); + let iter = items_result.call_method0("__iter__")?; + loop { + match iter.call_method0("__next__") { + Ok(item) => { + process_mapping_item( + &item, + &mut doc, + &mut has_id, + &mut id_value, + check_keys, + codec_options, + )?; + } + Err(e) => { + // Check if it's StopIteration + if e.is_instance_of::(py) { + break; + } else { + return Err(e); + } + } + } + } + } + + // Insert _id first if present and at top level + if has_id { + if let Some(id_val) = id_value { + if is_top_level { + // At top level, move _id to the front + let mut new_doc = Document::new(); + new_doc.insert("_id", id_val); + for (k, v) in doc { + new_doc.insert(k, v); + } + return Ok(new_doc); + } else { + // Not at top level, just insert _id in normal position + doc.insert("_id", id_val); + } + } + } + + return Ok(doc); + } + } + + // Match C extension behavior: raise TypeError for non-mapping types + Err(PyTypeError::new_err(format!("encoder expected a mapping type but got: {}", obj))) +} + +/// Extract a single item from a PyDict and return (key, value) + +fn process_mapping_item( + item: &Bound<'_, PyAny>, + doc: &mut Document, + has_id: &mut bool, + id_value: &mut Option, + check_keys: bool, + codec_options: Option<&Bound<'_, PyAny>>, +) -> PyResult<()> { + // Each item should be a tuple (key, value) + // Use extract to get a tuple of (PyObject, PyObject) + let (key, value): (Bound<'_, PyAny>, Bound<'_, PyAny>) = item.extract()?; + + // Check if key is bytes - this is not allowed + if key.extract::>().is_ok() { + let py = item.py(); + let key_repr = key.repr()?.to_string(); + return Err(invalid_document_error(py, + format!("documents must have only string keys, key was {}", key_repr))); + } + + // Convert key to string + let key_str: String = if let Ok(s) = key.extract::() { + s + } else { + let py = item.py(); + return Err(invalid_document_error(py, + format!("Dictionary keys must be strings, got {}", + key.get_type().name()?))); + }; + + // Check keys if requested + if check_keys { + if key_str.starts_with('$') { + let py = item.py(); + return Err(invalid_document_error(py, + format!("key '{}' must not start with '$'", key_str))); + } + if key_str.contains('.') { + let py = item.py(); + return Err(invalid_document_error(py, + format!("key '{}' must not contain '.'", key_str))); + } + } + + let bson_value = python_to_bson(value, check_keys, codec_options)?; + + // Always store _id field, but it will be reordered at top level only + if key_str == "_id" { + *has_id = true; + *id_value = Some(bson_value); + } else { + doc.insert(key_str, bson_value); + } + + Ok(()) +} + +/// Convert a Python mapping (dict, SON, OrderedDict, etc.) to a BSON Document +/// HYBRID APPROACH: Fast path for PyDict, items() method for other mappings + +fn extract_dict_item( + key: &Bound<'_, PyAny>, + value: &Bound<'_, PyAny>, + check_keys: bool, + codec_options: Option<&Bound<'_, PyAny>>, +) -> PyResult<(String, Bson)> { + let py = key.py(); + + // Keys must be strings (not bytes, not other types) + let key_str: String = if let Ok(s) = key.extract::() { + s + } else { + // Get a string representation of the key for the error message + let key_repr = if let Ok(b) = key.extract::>() { + format!("b'{}'", String::from_utf8_lossy(&b)) + } else { + format!("{}", key) + }; + return Err(invalid_document_error(py, format!( + "Invalid document: documents must have only string keys, key was {}", + key_repr + ))); + }; + + // Check for null bytes in key (always invalid) + if key_str.contains('\0') { + return Err(invalid_document_error(py, format!( + "Invalid document: Key names must not contain the NULL byte" + ))); + } + + // Check keys if requested (but not for _id) + if check_keys && key_str != "_id" { + if key_str.starts_with('$') { + return Err(invalid_document_error(py, format!( + "Invalid document: key '{}' must not start with '$'", + key_str + ))); + } + if key_str.contains('.') { + return Err(invalid_document_error(py, format!( + "Invalid document: key '{}' must not contain '.'", + key_str + ))); + } + } + + let bson_value = python_to_bson(value.clone(), check_keys, codec_options)?; + + Ok((key_str, bson_value)) +} + + +fn extract_mapping_item( + item: &Bound<'_, PyAny>, + check_keys: bool, + codec_options: Option<&Bound<'_, PyAny>>, +) -> PyResult<(String, Bson)> { + // Each item should be a tuple (key, value) + let (key, value): (Bound<'_, PyAny>, Bound<'_, PyAny>) = item.extract()?; + + // Keys must be strings (not bytes, not other types) + let py = item.py(); + let key_str: String = if let Ok(s) = key.extract::() { + s + } else { + // Get a string representation of the key for the error message + let key_repr = if let Ok(b) = key.extract::>() { + format!("b'{}'", String::from_utf8_lossy(&b)) + } else { + format!("{}", key) + }; + return Err(invalid_document_error(py, format!( + "Invalid document: documents must have only string keys, key was {}", + key_repr + ))); + }; + + // Check for null bytes in key (always invalid) + if key_str.contains('\0') { + return Err(invalid_document_error(py, format!( + "Invalid document: Key names must not contain the NULL byte" + ))); + } + + // Check keys if requested (but not for _id) + if check_keys && key_str != "_id" { + if key_str.starts_with('$') { + return Err(invalid_document_error(py, format!( + "Invalid document: key '{}' must not start with '$'", + key_str + ))); + } + if key_str.contains('.') { + return Err(invalid_document_error(py, format!( + "Invalid document: key '{}' must not contain '.'", + key_str + ))); + } + } + + let bson_value = python_to_bson(value, check_keys, codec_options)?; + + Ok((key_str, bson_value)) +} + + +fn handle_bson_type_marker( + obj: Bound<'_, PyAny>, + marker: i32, + check_keys: bool, + codec_options: Option<&Bound<'_, PyAny>>, +) -> PyResult { + match marker { + BINARY_TYPE_MARKER => { + // Binary object + let subtype: u8 = obj.getattr("subtype")?.extract()?; + let bytes: Vec = obj.extract()?; + + let bson_subtype = match subtype { + 0 => bson::spec::BinarySubtype::Generic, + 1 => bson::spec::BinarySubtype::Function, + 2 => bson::spec::BinarySubtype::BinaryOld, + 3 => bson::spec::BinarySubtype::UuidOld, + 4 => bson::spec::BinarySubtype::Uuid, + 5 => bson::spec::BinarySubtype::Md5, + 6 => bson::spec::BinarySubtype::Encrypted, + 7 => bson::spec::BinarySubtype::Column, + 8 => bson::spec::BinarySubtype::Sensitive, + 9 => bson::spec::BinarySubtype::Vector, + 10..=127 => bson::spec::BinarySubtype::Reserved(subtype), + 128..=255 => bson::spec::BinarySubtype::UserDefined(subtype), + }; + + Ok(Bson::Binary(bson::Binary { + subtype: bson_subtype, + bytes, + })) + } + OBJECTID_TYPE_MARKER => { + // ObjectId object - get the binary representation + let binary: Vec = obj.getattr("binary")?.extract()?; + if binary.len() != 12 { + return Err(invalid_document_error(obj.py(), "Invalid document: ObjectId must be 12 bytes".to_string())); + } + let mut oid_bytes = [0u8; 12]; + oid_bytes.copy_from_slice(&binary); + Ok(Bson::ObjectId(bson::oid::ObjectId::from_bytes(oid_bytes))) + } + DATETIME_TYPE_MARKER => { + // DateTime/DatetimeMS object - get milliseconds since epoch + if let Ok(value) = obj.getattr("_value") { + // Check that __int__() returns an actual integer, not a float + if let Ok(int_result) = obj.call_method0("__int__") { + // Check if the result is a float (which would be invalid) + if int_result.is_instance_of::() { + return Err(PyTypeError::new_err( + "DatetimeMS.__int__() must return an integer, not float" + )); + } + } + + let millis: i64 = value.extract()?; + Ok(Bson::DateTime(bson::DateTime::from_millis(millis))) + } else { + Err(invalid_document_error(obj.py(), + "Invalid document: DateTime object must have _value attribute".to_string(), + )) + } + } + REGEX_TYPE_MARKER => { + // Regex object - pattern can be str or bytes + let pattern_obj = obj.getattr("pattern")?; + let pattern: String = if let Ok(s) = pattern_obj.extract::() { + s + } else if let Ok(b) = pattern_obj.extract::>() { + // Pattern is bytes, convert to string (lossy for non-UTF8) + String::from_utf8_lossy(&b).to_string() + } else { + return Err(invalid_document_error(obj.py(), + "Invalid document: Regex pattern must be str or bytes".to_string())); + }; + + let flags_obj = obj.getattr("flags")?; + + // Flags can be an int or a string + let flags_str = if let Ok(flags_int) = flags_obj.extract::() { + int_flags_to_str(flags_int) + } else { + flags_obj.extract::().unwrap_or_default() + }; + + Ok(Bson::RegularExpression(bson::Regex { + pattern, + options: flags_str, + })) + } + CODE_TYPE_MARKER => { + // Code object - inherits from str + let code_str: String = obj.extract()?; + + // Check if there's a scope + if let Ok(scope_obj) = obj.getattr("scope") { + if !scope_obj.is_none() { + // Code with scope + let scope_doc = python_mapping_to_bson_doc(&scope_obj, check_keys, codec_options, false)?; + return Ok(Bson::JavaScriptCodeWithScope(bson::JavaScriptCodeWithScope { + code: code_str, + scope: scope_doc, + })); + } + } + + // Code without scope + Ok(Bson::JavaScriptCode(code_str)) + } + TIMESTAMP_TYPE_MARKER => { + // Timestamp object + let time: u32 = obj.getattr("time")?.extract()?; + let inc: u32 = obj.getattr("inc")?.extract()?; + Ok(Bson::Timestamp(bson::Timestamp { + time, + increment: inc, + })) + } + INT64_TYPE_MARKER => { + // Int64 object - extract the value and encode as BSON Int64 + let value: i64 = obj.extract()?; + Ok(Bson::Int64(value)) + } + DECIMAL128_TYPE_MARKER => { + // Decimal128 object + let bid: Vec = obj.getattr("bid")?.extract()?; + if bid.len() != 16 { + return Err(invalid_document_error(obj.py(), "Invalid document: Decimal128 must be 16 bytes".to_string())); + } + let mut bytes = [0u8; 16]; + bytes.copy_from_slice(&bid); + Ok(Bson::Decimal128(bson::Decimal128::from_bytes(bytes))) + } + MAXKEY_TYPE_MARKER => { + Ok(Bson::MaxKey) + } + MINKEY_TYPE_MARKER => { + Ok(Bson::MinKey) + } + DBREF_TYPE_MARKER => { + // DBRef object - use as_doc() method + if let Ok(as_doc_method) = obj.getattr("as_doc") { + if let Ok(doc_obj) = as_doc_method.call0() { + let dbref_doc = python_mapping_to_bson_doc(&doc_obj, check_keys, codec_options, false)?; + return Ok(Bson::Document(dbref_doc)); + } + } + + // Fallback: manually construct the document + let mut dbref_doc = Document::new(); + let collection: String = obj.getattr("collection")?.extract()?; + dbref_doc.insert("$ref", collection); + + let id_obj = obj.getattr("id")?; + let id_bson = python_to_bson(id_obj, check_keys, codec_options)?; + dbref_doc.insert("$id", id_bson); + + if let Ok(database_obj) = obj.getattr("database") { + if !database_obj.is_none() { + let database: String = database_obj.extract()?; + dbref_doc.insert("$db", database); + } + } + + Ok(Bson::Document(dbref_doc)) + } + _ => { + // Unknown type marker, fall through to remaining types + handle_remaining_python_types(obj, check_keys, codec_options) + } + } +} + + +fn handle_remaining_python_types( + obj: Bound<'_, PyAny>, + check_keys: bool, + codec_options: Option<&Bound<'_, PyAny>>, +) -> PyResult { + use pyo3::types::PyList; + use pyo3::types::PyTuple; + + // FAST PATH: Check for PyList first (most common sequence type) + if let Ok(list) = obj.downcast::() { + let mut arr = Vec::with_capacity(list.len()); + for item in list { + arr.push(python_to_bson(item, check_keys, codec_options)?); + } + return Ok(Bson::Array(arr)); + } + + // FAST PATH: Check for PyTuple + if let Ok(tuple) = obj.downcast::() { + let mut arr = Vec::with_capacity(tuple.len()); + for item in tuple { + arr.push(python_to_bson(item, check_keys, codec_options)?); + } + return Ok(Bson::Array(arr)); + } + + // Check for bytes/bytearray by type (not by extract, which would match tuples) + // Raw bytes without Binary wrapper -> subtype 0 + if obj.is_instance_of::() { + let v: Vec = obj.extract()?; + return Ok(Bson::Binary(bson::Binary { + subtype: bson::spec::BinarySubtype::Generic, + bytes: v, + })); + } + + // Check for dict-like objects (SON, OrderedDict, etc.) + if obj.hasattr("items")? { + // Any object with items() method (dict, SON, OrderedDict, etc.) + let doc = python_mapping_to_bson_doc(&obj, check_keys, codec_options, false)?; + return Ok(Bson::Document(doc)); + } + + // SLOW PATH: Try generic sequence extraction + if let Ok(list) = obj.extract::>>() { + // Check for sequences (lists, tuples) + let mut arr = Vec::new(); + for item in list { + arr.push(python_to_bson(item, check_keys, codec_options)?); + } + return Ok(Bson::Array(arr)); + } + + // Get object repr and type for error message + let obj_repr = obj.repr().map(|r| r.to_string()).unwrap_or_else(|_| "?".to_string()); + let obj_type = obj.get_type().to_string(); + Err(invalid_document_error(obj.py(), format!( + "cannot encode object: {}, of type: {}", + obj_repr, obj_type + ))) +} diff --git a/bson/_rbson/src/errors.rs b/bson/_rbson/src/errors.rs new file mode 100644 index 0000000000..a7b009b1f0 --- /dev/null +++ b/bson/_rbson/src/errors.rs @@ -0,0 +1,55 @@ +// Copyright 2025-present MongoDB, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Error handling utilities for BSON operations + +use pyo3::prelude::*; +use pyo3::types::{PyAny, PyTuple}; + +use crate::types::TYPE_CACHE; + +/// Helper to create InvalidDocument exception +pub(crate) fn invalid_document_error(py: Python, msg: String) -> PyErr { + let invalid_document = TYPE_CACHE.get_invalid_document_class(py) + .expect("Failed to get InvalidDocument class"); + PyErr::from_value( + invalid_document.bind(py) + .call1((msg,)) + .expect("Failed to create InvalidDocument") + ) +} + +/// Helper to create InvalidDocument exception with document property +pub(crate) fn invalid_document_error_with_doc(py: Python, msg: String, doc: &Bound<'_, PyAny>) -> PyErr { + let invalid_document = TYPE_CACHE.get_invalid_document_class(py) + .expect("Failed to get InvalidDocument class"); + // Call with positional arguments: InvalidDocument(message, document) + let args = PyTuple::new_bound(py, &[msg.into_py(py), doc.clone().into_py(py)]); + PyErr::from_value( + invalid_document.bind(py) + .call1(args) + .expect("Failed to create InvalidDocument") + ) +} + +/// Helper to create InvalidBSON exception +pub(crate) fn invalid_bson_error(py: Python, msg: String) -> PyErr { + let invalid_bson = TYPE_CACHE.get_invalid_bson_class(py) + .expect("Failed to get InvalidBSON class"); + PyErr::from_value( + invalid_bson.bind(py) + .call1((msg,)) + .expect("Failed to create InvalidBSON") + ) +} diff --git a/bson/_rbson/src/lib.rs b/bson/_rbson/src/lib.rs new file mode 100644 index 0000000000..cb5d16ad19 --- /dev/null +++ b/bson/_rbson/src/lib.rs @@ -0,0 +1,85 @@ +// Copyright 2025-present MongoDB, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Rust implementation of BSON encoding/decoding functions +//! +//! ⚠️ **NOT PRODUCTION READY** - Experimental implementation with incomplete features. +//! +//! This module provides a **partial implementation** of the C extension (bson._cbson) +//! interface, implemented in Rust using PyO3 and the bson library. +//! +//! # Implementation Status +//! +//! - ✅ Core BSON encoding/decoding: 86/88 tests passing +//! - ❌ Custom type encoders: NOT IMPLEMENTED (~85 tests skipped) +//! - ❌ RawBSONDocument: NOT IMPLEMENTED +//! - ❌ Performance: ~5x slower than C extension +//! +//! # Implementation History +//! +//! This implementation was developed as part of PYTHON-5683 to investigate +//! using Rust as an alternative to C for Python extension modules. +//! +//! See PR #2695 for the complete implementation history, including: +//! - Initial implementation with core BSON functionality +//! - Performance optimizations (type caching, fast paths, direct conversions) +//! - Modular refactoring (split into 6 modules) +//! - Test skip markers for unimplemented features +//! +//! # Performance +//! +//! Current performance: ~0.21x (5x slower than C extension) +//! Root cause: Architectural difference (Python ↔ Bson ↔ bytes vs Python ↔ bytes) +//! See README.md for detailed performance analysis and optimization opportunities. +//! +//! # Module Structure +//! +//! The codebase is organized into the following modules: +//! - `types`: Type cache and BSON type markers +//! - `errors`: Error handling utilities +//! - `utils`: Utility functions (datetime, regex, validation, string writing) +//! - `encode`: BSON encoding functions +//! - `decode`: BSON decoding functions + +#![allow(clippy::useless_conversion)] + +mod types; +mod errors; +mod utils; +mod encode; +mod decode; + +use pyo3::prelude::*; +use pyo3::types::PyDict; + +/// Test function to verify the Rust extension is loaded +#[pyfunction] +fn _test_rust_extension(py: Python) -> PyResult { + let result = PyDict::new(py); + result.set_item("implementation", "rust")?; + result.set_item("version", "0.1.0")?; + result.set_item("status", "experimental")?; + result.set_item("pyo3_version", env!("CARGO_PKG_VERSION"))?; + Ok(result.into()) +} + +/// Python module definition +#[pymodule] +fn _rbson(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_function(wrap_pyfunction!(encode::_dict_to_bson, m)?)?; + m.add_function(wrap_pyfunction!(encode::_dict_to_bson_direct, m)?)?; + m.add_function(wrap_pyfunction!(decode::_bson_to_dict, m)?)?; + m.add_function(wrap_pyfunction!(_test_rust_extension, m)?)?; + Ok(()) +} diff --git a/bson/_rbson/src/types.rs b/bson/_rbson/src/types.rs new file mode 100644 index 0000000000..763daf10ea --- /dev/null +++ b/bson/_rbson/src/types.rs @@ -0,0 +1,265 @@ +// Copyright 2025-present MongoDB, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Type cache for Python type objects +//! +//! This module provides a cache for Python type objects to avoid repeated imports. +//! This matches the C extension's approach of caching all BSON types at module initialization. + +use once_cell::sync::OnceCell; +use pyo3::prelude::*; +use pyo3::types::PyAny; + +/// Cache for Python type objects to avoid repeated imports +/// This matches the C extension's approach of caching all BSON types at module initialization +pub(crate) struct TypeCache { + // Standard library types + pub(crate) uuid_class: OnceCell, + pub(crate) datetime_class: OnceCell, + pub(crate) pattern_class: OnceCell, + + // BSON types + pub(crate) binary_class: OnceCell, + pub(crate) code_class: OnceCell, + pub(crate) objectid_class: OnceCell, + pub(crate) dbref_class: OnceCell, + pub(crate) regex_class: OnceCell, + pub(crate) timestamp_class: OnceCell, + pub(crate) int64_class: OnceCell, + pub(crate) decimal128_class: OnceCell, + pub(crate) minkey_class: OnceCell, + pub(crate) maxkey_class: OnceCell, + pub(crate) datetime_ms_class: OnceCell, + + // Utility objects + pub(crate) utc: OnceCell, + pub(crate) calendar_timegm: OnceCell, + + // Error classes + pub(crate) invalid_document_class: OnceCell, + pub(crate) invalid_bson_class: OnceCell, + + // Fallback decoder + pub(crate) bson_to_dict_python: OnceCell, +} + +pub(crate) static TYPE_CACHE: TypeCache = TypeCache { + uuid_class: OnceCell::new(), + datetime_class: OnceCell::new(), + pattern_class: OnceCell::new(), + binary_class: OnceCell::new(), + code_class: OnceCell::new(), + objectid_class: OnceCell::new(), + dbref_class: OnceCell::new(), + regex_class: OnceCell::new(), + timestamp_class: OnceCell::new(), + int64_class: OnceCell::new(), + decimal128_class: OnceCell::new(), + minkey_class: OnceCell::new(), + maxkey_class: OnceCell::new(), + datetime_ms_class: OnceCell::new(), + utc: OnceCell::new(), + calendar_timegm: OnceCell::new(), + invalid_document_class: OnceCell::new(), + invalid_bson_class: OnceCell::new(), + bson_to_dict_python: OnceCell::new(), +}; + +impl TypeCache { + /// Get or initialize the UUID class + pub(crate) fn get_uuid_class(&self, py: Python) -> PyResult> { + Ok(self.uuid_class.get_or_try_init(|| { + py.import_bound("uuid")? + .getattr("UUID") + .map(|c| c.unbind()) + })?.clone_ref(py)) + } + + /// Get or initialize the datetime class + pub(crate) fn get_datetime_class(&self, py: Python) -> PyResult> { + Ok(self.datetime_class.get_or_try_init(|| { + py.import_bound("datetime")? + .getattr("datetime") + .map(|c| c.unbind()) + })?.clone_ref(py)) + } + + /// Get or initialize the regex Pattern class + pub(crate) fn get_pattern_class(&self, py: Python) -> PyResult> { + Ok(self.pattern_class.get_or_try_init(|| { + py.import_bound("re")? + .getattr("Pattern") + .map(|c| c.unbind()) + })?.clone_ref(py)) + } + + /// Get or initialize the Binary class + pub(crate) fn get_binary_class(&self, py: Python) -> PyResult> { + Ok(self.binary_class.get_or_try_init(|| { + py.import_bound("bson.binary")? + .getattr("Binary") + .map(|c| c.unbind()) + })?.clone_ref(py)) + } + + /// Get or initialize the Code class + pub(crate) fn get_code_class(&self, py: Python) -> PyResult> { + Ok(self.code_class.get_or_try_init(|| { + py.import_bound("bson.code")? + .getattr("Code") + .map(|c| c.unbind()) + })?.clone_ref(py)) + } + + /// Get or initialize the ObjectId class + pub(crate) fn get_objectid_class(&self, py: Python) -> PyResult> { + Ok(self.objectid_class.get_or_try_init(|| { + py.import_bound("bson.objectid")? + .getattr("ObjectId") + .map(|c| c.unbind()) + })?.clone_ref(py)) + } + + /// Get or initialize the DBRef class + pub(crate) fn get_dbref_class(&self, py: Python) -> PyResult> { + Ok(self.dbref_class.get_or_try_init(|| { + py.import_bound("bson.dbref")? + .getattr("DBRef") + .map(|c| c.unbind()) + })?.clone_ref(py)) + } + + /// Get or initialize the Regex class + pub(crate) fn get_regex_class(&self, py: Python) -> PyResult> { + Ok(self.regex_class.get_or_try_init(|| { + py.import_bound("bson.regex")? + .getattr("Regex") + .map(|c| c.unbind()) + })?.clone_ref(py)) + } + + /// Get or initialize the Timestamp class + pub(crate) fn get_timestamp_class(&self, py: Python) -> PyResult> { + Ok(self.timestamp_class.get_or_try_init(|| { + py.import_bound("bson.timestamp")? + .getattr("Timestamp") + .map(|c| c.unbind()) + })?.clone_ref(py)) + } + + /// Get or initialize the Int64 class + pub(crate) fn get_int64_class(&self, py: Python) -> PyResult> { + Ok(self.int64_class.get_or_try_init(|| { + py.import_bound("bson.int64")? + .getattr("Int64") + .map(|c| c.unbind()) + })?.clone_ref(py)) + } + + /// Get or initialize the Decimal128 class + pub(crate) fn get_decimal128_class(&self, py: Python) -> PyResult> { + Ok(self.decimal128_class.get_or_try_init(|| { + py.import_bound("bson.decimal128")? + .getattr("Decimal128") + .map(|c| c.unbind()) + })?.clone_ref(py)) + } + + /// Get or initialize the MinKey class + pub(crate) fn get_minkey_class(&self, py: Python) -> PyResult> { + Ok(self.minkey_class.get_or_try_init(|| { + py.import_bound("bson.min_key")? + .getattr("MinKey") + .map(|c| c.unbind()) + })?.clone_ref(py)) + } + + /// Get or initialize the MaxKey class + pub(crate) fn get_maxkey_class(&self, py: Python) -> PyResult> { + Ok(self.maxkey_class.get_or_try_init(|| { + py.import_bound("bson.max_key")? + .getattr("MaxKey") + .map(|c| c.unbind()) + })?.clone_ref(py)) + } + + /// Get or initialize the DatetimeMS class + pub(crate) fn get_datetime_ms_class(&self, py: Python) -> PyResult> { + Ok(self.datetime_ms_class.get_or_try_init(|| { + py.import_bound("bson.datetime_ms")? + .getattr("DatetimeMS") + .map(|c| c.unbind()) + })?.clone_ref(py)) + } + + /// Get or initialize the UTC timezone object + pub(crate) fn get_utc(&self, py: Python) -> PyResult> { + Ok(self.utc.get_or_try_init(|| { + py.import_bound("bson.tz_util")? + .getattr("utc") + .map(|c| c.unbind()) + })?.clone_ref(py)) + } + + /// Get or initialize calendar.timegm function + pub(crate) fn get_calendar_timegm(&self, py: Python) -> PyResult> { + Ok(self.calendar_timegm.get_or_try_init(|| { + py.import_bound("calendar")? + .getattr("timegm") + .map(|c| c.unbind()) + })?.clone_ref(py)) + } + + /// Get or initialize InvalidDocument exception class + pub(crate) fn get_invalid_document_class(&self, py: Python) -> PyResult> { + Ok(self.invalid_document_class.get_or_try_init(|| { + py.import_bound("bson.errors")? + .getattr("InvalidDocument") + .map(|c| c.unbind()) + })?.clone_ref(py)) + } + + /// Get or initialize InvalidBSON exception class + pub(crate) fn get_invalid_bson_class(&self, py: Python) -> PyResult> { + Ok(self.invalid_bson_class.get_or_try_init(|| { + py.import_bound("bson.errors")? + .getattr("InvalidBSON") + .map(|c| c.unbind()) + })?.clone_ref(py)) + } + + /// Get or initialize the Python fallback decoder + pub(crate) fn get_bson_to_dict_python(&self, py: Python) -> PyResult> { + Ok(self.bson_to_dict_python.get_or_try_init(|| { + py.import_bound("bson")? + .getattr("_bson_to_dict_python") + .map(|c| c.unbind()) + })?.clone_ref(py)) + } +} + +// Type markers for BSON objects +pub(crate) const BINARY_TYPE_MARKER: i32 = 5; +pub(crate) const OBJECTID_TYPE_MARKER: i32 = 7; +pub(crate) const DATETIME_TYPE_MARKER: i32 = 9; +pub(crate) const REGEX_TYPE_MARKER: i32 = 11; +pub(crate) const CODE_TYPE_MARKER: i32 = 13; +pub(crate) const SYMBOL_TYPE_MARKER: i32 = 14; +pub(crate) const DBPOINTER_TYPE_MARKER: i32 = 15; +pub(crate) const TIMESTAMP_TYPE_MARKER: i32 = 17; +pub(crate) const INT64_TYPE_MARKER: i32 = 18; +pub(crate) const DECIMAL128_TYPE_MARKER: i32 = 19; +pub(crate) const DBREF_TYPE_MARKER: i32 = 100; +pub(crate) const MAXKEY_TYPE_MARKER: i32 = 127; +pub(crate) const MINKEY_TYPE_MARKER: i32 = 255; diff --git a/bson/_rbson/src/utils.rs b/bson/_rbson/src/utils.rs new file mode 100644 index 0000000000..85eaefa5dc --- /dev/null +++ b/bson/_rbson/src/utils.rs @@ -0,0 +1,153 @@ +// Copyright 2025-present MongoDB, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Utility functions for BSON operations + +use pyo3::prelude::*; +use pyo3::types::PyAny; + +use crate::types::TYPE_CACHE; + +/// Convert Python datetime to milliseconds since epoch UTC +/// This is equivalent to Python's bson.datetime_ms._datetime_to_millis() +pub(crate) fn datetime_to_millis(py: Python, dtm: &Bound<'_, PyAny>) -> PyResult { + // Get datetime components + let year: i32 = dtm.getattr("year")?.extract()?; + let month: i32 = dtm.getattr("month")?.extract()?; + let day: i32 = dtm.getattr("day")?.extract()?; + let hour: i32 = dtm.getattr("hour")?.extract()?; + let minute: i32 = dtm.getattr("minute")?.extract()?; + let second: i32 = dtm.getattr("second")?.extract()?; + let microsecond: i32 = dtm.getattr("microsecond")?.extract()?; + + // Check if datetime has timezone offset + let utcoffset = dtm.call_method0("utcoffset")?; + let offset_seconds: i64 = if !utcoffset.is_none() { + // Get total_seconds() from timedelta + let total_seconds: f64 = utcoffset.call_method0("total_seconds")?.extract()?; + total_seconds as i64 + } else { + 0 + }; + + // Calculate seconds since epoch using the same algorithm as Python's calendar.timegm + // This is: (year - 1970) * 365.25 days + month/day adjustments + time + // We'll use Python's calendar.timegm for accuracy + let timegm = TYPE_CACHE.get_calendar_timegm(py)?; + + // Create a time tuple (year, month, day, hour, minute, second, weekday, yearday, isdst) + // We need timetuple() method + let timetuple = dtm.call_method0("timetuple")?; + let seconds_since_epoch: i64 = timegm.bind(py).call1((timetuple,))?.extract()?; + + // Adjust for timezone offset (subtract to get UTC) + let utc_seconds = seconds_since_epoch - offset_seconds; + + // Convert to milliseconds and add microseconds + let millis = utc_seconds * 1000 + (microsecond / 1000) as i64; + + Ok(millis) +} + +/// Convert Python regex flags (int) to BSON regex options (string) +pub(crate) fn int_flags_to_str(flags: i32) -> String { + let mut options = String::new(); + + // Python re module flags to BSON regex options: + // re.IGNORECASE = 2 -> 'i' + // re.MULTILINE = 8 -> 'm' + // re.DOTALL = 16 -> 's' + // re.VERBOSE = 64 -> 'x' + // Note: re.LOCALE and re.UNICODE are Python-specific + + if flags & 2 != 0 { + options.push('i'); + } + if flags & 4 != 0 { + options.push('l'); // Preserved for round-trip compatibility + } + if flags & 8 != 0 { + options.push('m'); + } + if flags & 16 != 0 { + options.push('s'); + } + if flags & 32 != 0 { + options.push('u'); // Preserved for round-trip compatibility + } + if flags & 64 != 0 { + options.push('x'); + } + + options +} + +/// Convert BSON regex options (string) to Python regex flags (int) +pub(crate) fn str_flags_to_int(options: &str) -> i32 { + let mut flags = 0; + + for ch in options.chars() { + match ch { + 'i' => flags |= 2, // re.IGNORECASE + 'l' => flags |= 4, // re.LOCALE + 'm' => flags |= 8, // re.MULTILINE + 's' => flags |= 16, // re.DOTALL + 'u' => flags |= 32, // re.UNICODE + 'x' => flags |= 64, // re.VERBOSE + _ => {} // Ignore unknown flags + } + } + + flags +} + +/// Validate a document key +pub(crate) fn validate_key(key: &str, check_keys: bool) -> PyResult<()> { + // Check for null bytes (always invalid) + if key.contains('\0') { + return Err(PyErr::new::( + "Key names must not contain the NULL byte" + )); + } + + // Check keys if requested (but not for _id) + if check_keys && key != "_id" { + if key.starts_with('$') { + return Err(PyErr::new::( + format!("key '{}' must not start with '$'", key) + )); + } + if key.contains('.') { + return Err(PyErr::new::( + format!("key '{}' must not contain '.'", key) + )); + } + } + + Ok(()) +} + +/// Write a C-style null-terminated string +pub(crate) fn write_cstring(buf: &mut Vec, s: &str) { + buf.extend_from_slice(s.as_bytes()); + buf.push(0); +} + +/// Write a BSON string (int32 length + string + null terminator) +pub(crate) fn write_string(buf: &mut Vec, s: &str) { + let len = (s.len() + 1) as i32; // +1 for null terminator + buf.extend_from_slice(&len.to_le_bytes()); + buf.extend_from_slice(s.as_bytes()); + buf.push(0); +} diff --git a/bson/son.py b/bson/son.py index 8fd4f95cd2..ccb6bdb273 100644 --- a/bson/son.py +++ b/bson/son.py @@ -22,6 +22,7 @@ import copy import re +import warnings from collections.abc import Mapping as _Mapping from typing import ( Any, @@ -99,13 +100,28 @@ def __iter__(self) -> Iterator[_Key]: yield from self.__keys def has_key(self, key: _Key) -> bool: + warnings.warn( + "SON.has_key() is deprecated, use the in operator instead", + DeprecationWarning, + stacklevel=2, + ) return key in self.__keys def iterkeys(self) -> Iterator[_Key]: + warnings.warn( + "SON.iterkeys() is deprecated, use the keys() method instead", + DeprecationWarning, + stacklevel=2, + ) return self.__iter__() # fourth level uses definitions from lower levels def itervalues(self) -> Iterator[_Value]: + warnings.warn( + "SON.itervalues() is deprecated, use the values() method instead", + DeprecationWarning, + stacklevel=2, + ) for _, v in self.items(): yield v diff --git a/doc/changelog.rst b/doc/changelog.rst index 571ce3b63e..a1178ef3db 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -1,6 +1,23 @@ Changelog ========= +Changes in Version 4.17.0 (2026/XX/XX) +-------------------------------------- + +PyMongo 4.17 brings a number of changes including: + +- ``has_key``, ``iterkeys`` and ``itervalues`` in :class:`bson.son.SON` have + been deprecated and will be removed in PyMongo 5.0. These methods were + deprecated in favor of the standard dictionary containment operator ``in`` + and the ``keys()`` and ``values()`` methods, respectively. + +- Added the :meth:`~pymongo.asynchronous.client_session.AsyncClientSession.bind` and :meth:`~pymongo.client_session.ClientSession.bind` methods + that allow users to bind a session to all database operations within the scope of a context manager instead of having to explicitly pass the session to each individual operation. + See for examples and more information. +- Added support for MongoDB's Intelligent Workload Management (IWM) and ingress connection rate limiting features. + The driver now gracefully handles write-blocking scenarios and optimizes connection establishment during high-load conditions to maintain application availability. + See and for more information. + Changes in Version 4.16.0 (2026/01/07) -------------------------------------- diff --git a/hatch_build.py b/hatch_build.py index 40271972dd..0d69a1bca1 100644 --- a/hatch_build.py +++ b/hatch_build.py @@ -2,8 +2,12 @@ from __future__ import annotations import os +import shutil import subprocess import sys +import tempfile +import warnings +import zipfile from pathlib import Path from hatchling.builders.hooks.plugin.interface import BuildHookInterface @@ -12,6 +16,116 @@ class CustomHook(BuildHookInterface): """The pymongo build hook.""" + def _build_rust_extension(self, here: Path, *, required: bool = False) -> bool: + """Build the Rust BSON extension if Rust toolchain is available. + + Args: + here: The root directory of the project. + required: If True, raise an error if the build fails. If False, issue a warning. + + Returns True if built successfully, False otherwise. + """ + # Check if Rust is available + if not shutil.which("cargo"): + msg = ( + "Rust toolchain not found. " + "Install Rust from https://rustup.rs/ to enable the Rust extension." + ) + if required: + raise RuntimeError(msg) + warnings.warn( + f"{msg} Skipping Rust extension build.", + stacklevel=2, + ) + return False + + # Check if maturin is available + if not shutil.which("maturin"): + try: + # Try uv pip first, fall back to pip + if shutil.which("uv"): + subprocess.run( + ["uv", "pip", "install", "maturin"], + check=True, + capture_output=True, + ) + else: + subprocess.run( + [sys.executable, "-m", "pip", "install", "maturin"], + check=True, + capture_output=True, + ) + except subprocess.CalledProcessError as e: + msg = f"Failed to install maturin: {e}" + if required: + raise RuntimeError(msg) from e + warnings.warn( + f"{msg}. Skipping Rust extension build.", + stacklevel=2, + ) + return False + + # Build the Rust extension + rust_dir = here / "bson" / "_rbson" + if not rust_dir.exists(): + msg = f"Rust extension directory not found: {rust_dir}" + if required: + raise RuntimeError(msg) + return False + + try: + # Build the wheel to a temporary directory + with tempfile.TemporaryDirectory() as tmpdir: + subprocess.run( + [ + "maturin", + "build", + "--release", + "--out", + tmpdir, + "--manifest-path", + str(rust_dir / "Cargo.toml"), + ], + check=True, + cwd=str(rust_dir), + ) + + # Extract the .so file from the wheel + # Find the wheel file + wheel_files = list(Path(tmpdir).glob("*.whl")) + if not wheel_files: + msg = "No wheel file generated by maturin" + if required: + raise RuntimeError(msg) + return False + + # Extract the .so file from the wheel + # The wheel contains _rbson/_rbson.abi3.so, we want bson/_rbson.abi3.so + with zipfile.ZipFile(wheel_files[0], "r") as whl: + for name in whl.namelist(): + if name.endswith((".so", ".pyd")) and "_rbson" in name: + # Extract to bson/ directory + so_data = whl.read(name) + so_name = Path(name).name # Just the filename, e.g., _rbson.abi3.so + dest = here / "bson" / so_name + dest.write_bytes(so_data) + return True + + msg = "No Rust extension binary found in wheel" + if required: + raise RuntimeError(msg) + return False + + except (subprocess.CalledProcessError, Exception) as e: + msg = f"Failed to build Rust extension: {e}" + if required: + raise RuntimeError(msg) from e + warnings.warn( + f"{msg}. The C extension will be used instead.", + stacklevel=2, + ) + return False + def initialize(self, version, build_data): """Initialize the hook.""" if self.target_name == "sdist": @@ -19,7 +133,32 @@ def initialize(self, version, build_data): here = Path(__file__).parent.resolve() sys.path.insert(0, str(here)) - subprocess.run([sys.executable, "_setup.py", "build_ext", "-i"], check=True) + # Build C extensions + try: + subprocess.run([sys.executable, "_setup.py", "build_ext", "-i"], check=True) + except (subprocess.CalledProcessError, FileNotFoundError) as e: + warnings.warn( + f"Failed to build C extension: {e}. " + "The package will be installed without compiled extensions.", + stacklevel=2, + ) + + # Build Rust extension (optional) + # Only build if PYMONGO_BUILD_RUST is set or Rust is available + # Skip for free-threaded Python (not yet supported) + is_free_threaded = hasattr(sys, "_is_gil_enabled") and not sys._is_gil_enabled() + build_rust = os.environ.get("PYMONGO_BUILD_RUST", "").lower() in ("1", "true", "yes") + if build_rust and is_free_threaded: + warnings.warn( + "Rust extension is not yet supported on free-threaded Python. Skipping build.", + stacklevel=2, + ) + elif build_rust: + # If PYMONGO_BUILD_RUST is explicitly set, the build must succeed + self._build_rust_extension(here, required=True) + elif shutil.which("cargo") and not is_free_threaded: + # If Rust is available but not explicitly requested, build is optional + self._build_rust_extension(here, required=False) # Ensure wheel is marked as binary and contains the binary files. build_data["infer_tag"] = True diff --git a/justfile b/justfile index 082b6ea170..29df8c75d4 100644 --- a/justfile +++ b/justfile @@ -57,11 +57,14 @@ lint-manual *args="": && resync [group('test')] test *args="-v --durations=5 --maxfail=10": && resync - uv run --extra test python -m pytest {{args}} + #!/usr/bin/env bash + set -euo pipefail + uv run ${USE_ACTIVE_VENV:+--active} --extra test python -m pytest {{args}} [group('test')] -test-numpy: && resync - uv run --extra test --with numpy python -m pytest test/test_bson.py +test-numpy *args="": && resync + just setup-tests numpy {{args}} + just run-tests test/test_bson.py [group('test')] run-tests *args: && resync @@ -79,6 +82,25 @@ teardown-tests: integration-tests: bash integration_tests/run.sh +[group('test')] +test-coverage *args="": + just setup-tests --cov + just run-tests {{args}} + +[group('coverage')] +coverage-report: + uv tool run --with "coverage[toml]" coverage report + +[group('coverage')] +coverage-html: + uv tool run --with "coverage[toml]" coverage html + @echo "Coverage report generated in htmlcov/index.html" + +[group('coverage')] +coverage-xml: + uv tool run --with "coverage[toml]" coverage xml + @echo "Coverage report generated in coverage.xml" + [group('server')] run-server *args="": bash .evergreen/scripts/run-server.sh {{args}} @@ -86,3 +108,31 @@ run-server *args="": [group('server')] stop-server: bash .evergreen/scripts/stop-server.sh + +[group('rust')] +rust-build: + cd bson/_rbson && ./build.sh + +[group('rust')] +rust-clean: + rm -f bson/_rbson*.so bson/_rbson*.pyd + cd bson/_rbson && cargo clean + +[group('rust')] +rust-rebuild: rust-clean rust-build + +[group('rust')] +rust-install: + PYMONGO_BUILD_RUST=1 pip install --force-reinstall --no-deps . + +[group('rust')] +rust-install-full: + PYMONGO_BUILD_RUST=1 pip install --force-reinstall . + +[group('rust')] +rust-test: + PYMONGO_USE_RUST=1 uv run --extra test python -m pytest test/test_bson.py -v + +[group('rust')] +rust-check: + @python -c 'import os; os.environ["PYMONGO_USE_RUST"] = "1"; import bson; print("Rust extension:", bson.get_bson_implementation())' diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index 151942c8a8..015947d7ef 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -59,6 +59,7 @@ InvalidOperation, NotPrimaryError, OperationFailure, + PyMongoError, WaitQueueTimeoutError, ) from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES @@ -563,9 +564,17 @@ async def _execute_command( error, ConnectionFailure ) and not isinstance(error, (NotPrimaryError, WaitQueueTimeoutError)) + retryable_label_error = isinstance( + error, PyMongoError + ) and error.has_error_label("RetryableError") + # Synthesize the full bulk result without modifying the # current one because this write operation may be retried. - if retryable and (retryable_top_level_error or retryable_network_error): + if retryable and ( + retryable_top_level_error + or retryable_network_error + or retryable_label_error + ): full = copy.deepcopy(full_result) _merge_command(self.ops, self.idx_offset, full, result) _throw_client_bulk_write_exception(full, self.verbose_results) diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index a12ca1f11b..31e6ceb386 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -135,10 +135,13 @@ from __future__ import annotations +import asyncio import collections +import random import time import uuid from collections.abc import Mapping as _Mapping +from contextvars import ContextVar, Token from typing import ( TYPE_CHECKING, Any, @@ -161,7 +164,9 @@ from pymongo.errors import ( ConfigurationError, ConnectionFailure, + ExecutionTimeout, InvalidOperation, + NetworkTimeout, OperationFailure, PyMongoError, WTimeoutError, @@ -181,6 +186,28 @@ _IS_SYNC = False +_SESSION: ContextVar[Optional[AsyncClientSession]] = ContextVar("SESSION", default=None) + + +class _AsyncBoundSessionContext: + """Context manager returned by AsyncClientSession.bind() that manages bound state.""" + + def __init__(self, session: AsyncClientSession, end_session: bool) -> None: + self._session = session + self._session_token: Optional[Token[AsyncClientSession]] = None + self._end_session = end_session + + async def __aenter__(self) -> AsyncClientSession: + self._session_token = _SESSION.set(self._session) # type: ignore[assignment] + return self._session + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self._session_token: + _SESSION.reset(self._session_token) # type: ignore[arg-type] + self._session_token = None + if self._end_session: + await self._session.end_session() + class SessionOptions: """Options for a new :class:`AsyncClientSession`. @@ -404,6 +431,7 @@ def __init__(self, opts: Optional[TransactionOptions], client: AsyncMongoClient[ self.recovery_token = None self.attempt = 0 self.client = client + self.has_completed_command = False def active(self) -> bool: return self.state in (_TxnState.STARTING, _TxnState.IN_PROGRESS) @@ -411,6 +439,9 @@ def active(self) -> bool: def starting(self) -> bool: return self.state == _TxnState.STARTING + def set_starting(self) -> None: + self.state = _TxnState.STARTING + @property def pinned_conn(self) -> Optional[AsyncConnection]: if self.active() and self.conn_mgr: @@ -436,6 +467,7 @@ async def reset(self) -> None: self.sharded = False self.recovery_token = None self.attempt = 0 + self.has_completed_command = False def __del__(self) -> None: if self.conn_mgr: @@ -470,11 +502,29 @@ def _max_time_expired_error(exc: PyMongoError) -> bool: # This limit is non-configurable and was chosen to be twice the 60 second # default value of MongoDB's `transactionLifetimeLimitSeconds` parameter. _WITH_TRANSACTION_RETRY_TIME_LIMIT = 120 +_BACKOFF_MAX = 0.500 # 500ms max backoff +_BACKOFF_INITIAL = 0.005 # 5ms initial backoff -def _within_time_limit(start_time: float) -> bool: +def _within_time_limit(start_time: float, backoff: float = 0) -> bool: """Are we within the with_transaction retry limit?""" - return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT + remaining = _csot.remaining() + if remaining is not None and remaining <= 0: + return False + return time.monotonic() + backoff - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT + + +def _make_timeout_error(error: BaseException) -> PyMongoError: + """Convert error to a NetworkTimeout or ExecutionTimeout as appropriate.""" + if _csot.remaining() is not None: + timeout_error: PyMongoError = ExecutionTimeout( + str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50} + ) + else: + timeout_error = NetworkTimeout(str(error)) + if isinstance(error, PyMongoError): + timeout_error._error_labels = error._error_labels.copy() + return timeout_error _T = TypeVar("_T") @@ -547,6 +597,24 @@ def _check_ended(self) -> None: if self._server_session is None: raise InvalidOperation("Cannot use ended session") + def bind(self, end_session: bool = True) -> _AsyncBoundSessionContext: + """Bind this session so it is implicitly passed to all database operations within the returned context. + + .. code-block:: python + + async with client.start_session() as s: + async with s.bind(): + # session=s is passed implicitly + await client.db.collection.insert_one({"x": 1}) + + :param end_session: Whether to end the session on exiting the returned context. Defaults to True. + If set to False, :meth:`~pymongo.asynchronous.client_session.AsyncClientSession.end_session()` must be called + once the session is no longer used. + + .. versionadded:: 4.17 + """ + return _AsyncBoundSessionContext(self, end_session) + async def __aenter__(self) -> AsyncClientSession: return self @@ -703,7 +771,17 @@ async def callback(session, custom_arg, custom_kwarg=None): https://github.com/mongodb/specifications/blob/master/source/transactions-convenient-api/transactions-convenient-api.md#handling-errors-inside-the-callback """ start_time = time.monotonic() + retry = 0 + last_error: Optional[BaseException] = None while True: + if retry: # Implement exponential backoff on retry. + jitter = random.random() # noqa: S311 + backoff = jitter * min(_BACKOFF_INITIAL * (1.5**retry), _BACKOFF_MAX) + if not _within_time_limit(start_time, backoff): + assert last_error is not None + raise _make_timeout_error(last_error) from last_error + await asyncio.sleep(backoff) + retry += 1 await self.start_transaction( read_concern, write_concern, read_preference, max_commit_time_ms ) @@ -711,15 +789,16 @@ async def callback(session, custom_arg, custom_kwarg=None): ret = await callback(self) # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as exc: + last_error = exc if self.in_transaction: await self.abort_transaction() - if ( - isinstance(exc, PyMongoError) - and exc.has_error_label("TransientTransactionError") - and _within_time_limit(start_time) + if isinstance(exc, PyMongoError) and exc.has_error_label( + "TransientTransactionError" ): - # Retry the entire transaction. - continue + if _within_time_limit(start_time): + # Retry the entire transaction. + continue + raise _make_timeout_error(last_error) from exc raise if not self.in_transaction: @@ -730,17 +809,18 @@ async def callback(session, custom_arg, custom_kwarg=None): try: await self.commit_transaction() except PyMongoError as exc: - if ( - exc.has_error_label("UnknownTransactionCommitResult") - and _within_time_limit(start_time) - and not _max_time_expired_error(exc) - ): + last_error = exc + if exc.has_error_label( + "UnknownTransactionCommitResult" + ) and not _max_time_expired_error(exc): + if not _within_time_limit(start_time): + raise _make_timeout_error(last_error) from exc # Retry the commit. continue - if exc.has_error_label("TransientTransactionError") and _within_time_limit( - start_time - ): + if exc.has_error_label("TransientTransactionError"): + if not _within_time_limit(start_time): + raise _make_timeout_error(last_error) from exc # Retry the entire transaction. break raise diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index 53b4992493..4fff6650f1 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -20,7 +20,6 @@ from typing import ( TYPE_CHECKING, Any, - AsyncContextManager, Callable, Coroutine, Generic, @@ -571,11 +570,6 @@ async def watch( await change_stream._initialize_cursor() return change_stream - async def _conn_for_writes( - self, session: Optional[AsyncClientSession], operation: str - ) -> AsyncContextManager[AsyncConnection]: - return await self._database.client._conn_for_writes(session, operation) - async def _command( self, conn: AsyncConnection, @@ -652,7 +646,10 @@ async def _create_helper( if "size" in options: options["size"] = float(options["size"]) cmd.update(options) - async with await self._conn_for_writes(session, operation=_Op.CREATE) as conn: + + async def inner( + session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool + ) -> None: if qev2_required and conn.max_wire_version < 21: raise ConfigurationError( "Driver support of Queryable Encryption is incompatible with server. " @@ -669,6 +666,8 @@ async def _create_helper( session=session, ) + await self.database.client._retryable_write(False, inner, session, _Op.CREATE) + async def _create( self, options: MutableMapping[str, Any], @@ -2240,7 +2239,10 @@ async def _create_indexes( command (like maxTimeMS) can be passed as keyword arguments. """ names = [] - async with await self._conn_for_writes(session, operation=_Op.CREATE_INDEXES) as conn: + + async def inner( + session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool + ) -> list[str]: supports_quorum = conn.max_wire_version >= 9 def gen_indexes() -> Iterator[Mapping[str, Any]]: @@ -2269,7 +2271,11 @@ def gen_indexes() -> Iterator[Mapping[str, Any]]: write_concern=self._write_concern_for(session), session=session, ) - return names + return names + + return await self.database.client._retryable_write( + False, inner, session, _Op.CREATE_INDEXES + ) async def create_index( self, @@ -2422,7 +2428,6 @@ async def drop_indexes( kwargs["comment"] = comment await self._drop_index("*", session=session, **kwargs) - @_csot.apply async def drop_index( self, index_or_name: _IndexKeyHint, @@ -2490,7 +2495,10 @@ async def _drop_index( cmd.update(kwargs) if comment is not None: cmd["comment"] = comment - async with await self._conn_for_writes(session, operation=_Op.DROP_INDEXES) as conn: + + async def inner( + session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool + ) -> None: await self._command( conn, cmd, @@ -2500,6 +2508,8 @@ async def _drop_index( session=session, ) + await self.database.client._retryable_write(False, inner, session, _Op.DROP_INDEXES) + async def list_indexes( self, session: Optional[AsyncClientSession] = None, @@ -2763,17 +2773,22 @@ def gen_indexes() -> Iterator[Mapping[str, Any]]: cmd = {"createSearchIndexes": self.name, "indexes": list(gen_indexes())} cmd.update(kwargs) - async with await self._conn_for_writes( - session, operation=_Op.CREATE_SEARCH_INDEXES - ) as conn: + async def inner( + session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool + ) -> list[str]: resp = await self._command( conn, cmd, read_preference=ReadPreference.PRIMARY, codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, + session=session, ) return [index["name"] for index in resp["indexesCreated"]] + return await self.database.client._retryable_write( + False, inner, session, _Op.CREATE_SEARCH_INDEXES + ) + async def drop_search_index( self, name: str, @@ -2799,15 +2814,21 @@ async def drop_search_index( cmd.update(kwargs) if comment is not None: cmd["comment"] = comment - async with await self._conn_for_writes(session, operation=_Op.DROP_SEARCH_INDEXES) as conn: + + async def inner( + session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool + ) -> None: await self._command( conn, cmd, read_preference=ReadPreference.PRIMARY, allowable_errors=["ns not found", 26], codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, + session=session, ) + await self.database.client._retryable_write(False, inner, session, _Op.DROP_SEARCH_INDEXES) + async def update_search_index( self, name: str, @@ -2835,15 +2856,21 @@ async def update_search_index( cmd.update(kwargs) if comment is not None: cmd["comment"] = comment - async with await self._conn_for_writes(session, operation=_Op.UPDATE_SEARCH_INDEX) as conn: + + async def inner( + session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool + ) -> None: await self._command( conn, cmd, read_preference=ReadPreference.PRIMARY, allowable_errors=["ns not found", 26], codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, + session=session, ) + await self.database.client._retryable_write(False, inner, session, _Op.UPDATE_SEARCH_INDEX) + async def options( self, session: Optional[AsyncClientSession] = None, @@ -2918,6 +2945,7 @@ async def _aggregate( session, retryable=not cmd._performs_write, operation=_Op.AGGREGATE, + is_aggregate_write=cmd._performs_write, ) async def aggregate( @@ -3123,17 +3151,21 @@ async def rename( if comment is not None: cmd["comment"] = comment write_concern = self._write_concern_for_cmd(cmd, session) + client = self._database.client - async with await self._conn_for_writes(session, operation=_Op.RENAME) as conn: - async with self._database.client._tmp_session(session) as s: - return await conn.command( - "admin", - cmd, - write_concern=write_concern, - parse_write_concern_error=True, - session=s, - client=self._database.client, - ) + async def inner( + session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool + ) -> MutableMapping[str, Any]: + return await conn.command( + "admin", + cmd, + write_concern=write_concern, + parse_write_concern_error=True, + session=session, + client=client, + ) + + return await client._retryable_write(False, inner, session, _Op.RENAME) async def distinct( self, diff --git a/pymongo/asynchronous/database.py b/pymongo/asynchronous/database.py index 5aa206ee24..28ed36073c 100644 --- a/pymongo/asynchronous/database.py +++ b/pymongo/asynchronous/database.py @@ -931,14 +931,15 @@ async def command( if read_preference is None: read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY - async with await self._client._conn_for_reads( - read_preference, session, operation=command_name - ) as ( - connection, - read_preference, - ): + + async def inner( + session: Optional[AsyncClientSession], + _server: Server, + conn: AsyncConnection, + read_preference: _ServerMode, + ) -> Union[dict[str, Any], _CodecDocumentType]: return await self._command( - connection, + conn, command, value, check, @@ -949,6 +950,10 @@ async def command( **kwargs, ) + return await self._client._retryable_read( + inner, read_preference, session, command_name, None, False, is_run_command=True + ) + @_csot.apply async def cursor_command( self, @@ -1016,17 +1021,17 @@ async def cursor_command( async with self._client._tmp_session(session) as tmp_session: opts = codec_options or DEFAULT_CODEC_OPTIONS - if read_preference is None: read_preference = ( tmp_session and tmp_session._txn_read_preference() ) or ReadPreference.PRIMARY - async with await self._client._conn_for_reads( - read_preference, tmp_session, command_name - ) as ( - conn, - read_preference, - ): + + async def inner( + session: Optional[AsyncClientSession], + _server: Server, + conn: AsyncConnection, + read_preference: _ServerMode, + ) -> AsyncCommandCursor[_DocumentType]: response = await self._command( conn, command, @@ -1035,7 +1040,7 @@ async def cursor_command( None, read_preference, opts, - session=tmp_session, + session=session, **kwargs, ) coll = self.get_collection("$cmd", read_preference=read_preference) @@ -1045,7 +1050,7 @@ async def cursor_command( response["cursor"], conn.address, max_await_time_ms=max_await_time_ms, - session=tmp_session, + session=session, comment=comment, ) await cmd_cursor._maybe_pin_connection(conn) @@ -1053,6 +1058,10 @@ async def cursor_command( else: raise InvalidOperation("Command does not return a cursor.") + return await self.client._retryable_read( + inner, read_preference, tmp_session, command_name, None, False + ) + async def _retryable_read_command( self, command: Union[str, MutableMapping[str, Any]], @@ -1254,9 +1263,11 @@ async def _drop_helper( if comment is not None: command["comment"] = comment - async with await self._client._conn_for_writes(session, operation=_Op.DROP) as connection: + async def inner( + session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable_write: bool + ) -> dict[str, Any]: return await self._command( - connection, + conn, command, allowable_errors=["ns not found", 26], write_concern=self._write_concern_for(session), @@ -1264,6 +1275,8 @@ async def _drop_helper( session=session, ) + return await self.client._retryable_write(False, inner, session, _Op.DROP) + @_csot.apply async def drop_collection( self, diff --git a/pymongo/asynchronous/helpers.py b/pymongo/asynchronous/helpers.py index ccda16e28b..16b007c38f 100644 --- a/pymongo/asynchronous/helpers.py +++ b/pymongo/asynchronous/helpers.py @@ -17,8 +17,11 @@ import asyncio import builtins +import functools +import random import socket import sys +import time as time # noqa: PLC0414 # needed in sync version from typing import ( Any, Callable, @@ -26,6 +29,8 @@ cast, ) +from pymongo import _csot +from pymongo.common import MAX_ADAPTIVE_RETRIES from pymongo.errors import ( OperationFailure, ) @@ -38,6 +43,7 @@ def _handle_reauth(func: F) -> F: + @functools.wraps(func) async def inner(*args: Any, **kwargs: Any) -> Any: no_reauth = kwargs.pop("no_reauth", False) from pymongo.asynchronous.pool import AsyncConnection @@ -70,6 +76,46 @@ async def inner(*args: Any, **kwargs: Any) -> Any: return cast(F, inner) +_BACKOFF_INITIAL = 0.1 +_BACKOFF_MAX = 10 + + +def _backoff( + attempt: int, initial_delay: float = _BACKOFF_INITIAL, max_delay: float = _BACKOFF_MAX +) -> float: + jitter = random.random() # noqa: S311 + return jitter * min(initial_delay * (2**attempt), max_delay) + + +class _RetryPolicy: + """A retry limiter that performs exponential backoff with jitter.""" + + def __init__( + self, + attempts: int = MAX_ADAPTIVE_RETRIES, + backoff_initial: float = _BACKOFF_INITIAL, + backoff_max: float = _BACKOFF_MAX, + ): + self.attempts = attempts + self.backoff_initial = backoff_initial + self.backoff_max = backoff_max + + def backoff(self, attempt: int) -> float: + """Return the backoff duration for the given attempt.""" + return _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max) + + async def should_retry(self, attempt: int, delay: float) -> bool: + """Return if we have retry attempts remaining and the next backoff would not exceed a timeout.""" + if attempt > self.attempts: + return False + + if _csot.get_timeout(): + if time.monotonic() + delay > _csot.get_deadline(): + return False + + return True + + async def _getaddrinfo( host: Any, port: Any, **kwargs: Any ) -> list[ diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 4f3c43f23c..03e2d6073a 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -35,6 +35,7 @@ import asyncio import contextlib import os +import time as time # noqa: PLC0414 # needed in sync version import warnings import weakref from collections import defaultdict @@ -65,8 +66,11 @@ from pymongo.asynchronous import client_session, database, uri_parser from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream from pymongo.asynchronous.client_bulk import _AsyncClientBulk -from pymongo.asynchronous.client_session import _EmptyServerSession +from pymongo.asynchronous.client_session import _SESSION, _EmptyServerSession from pymongo.asynchronous.command_cursor import AsyncCommandCursor +from pymongo.asynchronous.helpers import ( + _RetryPolicy, +) from pymongo.asynchronous.settings import TopologySettings from pymongo.asynchronous.topology import Topology, _ErrorContext from pymongo.client_options import ClientOptions @@ -610,8 +614,18 @@ def __init__( client to use Stable API. See `versioned API `_ for details. + | **Overload retry options:** + + - `max_adaptive_retries`: (int) How many retries to allow for overload errors. Defaults to ``2``. + - `enable_overload_retargeting`: (boolean) Whether overload retargeting is enabled for this client. + If enabled, server overload errors will cause retry attempts to select a server that has not yet returned an overload error, if possible. + Defaults to ``False``. + .. seealso:: The MongoDB documentation on `connections `_. + .. versionchanged:: 4.17 + Added the ``max_adaptive_retries`` and ``enable_overload_retargeting`` URI and keyword arguments. + .. versionchanged:: 4.5 Added the ``serverMonitoringMode`` keyword argument. @@ -879,11 +893,14 @@ def __init__( self._options.read_concern, ) + self._retry_policy = _RetryPolicy(attempts=self._options.max_adaptive_retries) + self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name) self._opened = False self._closed = False self._loop: Optional[asyncio.AbstractEventLoop] = None + if not is_srv: self._init_background() @@ -1408,7 +1425,8 @@ def start_session( def _ensure_session( self, session: Optional[AsyncClientSession] = None ) -> Optional[AsyncClientSession]: - """If provided session is None, lend a temporary session.""" + """If provided session and bound session are None, lend a temporary session.""" + session = session or self._get_bound_session() if session: return session @@ -1990,6 +2008,8 @@ async def _retry_internal( read_pref: Optional[_ServerMode] = None, retryable: bool = False, operation_id: Optional[int] = None, + is_run_command: bool = False, + is_aggregate_write: bool = False, ) -> T: """Internal retryable helper for all client transactions. @@ -2001,6 +2021,8 @@ async def _retry_internal( :param address: Server Address, defaults to None :param read_pref: Topology of read operation, defaults to None :param retryable: If the operation should be retried once, defaults to None + :param is_run_command: If this is a runCommand operation, defaults to False + :param is_aggregate_write: If this is a aggregate operation with a write, defaults to False. :return: Output of the calling func() """ @@ -2015,6 +2037,8 @@ async def _retry_internal( address=address, retryable=retryable, operation_id=operation_id, + is_run_command=is_run_command, + is_aggregate_write=is_aggregate_write, ).run() async def _retryable_read( @@ -2026,6 +2050,8 @@ async def _retryable_read( address: Optional[_Address] = None, retryable: bool = True, operation_id: Optional[int] = None, + is_run_command: bool = False, + is_aggregate_write: bool = False, ) -> T: """Execute an operation with consecutive retries if possible @@ -2041,6 +2067,8 @@ async def _retryable_read( :param address: Optional address when sending a message, defaults to None :param retryable: if we should attempt retries (may not always be supported even if supplied), defaults to False + :param is_run_command: If this is a runCommand operation, defaults to False. + :param is_aggregate_write: If this is a aggregate operation with a write, defaults to False. """ # Ensure that the client supports retrying on reads and there is no session in @@ -2059,6 +2087,8 @@ async def _retryable_read( read_pref=read_pref, retryable=retryable, operation_id=operation_id, + is_run_command=is_run_command, + is_aggregate_write=is_aggregate_write, ) async def _retryable_write( @@ -2267,11 +2297,14 @@ async def _tmp_session( self, session: Optional[client_session.AsyncClientSession] ) -> AsyncGenerator[Optional[client_session.AsyncClientSession], None]: """If provided session is None, lend a temporary session.""" - if session is not None: - if not isinstance(session, client_session.AsyncClientSession): - raise ValueError( - f"'session' argument must be an AsyncClientSession or None, not {type(session)}" - ) + if session is not None and not isinstance(session, client_session.AsyncClientSession): + raise ValueError( + f"'session' argument must be an AsyncClientSession or None, not {type(session)}" + ) + + # Check for a bound session. If one exists, treat it as an explicitly passed session. + session = session or self._get_bound_session() + if session: # Don't call end_session. yield session return @@ -2301,6 +2334,18 @@ async def _process_response( if session is not None: session._process_response(reply) + def _get_bound_session(self) -> Optional[AsyncClientSession]: + bound_session = _SESSION.get() + if bound_session: + if bound_session.client is self: + return bound_session + else: + raise InvalidOperation( + "Only the client that created the bound session can perform operations within its context block. See for more information." + ) + else: + return None + async def server_info( self, session: Optional[client_session.AsyncClientSession] = None ) -> dict[str, Any]: @@ -2438,15 +2483,13 @@ async def drop_database( f"name_or_database must be an instance of str or a AsyncDatabase, not {type(name)}" ) - async with await self._conn_for_writes(session, operation=_Op.DROP_DATABASE) as conn: - await self[name]._command( - conn, - {"dropDatabase": 1, "comment": comment}, - read_preference=ReadPreference.PRIMARY, - write_concern=self._write_concern_for(session), - parse_write_concern_error=True, - session=session, - ) + await self[name].command( + {"dropDatabase": 1, "comment": comment}, + read_preference=ReadPreference.PRIMARY, + write_concern=self._write_concern_for(session), + parse_write_concern_error=True, + session=session, + ) @_csot.apply async def bulk_write( @@ -2730,12 +2773,15 @@ def __init__( address: Optional[_Address] = None, retryable: bool = False, operation_id: Optional[int] = None, + is_run_command: bool = False, + is_aggregate_write: bool = False, ): self._last_error: Optional[Exception] = None self._retrying = False + self._always_retryable = False self._multiple_retries = _csot.get_timeout() is not None self._client = mongo_client - + self._retry_policy = mongo_client._retry_policy self._func = func self._bulk = bulk self._session = session @@ -2751,6 +2797,8 @@ def __init__( self._operation = operation self._operation_id = operation_id self._attempt_number = 0 + self._is_run_command = is_run_command + self._is_aggregate_write = is_aggregate_write async def run(self) -> T: """Runs the supplied func() and attempts a retry @@ -2770,7 +2818,13 @@ async def run(self) -> T: while True: self._check_last_error(check_csot=True) try: - return await self._read() if self._is_read else await self._write() + res = await self._read() if self._is_read else await self._write() + # Track whether the transaction has completed a command. + # If we need to apply backpressure to the first command, + # we will need to revert back to starting state. + if self._session is not None and self._session.in_transaction: + self._session._transaction.has_completed_command = True + return res except ServerSelectionTimeoutError: # The application may think the write was never attempted # if we raise ServerSelectionTimeoutError on the retry @@ -2781,37 +2835,76 @@ async def run(self) -> T: # most likely be a waste of time. raise except PyMongoError as exc: + always_retryable = False + overloaded = False + exc_to_check = exc + + if self._is_run_command and not ( + self._client.options.retry_reads and self._client.options.retry_writes + ): + raise + if self._is_aggregate_write and not self._client.options.retry_writes: + raise + # Execute specialized catch on read if self._is_read: if isinstance(exc, (ConnectionFailure, OperationFailure)): # ConnectionFailures do not supply a code property exc_code = getattr(exc, "code", None) - if self._is_not_eligible_for_retry() or ( - isinstance(exc, OperationFailure) - and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES + overloaded = exc.has_error_label("SystemOverloadedError") + always_retryable = exc.has_error_label("RetryableError") and overloaded + if not self._client.options.retry_reads or ( + not always_retryable + and ( + self._is_not_eligible_for_retry() + or ( + isinstance(exc, OperationFailure) + and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES + ) + ) ): raise self._retrying = True self._last_error = exc self._attempt_number += 1 + + # Revert back to starting state if we're in a transaction but haven't completed the first + # command. + if ( + overloaded + and self._session is not None + and self._session.in_transaction + ): + transaction = self._session._transaction + if not transaction.has_completed_command: + transaction.set_starting() + transaction.attempt = 0 else: raise # Specialized catch on write operation if not self._is_read: - if not self._retryable: + if isinstance(exc, ClientBulkWriteException) and isinstance( + exc.error, PyMongoError + ): + exc_to_check = exc.error + retryable_write_label = exc_to_check.has_error_label("RetryableWriteError") + overloaded = exc_to_check.has_error_label("SystemOverloadedError") + always_retryable = exc_to_check.has_error_label("RetryableError") and overloaded + + # Always retry abortTransaction and commitTransaction up to once + if self._operation not in ["abortTransaction", "commitTransaction"] and ( + not self._client.options.retry_writes + or not (self._retryable or always_retryable) + ): raise - if isinstance(exc, ClientBulkWriteException) and exc.error: - retryable_write_error_exc = isinstance( - exc.error, PyMongoError - ) and exc.error.has_error_label("RetryableWriteError") - else: - retryable_write_error_exc = exc.has_error_label("RetryableWriteError") - if retryable_write_error_exc: + if retryable_write_label or always_retryable: assert self._session await self._session._unpin() - if not retryable_write_error_exc or self._is_not_eligible_for_retry(): - if exc.has_error_label("NoWritesPerformed") and self._last_error: + if not always_retryable and ( + not retryable_write_label or self._is_not_eligible_for_retry() + ): + if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error: raise self._last_error from exc else: raise @@ -2820,14 +2913,34 @@ async def run(self) -> T: self._bulk.retrying = True else: self._retrying = True - if not exc.has_error_label("NoWritesPerformed"): + if not exc_to_check.has_error_label("NoWritesPerformed"): self._last_error = exc if self._last_error is None: self._last_error = exc - - if self._server is not None: + # Revert back to starting state if we're in a transaction but haven't completed the first + # command. + if overloaded and self._session is not None and self._session.in_transaction: + transaction = self._session._transaction + if not transaction.has_completed_command: + transaction.set_starting() + transaction.attempt = 0 + + if self._server is not None and ( + self._client.topology_description.topology_type_name == "Sharded" + or (overloaded and self._client.options.enable_overload_retargeting) + ): self._deprioritized_servers.append(self._server) + self._always_retryable = always_retryable + if overloaded: + delay = self._retry_policy.backoff(self._attempt_number) + if not await self._retry_policy.should_retry(self._attempt_number, delay): + if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error: + raise self._last_error from exc + else: + raise + await asyncio.sleep(delay) + def _is_not_eligible_for_retry(self) -> bool: """Checks if the exchange is not eligible for retry""" return not self._retryable or (self._is_retrying() and not self._multiple_retries) @@ -2889,7 +3002,7 @@ async def _write(self) -> T: and conn.supports_sessions ) is_mongos = conn.is_mongos - if not sessions_supported: + if not self._always_retryable and not sessions_supported: # A retry is not possible because this server does # not support sessions raise the last error. self._check_last_error() @@ -2921,7 +3034,7 @@ async def _read(self) -> T: conn, read_pref, ): - if self._retrying and not self._retryable: + if self._retrying and not self._retryable and not self._always_retryable: self._check_last_error() if self._retrying: _debug_log( diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 4e1e7c0638..3c1a85246e 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -19,6 +19,8 @@ import contextlib import logging import os +import socket +import ssl import sys import time import weakref @@ -52,10 +54,12 @@ DocumentTooLarge, ExecutionTimeout, InvalidOperation, + NetworkTimeout, NotPrimaryError, OperationFailure, PyMongoError, WaitQueueTimeoutError, + _CertificateError, ) from pymongo.hello import Hello, HelloCompat from pymongo.helpers_shared import _get_timeout_details, format_timeout_details @@ -250,6 +254,7 @@ async def _hello( cmd = self.hello_cmd() performing_handshake = not self.performed_handshake awaitable = False + cmd["backpressure"] = True if performing_handshake: self.performed_handshake = True cmd["client"] = self.opts.metadata @@ -752,8 +757,8 @@ def __init__( # Enforces: maxConnecting # Also used for: clearing the wait queue self._max_connecting_cond = _async_create_condition(self.lock) - self._max_connecting = self.opts.max_connecting self._pending = 0 + self._max_connecting = self.opts.max_connecting self._client_id = client_id if self.enabled_for_cmap: assert self.opts._event_listeners is not None @@ -986,6 +991,21 @@ async def remove_stale_sockets(self, reference_generation: int) -> None: self.requests -= 1 self.size_cond.notify() + def _handle_connection_error(self, error: BaseException) -> None: + # Handle system overload condition for non-sdam pools. + # Look for errors of type AutoReconnect and add error labels if appropriate. + if self.is_sdam or type(error) not in (AutoReconnect, NetworkTimeout): + return + assert isinstance(error, AutoReconnect) # Appease type checker. + # If the original error was a DNS, certificate, or SSL error, ignore it. + if isinstance(error.__cause__, (_CertificateError, SSLErrors, socket.gaierror)): + # End of file errors are excluded, because the server may have disconnected + # during the handshake. + if not isinstance(error.__cause__, (ssl.SSLEOFError, ssl.SSLZeroReturnError)): + return + error._add_error_label("SystemOverloadedError") + error._add_error_label("RetryableError") + async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> AsyncConnection: """Connect to Mongo and return a new AsyncConnection. @@ -1037,10 +1057,10 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A reason=_verbose_connection_error_reason(ConnectionClosedReason.ERROR), error=ConnectionClosedReason.ERROR, ) + self._handle_connection_error(error) if isinstance(error, (IOError, OSError, *SSLErrors)): details = _get_timeout_details(self.opts) _raise_connection_failure(self.address, error, timeout_details=details) - raise conn = AsyncConnection(networking_interface, self, self.address, conn_id, self.is_sdam) # type: ignore[arg-type] @@ -1049,18 +1069,22 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A self.active_contexts.discard(tmp_context) if tmp_context.cancelled: conn.cancel_context.cancel() + completed_hello = False try: if not self.is_sdam: await conn.hello() + completed_hello = True self.is_writable = conn.is_writable if handler: handler.contribute_socket(conn, completed_handshake=False) await conn.authenticate() # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. - except BaseException: + except BaseException as e: async with self.lock: self.active_contexts.discard(conn.cancel_context) + if not completed_hello: + self._handle_connection_error(e) await conn.close_conn(ConnectionClosedReason.ERROR) raise @@ -1389,8 +1413,8 @@ async def _perished(self, conn: AsyncConnection) -> bool: :class:`~pymongo.errors.AutoReconnect` exceptions on server hiccups, etc. We only check if the socket was closed by an external error if it has been > 1 second since the socket was checked into the - pool, to keep performance reasonable - we can't avoid AutoReconnects - completely anyway. + pool to keep performance reasonable - + we can't avoid AutoReconnects completely anyway. """ idle_time_seconds = conn.idle_time_seconds() # If socket is idle, open a new one. @@ -1401,8 +1425,9 @@ async def _perished(self, conn: AsyncConnection) -> bool: await conn.close_conn(ConnectionClosedReason.IDLE) return True - if self._check_interval_seconds is not None and ( - self._check_interval_seconds == 0 or idle_time_seconds > self._check_interval_seconds + check_interval_seconds = self._check_interval_seconds + if check_interval_seconds is not None and ( + check_interval_seconds == 0 or idle_time_seconds > check_interval_seconds ): if conn.conn_closed(): await conn.close_conn(ConnectionClosedReason.ERROR) diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index c171848cac..01e346bfa8 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -913,7 +913,9 @@ async def _handle_error(self, address: _Address, err_ctx: _ErrorContext) -> None # Clear the pool. await server.reset(service_id) elif isinstance(error, ConnectionFailure): - if isinstance(error, WaitQueueTimeoutError): + if isinstance(error, WaitQueueTimeoutError) or ( + error.has_error_label("SystemOverloadedError") + ): return # "Client MUST replace the server's description with type Unknown # ... MUST NOT request an immediate check of the server." diff --git a/pymongo/client_options.py b/pymongo/client_options.py index 8b4eea7e65..e5dc609946 100644 --- a/pymongo/client_options.py +++ b/pymongo/client_options.py @@ -235,6 +235,16 @@ def __init__( self.__server_monitoring_mode = options.get( "servermonitoringmode", common.SERVER_MONITORING_MODE ) + self.__max_adaptive_retries = ( + options.get("max_adaptive_retries", common.MAX_ADAPTIVE_RETRIES) + if "max_adaptive_retries" in options + else options.get("maxadaptiveretries", common.MAX_ADAPTIVE_RETRIES) + ) + self.__enable_overload_retargeting = ( + options.get("enable_overload_retargeting", common.ENABLE_OVERLOAD_RETARGETING) + if "enable_overload_retargeting" in options + else options.get("enableoverloadretargeting", common.ENABLE_OVERLOAD_RETARGETING) + ) @property def _options(self) -> Mapping[str, Any]: @@ -346,3 +356,19 @@ def server_monitoring_mode(self) -> str: .. versionadded:: 4.5 """ return self.__server_monitoring_mode + + @property + def max_adaptive_retries(self) -> int: + """The configured maxAdaptiveRetries option. + + .. versionadded:: 4.17 + """ + return self.__max_adaptive_retries + + @property + def enable_overload_retargeting(self) -> bool: + """The configured enableOverloadRetargeting option. + + .. versionadded:: 4.17 + """ + return self.__enable_overload_retargeting diff --git a/pymongo/common.py b/pymongo/common.py index e23adac426..ea349b3d23 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -140,6 +140,12 @@ # Default value for serverMonitoringMode SERVER_MONITORING_MODE = "auto" # poll/stream/auto +# Default value for max adaptive retries +MAX_ADAPTIVE_RETRIES = 2 + +# Default value for enableOverloadRetargeting +ENABLE_OVERLOAD_RETARGETING = False + # Auth mechanism properties that must raise an error instead of warning if they invalidate. _MECH_PROP_MUST_RAISE = ["CANONICALIZE_HOST_NAME"] @@ -233,13 +239,6 @@ def validate_readable(option: str, value: Any) -> Optional[str]: return value -def validate_positive_integer_or_none(option: str, value: Any) -> Optional[int]: - """Validate that 'value' is a positive integer or None.""" - if value is None: - return value - return validate_positive_integer(option, value) - - def validate_non_negative_integer_or_none(option: str, value: Any) -> Optional[int]: """Validate that 'value' is a positive integer or 0 or None.""" if value is None: @@ -261,20 +260,6 @@ def validate_string_or_none(option: str, value: Any) -> Optional[str]: return validate_string(option, value) -def validate_int_or_basestring(option: str, value: Any) -> Union[int, str]: - """Validates that 'value' is an integer or string.""" - if isinstance(value, int): - return value - elif isinstance(value, str): - try: - return int(value) - except ValueError: - return value - raise TypeError( - f"Wrong type for {option}, value must be an integer or a string, not {type(value)}" - ) - - def validate_non_negative_int_or_basestring(option: Any, value: Any) -> Union[int, str]: """Validates that 'value' is an integer or string.""" if isinstance(value, int): @@ -738,6 +723,8 @@ def validate_server_monitoring_mode(option: str, value: str) -> str: "srvmaxhosts": validate_non_negative_integer, "timeoutms": validate_timeoutms, "servermonitoringmode": validate_server_monitoring_mode, + "maxadaptiveretries": validate_non_negative_integer, + "enableoverloadretargeting": validate_boolean_or_string, } # Dictionary where keys are the names of URI options specific to pymongo, @@ -771,6 +758,8 @@ def validate_server_monitoring_mode(option: str, value: str) -> str: "server_selector": validate_is_callable_or_none, "auto_encryption_opts": validate_auto_encryption_opts_or_none, "authoidcallowedhosts": validate_list, + "max_adaptive_retries": validate_non_negative_integer, + "enable_overload_retargeting": validate_boolean_or_string, } # Dictionary where keys are any URI option name, and values are the @@ -817,16 +806,6 @@ def validate_server_monitoring_mode(option: str, value: str) -> str: "waitqueuetimeoutms", ] -_AUTH_OPTIONS = frozenset(["authmechanismproperties"]) - - -def validate_auth_option(option: str, value: Any) -> tuple[str, Any]: - """Validate optional authentication parameters.""" - lower, value = validate(option, value) - if lower not in _AUTH_OPTIONS: - raise ConfigurationError(f"Unknown option: {option}. Must be in {_AUTH_OPTIONS}") - return option, value - def _get_validator( key: str, validators: dict[str, Callable[[Any, Any], Any]], normed_key: Optional[str] = None diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index a606d028e1..1134594ae9 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -59,6 +59,7 @@ InvalidOperation, NotPrimaryError, OperationFailure, + PyMongoError, WaitQueueTimeoutError, ) from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES @@ -561,9 +562,17 @@ def _execute_command( error, ConnectionFailure ) and not isinstance(error, (NotPrimaryError, WaitQueueTimeoutError)) + retryable_label_error = isinstance( + error, PyMongoError + ) and error.has_error_label("RetryableError") + # Synthesize the full bulk result without modifying the # current one because this write operation may be retried. - if retryable and (retryable_top_level_error or retryable_network_error): + if retryable and ( + retryable_top_level_error + or retryable_network_error + or retryable_label_error + ): full = copy.deepcopy(full_result) _merge_command(self.ops, self.idx_offset, full, result) _throw_client_bulk_write_exception(full, self.verbose_results) diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index 8755e57261..3165dd52b7 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -136,9 +136,11 @@ from __future__ import annotations import collections +import random import time import uuid from collections.abc import Mapping as _Mapping +from contextvars import ContextVar, Token from typing import ( TYPE_CHECKING, Any, @@ -159,7 +161,9 @@ from pymongo.errors import ( ConfigurationError, ConnectionFailure, + ExecutionTimeout, InvalidOperation, + NetworkTimeout, OperationFailure, PyMongoError, WTimeoutError, @@ -180,6 +184,28 @@ _IS_SYNC = True +_SESSION: ContextVar[Optional[ClientSession]] = ContextVar("SESSION", default=None) + + +class _BoundSessionContext: + """Context manager returned by ClientSession.bind() that manages bound state.""" + + def __init__(self, session: ClientSession, end_session: bool) -> None: + self._session = session + self._session_token: Optional[Token[ClientSession]] = None + self._end_session = end_session + + def __enter__(self) -> ClientSession: + self._session_token = _SESSION.set(self._session) # type: ignore[assignment] + return self._session + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self._session_token: + _SESSION.reset(self._session_token) # type: ignore[arg-type] + self._session_token = None + if self._end_session: + self._session.end_session() + class SessionOptions: """Options for a new :class:`ClientSession`. @@ -403,6 +429,7 @@ def __init__(self, opts: Optional[TransactionOptions], client: MongoClient[Any]) self.recovery_token = None self.attempt = 0 self.client = client + self.has_completed_command = False def active(self) -> bool: return self.state in (_TxnState.STARTING, _TxnState.IN_PROGRESS) @@ -410,6 +437,9 @@ def active(self) -> bool: def starting(self) -> bool: return self.state == _TxnState.STARTING + def set_starting(self) -> None: + self.state = _TxnState.STARTING + @property def pinned_conn(self) -> Optional[Connection]: if self.active() and self.conn_mgr: @@ -435,6 +465,7 @@ def reset(self) -> None: self.sharded = False self.recovery_token = None self.attempt = 0 + self.has_completed_command = False def __del__(self) -> None: if self.conn_mgr: @@ -469,11 +500,29 @@ def _max_time_expired_error(exc: PyMongoError) -> bool: # This limit is non-configurable and was chosen to be twice the 60 second # default value of MongoDB's `transactionLifetimeLimitSeconds` parameter. _WITH_TRANSACTION_RETRY_TIME_LIMIT = 120 +_BACKOFF_MAX = 0.500 # 500ms max backoff +_BACKOFF_INITIAL = 0.005 # 5ms initial backoff -def _within_time_limit(start_time: float) -> bool: +def _within_time_limit(start_time: float, backoff: float = 0) -> bool: """Are we within the with_transaction retry limit?""" - return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT + remaining = _csot.remaining() + if remaining is not None and remaining <= 0: + return False + return time.monotonic() + backoff - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT + + +def _make_timeout_error(error: BaseException) -> PyMongoError: + """Convert error to a NetworkTimeout or ExecutionTimeout as appropriate.""" + if _csot.remaining() is not None: + timeout_error: PyMongoError = ExecutionTimeout( + str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50} + ) + else: + timeout_error = NetworkTimeout(str(error)) + if isinstance(error, PyMongoError): + timeout_error._error_labels = error._error_labels.copy() + return timeout_error _T = TypeVar("_T") @@ -546,6 +595,24 @@ def _check_ended(self) -> None: if self._server_session is None: raise InvalidOperation("Cannot use ended session") + def bind(self, end_session: bool = True) -> _BoundSessionContext: + """Bind this session so it is implicitly passed to all database operations within the returned context. + + .. code-block:: python + + with client.start_session() as s: + with s.bind(): + # session=s is passed implicitly + client.db.collection.insert_one({"x": 1}) + + :param end_session: Whether to end the session on exiting the returned context. Defaults to True. + If set to False, :meth:`~pymongo.client_session.ClientSession.end_session()` must be called + once the session is no longer used. + + .. versionadded:: 4.17 + """ + return _BoundSessionContext(self, end_session) + def __enter__(self) -> ClientSession: return self @@ -702,21 +769,32 @@ def callback(session, custom_arg, custom_kwarg=None): https://github.com/mongodb/specifications/blob/master/source/transactions-convenient-api/transactions-convenient-api.md#handling-errors-inside-the-callback """ start_time = time.monotonic() + retry = 0 + last_error: Optional[BaseException] = None while True: + if retry: # Implement exponential backoff on retry. + jitter = random.random() # noqa: S311 + backoff = jitter * min(_BACKOFF_INITIAL * (1.5**retry), _BACKOFF_MAX) + if not _within_time_limit(start_time, backoff): + assert last_error is not None + raise _make_timeout_error(last_error) from last_error + time.sleep(backoff) + retry += 1 self.start_transaction(read_concern, write_concern, read_preference, max_commit_time_ms) try: ret = callback(self) # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as exc: + last_error = exc if self.in_transaction: self.abort_transaction() - if ( - isinstance(exc, PyMongoError) - and exc.has_error_label("TransientTransactionError") - and _within_time_limit(start_time) + if isinstance(exc, PyMongoError) and exc.has_error_label( + "TransientTransactionError" ): - # Retry the entire transaction. - continue + if _within_time_limit(start_time): + # Retry the entire transaction. + continue + raise _make_timeout_error(last_error) from exc raise if not self.in_transaction: @@ -727,17 +805,18 @@ def callback(session, custom_arg, custom_kwarg=None): try: self.commit_transaction() except PyMongoError as exc: - if ( - exc.has_error_label("UnknownTransactionCommitResult") - and _within_time_limit(start_time) - and not _max_time_expired_error(exc) - ): + last_error = exc + if exc.has_error_label( + "UnknownTransactionCommitResult" + ) and not _max_time_expired_error(exc): + if not _within_time_limit(start_time): + raise _make_timeout_error(last_error) from exc # Retry the commit. continue - if exc.has_error_label("TransientTransactionError") and _within_time_limit( - start_time - ): + if exc.has_error_label("TransientTransactionError"): + if not _within_time_limit(start_time): + raise _make_timeout_error(last_error) from exc # Retry the entire transaction. break raise diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index edc6047330..1057151e59 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -21,7 +21,6 @@ TYPE_CHECKING, Any, Callable, - ContextManager, Generic, Iterable, Iterator, @@ -572,11 +571,6 @@ def watch( change_stream._initialize_cursor() return change_stream - def _conn_for_writes( - self, session: Optional[ClientSession], operation: str - ) -> ContextManager[Connection]: - return self._database.client._conn_for_writes(session, operation) - def _command( self, conn: Connection, @@ -653,7 +647,10 @@ def _create_helper( if "size" in options: options["size"] = float(options["size"]) cmd.update(options) - with self._conn_for_writes(session, operation=_Op.CREATE) as conn: + + def inner( + session: Optional[ClientSession], conn: Connection, _retryable_write: bool + ) -> None: if qev2_required and conn.max_wire_version < 21: raise ConfigurationError( "Driver support of Queryable Encryption is incompatible with server. " @@ -670,6 +667,8 @@ def _create_helper( session=session, ) + self.database.client._retryable_write(False, inner, session, _Op.CREATE) + def _create( self, options: MutableMapping[str, Any], @@ -2237,7 +2236,10 @@ def _create_indexes( command (like maxTimeMS) can be passed as keyword arguments. """ names = [] - with self._conn_for_writes(session, operation=_Op.CREATE_INDEXES) as conn: + + def inner( + session: Optional[ClientSession], conn: Connection, _retryable_write: bool + ) -> list[str]: supports_quorum = conn.max_wire_version >= 9 def gen_indexes() -> Iterator[Mapping[str, Any]]: @@ -2266,7 +2268,9 @@ def gen_indexes() -> Iterator[Mapping[str, Any]]: write_concern=self._write_concern_for(session), session=session, ) - return names + return names + + return self.database.client._retryable_write(False, inner, session, _Op.CREATE_INDEXES) def create_index( self, @@ -2419,7 +2423,6 @@ def drop_indexes( kwargs["comment"] = comment self._drop_index("*", session=session, **kwargs) - @_csot.apply def drop_index( self, index_or_name: _IndexKeyHint, @@ -2487,7 +2490,10 @@ def _drop_index( cmd.update(kwargs) if comment is not None: cmd["comment"] = comment - with self._conn_for_writes(session, operation=_Op.DROP_INDEXES) as conn: + + def inner( + session: Optional[ClientSession], conn: Connection, _retryable_write: bool + ) -> None: self._command( conn, cmd, @@ -2497,6 +2503,8 @@ def _drop_index( session=session, ) + self.database.client._retryable_write(False, inner, session, _Op.DROP_INDEXES) + def list_indexes( self, session: Optional[ClientSession] = None, @@ -2760,15 +2768,22 @@ def gen_indexes() -> Iterator[Mapping[str, Any]]: cmd = {"createSearchIndexes": self.name, "indexes": list(gen_indexes())} cmd.update(kwargs) - with self._conn_for_writes(session, operation=_Op.CREATE_SEARCH_INDEXES) as conn: + def inner( + session: Optional[ClientSession], conn: Connection, _retryable_write: bool + ) -> list[str]: resp = self._command( conn, cmd, read_preference=ReadPreference.PRIMARY, codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, + session=session, ) return [index["name"] for index in resp["indexesCreated"]] + return self.database.client._retryable_write( + False, inner, session, _Op.CREATE_SEARCH_INDEXES + ) + def drop_search_index( self, name: str, @@ -2794,15 +2809,21 @@ def drop_search_index( cmd.update(kwargs) if comment is not None: cmd["comment"] = comment - with self._conn_for_writes(session, operation=_Op.DROP_SEARCH_INDEXES) as conn: + + def inner( + session: Optional[ClientSession], conn: Connection, _retryable_write: bool + ) -> None: self._command( conn, cmd, read_preference=ReadPreference.PRIMARY, allowable_errors=["ns not found", 26], codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, + session=session, ) + self.database.client._retryable_write(False, inner, session, _Op.DROP_SEARCH_INDEXES) + def update_search_index( self, name: str, @@ -2830,15 +2851,21 @@ def update_search_index( cmd.update(kwargs) if comment is not None: cmd["comment"] = comment - with self._conn_for_writes(session, operation=_Op.UPDATE_SEARCH_INDEX) as conn: + + def inner( + session: Optional[ClientSession], conn: Connection, _retryable_write: bool + ) -> None: self._command( conn, cmd, read_preference=ReadPreference.PRIMARY, allowable_errors=["ns not found", 26], codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, + session=session, ) + self.database.client._retryable_write(False, inner, session, _Op.UPDATE_SEARCH_INDEX) + def options( self, session: Optional[ClientSession] = None, @@ -2911,6 +2938,7 @@ def _aggregate( session, retryable=not cmd._performs_write, operation=_Op.AGGREGATE, + is_aggregate_write=cmd._performs_write, ) def aggregate( @@ -3116,17 +3144,21 @@ def rename( if comment is not None: cmd["comment"] = comment write_concern = self._write_concern_for_cmd(cmd, session) + client = self._database.client - with self._conn_for_writes(session, operation=_Op.RENAME) as conn: - with self._database.client._tmp_session(session) as s: - return conn.command( - "admin", - cmd, - write_concern=write_concern, - parse_write_concern_error=True, - session=s, - client=self._database.client, - ) + def inner( + session: Optional[ClientSession], conn: Connection, _retryable_write: bool + ) -> MutableMapping[str, Any]: + return conn.command( + "admin", + cmd, + write_concern=write_concern, + parse_write_concern_error=True, + session=session, + client=client, + ) + + return client._retryable_write(False, inner, session, _Op.RENAME) def distinct( self, diff --git a/pymongo/synchronous/database.py b/pymongo/synchronous/database.py index a453a94265..1956795bcb 100644 --- a/pymongo/synchronous/database.py +++ b/pymongo/synchronous/database.py @@ -931,12 +931,15 @@ def command( if read_preference is None: read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY - with self._client._conn_for_reads(read_preference, session, operation=command_name) as ( - connection, - read_preference, - ): + + def inner( + session: Optional[ClientSession], + _server: Server, + conn: Connection, + read_preference: _ServerMode, + ) -> Union[dict[str, Any], _CodecDocumentType]: return self._command( - connection, + conn, command, value, check, @@ -947,6 +950,10 @@ def command( **kwargs, ) + return self._client._retryable_read( + inner, read_preference, session, command_name, None, False, is_run_command=True + ) + @_csot.apply def cursor_command( self, @@ -1014,15 +1021,17 @@ def cursor_command( with self._client._tmp_session(session) as tmp_session: opts = codec_options or DEFAULT_CODEC_OPTIONS - if read_preference is None: read_preference = ( tmp_session and tmp_session._txn_read_preference() ) or ReadPreference.PRIMARY - with self._client._conn_for_reads(read_preference, tmp_session, command_name) as ( - conn, - read_preference, - ): + + def inner( + session: Optional[ClientSession], + _server: Server, + conn: Connection, + read_preference: _ServerMode, + ) -> CommandCursor[_DocumentType]: response = self._command( conn, command, @@ -1031,7 +1040,7 @@ def cursor_command( None, read_preference, opts, - session=tmp_session, + session=session, **kwargs, ) coll = self.get_collection("$cmd", read_preference=read_preference) @@ -1041,7 +1050,7 @@ def cursor_command( response["cursor"], conn.address, max_await_time_ms=max_await_time_ms, - session=tmp_session, + session=session, comment=comment, ) cmd_cursor._maybe_pin_connection(conn) @@ -1049,6 +1058,10 @@ def cursor_command( else: raise InvalidOperation("Command does not return a cursor.") + return self.client._retryable_read( + inner, read_preference, tmp_session, command_name, None, False + ) + def _retryable_read_command( self, command: Union[str, MutableMapping[str, Any]], @@ -1247,9 +1260,11 @@ def _drop_helper( if comment is not None: command["comment"] = comment - with self._client._conn_for_writes(session, operation=_Op.DROP) as connection: + def inner( + session: Optional[ClientSession], conn: Connection, _retryable_write: bool + ) -> dict[str, Any]: return self._command( - connection, + conn, command, allowable_errors=["ns not found", 26], write_concern=self._write_concern_for(session), @@ -1257,6 +1272,8 @@ def _drop_helper( session=session, ) + return self.client._retryable_write(False, inner, session, _Op.DROP) + @_csot.apply def drop_collection( self, diff --git a/pymongo/synchronous/helpers.py b/pymongo/synchronous/helpers.py index 1fff9a0f23..bbe8963fe7 100644 --- a/pymongo/synchronous/helpers.py +++ b/pymongo/synchronous/helpers.py @@ -17,8 +17,11 @@ import asyncio import builtins +import functools +import random import socket import sys +import time as time # noqa: PLC0414 # needed in sync version from typing import ( Any, Callable, @@ -26,6 +29,8 @@ cast, ) +from pymongo import _csot +from pymongo.common import MAX_ADAPTIVE_RETRIES from pymongo.errors import ( OperationFailure, ) @@ -38,6 +43,7 @@ def _handle_reauth(func: F) -> F: + @functools.wraps(func) def inner(*args: Any, **kwargs: Any) -> Any: no_reauth = kwargs.pop("no_reauth", False) from pymongo.message import _BulkWriteContext @@ -70,6 +76,46 @@ def inner(*args: Any, **kwargs: Any) -> Any: return cast(F, inner) +_BACKOFF_INITIAL = 0.1 +_BACKOFF_MAX = 10 + + +def _backoff( + attempt: int, initial_delay: float = _BACKOFF_INITIAL, max_delay: float = _BACKOFF_MAX +) -> float: + jitter = random.random() # noqa: S311 + return jitter * min(initial_delay * (2**attempt), max_delay) + + +class _RetryPolicy: + """A retry limiter that performs exponential backoff with jitter.""" + + def __init__( + self, + attempts: int = MAX_ADAPTIVE_RETRIES, + backoff_initial: float = _BACKOFF_INITIAL, + backoff_max: float = _BACKOFF_MAX, + ): + self.attempts = attempts + self.backoff_initial = backoff_initial + self.backoff_max = backoff_max + + def backoff(self, attempt: int) -> float: + """Return the backoff duration for the given attempt.""" + return _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max) + + def should_retry(self, attempt: int, delay: float) -> bool: + """Return if we have retry attempts remaining and the next backoff would not exceed a timeout.""" + if attempt > self.attempts: + return False + + if _csot.get_timeout(): + if time.monotonic() + delay > _csot.get_deadline(): + return False + + return True + + def _getaddrinfo( host: Any, port: Any, **kwargs: Any ) -> list[ diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index cd0d19141f..c049dcaeae 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -35,6 +35,7 @@ import asyncio import contextlib import os +import time as time # noqa: PLC0414 # needed in sync version import warnings import weakref from collections import defaultdict @@ -108,8 +109,11 @@ from pymongo.synchronous import client_session, database, uri_parser from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream from pymongo.synchronous.client_bulk import _ClientBulk -from pymongo.synchronous.client_session import _EmptyServerSession +from pymongo.synchronous.client_session import _SESSION, _EmptyServerSession from pymongo.synchronous.command_cursor import CommandCursor +from pymongo.synchronous.helpers import ( + _RetryPolicy, +) from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology, _ErrorContext from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription @@ -610,8 +614,18 @@ def __init__( client to use Stable API. See `versioned API `_ for details. + | **Overload retry options:** + + - `max_adaptive_retries`: (int) How many retries to allow for overload errors. Defaults to ``2``. + - `enable_overload_retargeting`: (boolean) Whether overload retargeting is enabled for this client. + If enabled, server overload errors will cause retry attempts to select a server that has not yet returned an overload error, if possible. + Defaults to ``False``. + .. seealso:: The MongoDB documentation on `connections `_. + .. versionchanged:: 4.17 + Added the ``max_adaptive_retries`` and ``enable_overload_retargeting`` URI and keyword arguments. + .. versionchanged:: 4.5 Added the ``serverMonitoringMode`` keyword argument. @@ -879,11 +893,14 @@ def __init__( self._options.read_concern, ) + self._retry_policy = _RetryPolicy(attempts=self._options.max_adaptive_retries) + self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name) self._opened = False self._closed = False self._loop: Optional[asyncio.AbstractEventLoop] = None + if not is_srv: self._init_background() @@ -1406,7 +1423,8 @@ def start_session( ) def _ensure_session(self, session: Optional[ClientSession] = None) -> Optional[ClientSession]: - """If provided session is None, lend a temporary session.""" + """If provided session and bound session are None, lend a temporary session.""" + session = session or self._get_bound_session() if session: return session @@ -1986,6 +2004,8 @@ def _retry_internal( read_pref: Optional[_ServerMode] = None, retryable: bool = False, operation_id: Optional[int] = None, + is_run_command: bool = False, + is_aggregate_write: bool = False, ) -> T: """Internal retryable helper for all client transactions. @@ -1997,6 +2017,8 @@ def _retry_internal( :param address: Server Address, defaults to None :param read_pref: Topology of read operation, defaults to None :param retryable: If the operation should be retried once, defaults to None + :param is_run_command: If this is a runCommand operation, defaults to False + :param is_aggregate_write: If this is a aggregate operation with a write, defaults to False. :return: Output of the calling func() """ @@ -2011,6 +2033,8 @@ def _retry_internal( address=address, retryable=retryable, operation_id=operation_id, + is_run_command=is_run_command, + is_aggregate_write=is_aggregate_write, ).run() def _retryable_read( @@ -2022,6 +2046,8 @@ def _retryable_read( address: Optional[_Address] = None, retryable: bool = True, operation_id: Optional[int] = None, + is_run_command: bool = False, + is_aggregate_write: bool = False, ) -> T: """Execute an operation with consecutive retries if possible @@ -2037,6 +2063,8 @@ def _retryable_read( :param address: Optional address when sending a message, defaults to None :param retryable: if we should attempt retries (may not always be supported even if supplied), defaults to False + :param is_run_command: If this is a runCommand operation, defaults to False. + :param is_aggregate_write: If this is a aggregate operation with a write, defaults to False. """ # Ensure that the client supports retrying on reads and there is no session in @@ -2055,6 +2083,8 @@ def _retryable_read( read_pref=read_pref, retryable=retryable, operation_id=operation_id, + is_run_command=is_run_command, + is_aggregate_write=is_aggregate_write, ) def _retryable_write( @@ -2263,11 +2293,14 @@ def _tmp_session( self, session: Optional[client_session.ClientSession] ) -> Generator[Optional[client_session.ClientSession], None]: """If provided session is None, lend a temporary session.""" - if session is not None: - if not isinstance(session, client_session.ClientSession): - raise ValueError( - f"'session' argument must be a ClientSession or None, not {type(session)}" - ) + if session is not None and not isinstance(session, client_session.ClientSession): + raise ValueError( + f"'session' argument must be a ClientSession or None, not {type(session)}" + ) + + # Check for a bound session. If one exists, treat it as an explicitly passed session. + session = session or self._get_bound_session() + if session: # Don't call end_session. yield session return @@ -2295,6 +2328,18 @@ def _process_response(self, reply: Mapping[str, Any], session: Optional[ClientSe if session is not None: session._process_response(reply) + def _get_bound_session(self) -> Optional[ClientSession]: + bound_session = _SESSION.get() + if bound_session: + if bound_session.client is self: + return bound_session + else: + raise InvalidOperation( + "Only the client that created the bound session can perform operations within its context block. See for more information." + ) + else: + return None + def server_info(self, session: Optional[client_session.ClientSession] = None) -> dict[str, Any]: """Get information about the MongoDB server we're connected to. @@ -2428,15 +2473,13 @@ def drop_database( f"name_or_database must be an instance of str or a Database, not {type(name)}" ) - with self._conn_for_writes(session, operation=_Op.DROP_DATABASE) as conn: - self[name]._command( - conn, - {"dropDatabase": 1, "comment": comment}, - read_preference=ReadPreference.PRIMARY, - write_concern=self._write_concern_for(session), - parse_write_concern_error=True, - session=session, - ) + self[name].command( + {"dropDatabase": 1, "comment": comment}, + read_preference=ReadPreference.PRIMARY, + write_concern=self._write_concern_for(session), + parse_write_concern_error=True, + session=session, + ) @_csot.apply def bulk_write( @@ -2720,12 +2763,15 @@ def __init__( address: Optional[_Address] = None, retryable: bool = False, operation_id: Optional[int] = None, + is_run_command: bool = False, + is_aggregate_write: bool = False, ): self._last_error: Optional[Exception] = None self._retrying = False + self._always_retryable = False self._multiple_retries = _csot.get_timeout() is not None self._client = mongo_client - + self._retry_policy = mongo_client._retry_policy self._func = func self._bulk = bulk self._session = session @@ -2741,6 +2787,8 @@ def __init__( self._operation = operation self._operation_id = operation_id self._attempt_number = 0 + self._is_run_command = is_run_command + self._is_aggregate_write = is_aggregate_write def run(self) -> T: """Runs the supplied func() and attempts a retry @@ -2760,7 +2808,13 @@ def run(self) -> T: while True: self._check_last_error(check_csot=True) try: - return self._read() if self._is_read else self._write() + res = self._read() if self._is_read else self._write() + # Track whether the transaction has completed a command. + # If we need to apply backpressure to the first command, + # we will need to revert back to starting state. + if self._session is not None and self._session.in_transaction: + self._session._transaction.has_completed_command = True + return res except ServerSelectionTimeoutError: # The application may think the write was never attempted # if we raise ServerSelectionTimeoutError on the retry @@ -2771,37 +2825,76 @@ def run(self) -> T: # most likely be a waste of time. raise except PyMongoError as exc: + always_retryable = False + overloaded = False + exc_to_check = exc + + if self._is_run_command and not ( + self._client.options.retry_reads and self._client.options.retry_writes + ): + raise + if self._is_aggregate_write and not self._client.options.retry_writes: + raise + # Execute specialized catch on read if self._is_read: if isinstance(exc, (ConnectionFailure, OperationFailure)): # ConnectionFailures do not supply a code property exc_code = getattr(exc, "code", None) - if self._is_not_eligible_for_retry() or ( - isinstance(exc, OperationFailure) - and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES + overloaded = exc.has_error_label("SystemOverloadedError") + always_retryable = exc.has_error_label("RetryableError") and overloaded + if not self._client.options.retry_reads or ( + not always_retryable + and ( + self._is_not_eligible_for_retry() + or ( + isinstance(exc, OperationFailure) + and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES + ) + ) ): raise self._retrying = True self._last_error = exc self._attempt_number += 1 + + # Revert back to starting state if we're in a transaction but haven't completed the first + # command. + if ( + overloaded + and self._session is not None + and self._session.in_transaction + ): + transaction = self._session._transaction + if not transaction.has_completed_command: + transaction.set_starting() + transaction.attempt = 0 else: raise # Specialized catch on write operation if not self._is_read: - if not self._retryable: + if isinstance(exc, ClientBulkWriteException) and isinstance( + exc.error, PyMongoError + ): + exc_to_check = exc.error + retryable_write_label = exc_to_check.has_error_label("RetryableWriteError") + overloaded = exc_to_check.has_error_label("SystemOverloadedError") + always_retryable = exc_to_check.has_error_label("RetryableError") and overloaded + + # Always retry abortTransaction and commitTransaction up to once + if self._operation not in ["abortTransaction", "commitTransaction"] and ( + not self._client.options.retry_writes + or not (self._retryable or always_retryable) + ): raise - if isinstance(exc, ClientBulkWriteException) and exc.error: - retryable_write_error_exc = isinstance( - exc.error, PyMongoError - ) and exc.error.has_error_label("RetryableWriteError") - else: - retryable_write_error_exc = exc.has_error_label("RetryableWriteError") - if retryable_write_error_exc: + if retryable_write_label or always_retryable: assert self._session self._session._unpin() - if not retryable_write_error_exc or self._is_not_eligible_for_retry(): - if exc.has_error_label("NoWritesPerformed") and self._last_error: + if not always_retryable and ( + not retryable_write_label or self._is_not_eligible_for_retry() + ): + if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error: raise self._last_error from exc else: raise @@ -2810,14 +2903,34 @@ def run(self) -> T: self._bulk.retrying = True else: self._retrying = True - if not exc.has_error_label("NoWritesPerformed"): + if not exc_to_check.has_error_label("NoWritesPerformed"): self._last_error = exc if self._last_error is None: self._last_error = exc - - if self._server is not None: + # Revert back to starting state if we're in a transaction but haven't completed the first + # command. + if overloaded and self._session is not None and self._session.in_transaction: + transaction = self._session._transaction + if not transaction.has_completed_command: + transaction.set_starting() + transaction.attempt = 0 + + if self._server is not None and ( + self._client.topology_description.topology_type_name == "Sharded" + or (overloaded and self._client.options.enable_overload_retargeting) + ): self._deprioritized_servers.append(self._server) + self._always_retryable = always_retryable + if overloaded: + delay = self._retry_policy.backoff(self._attempt_number) + if not self._retry_policy.should_retry(self._attempt_number, delay): + if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error: + raise self._last_error from exc + else: + raise + time.sleep(delay) + def _is_not_eligible_for_retry(self) -> bool: """Checks if the exchange is not eligible for retry""" return not self._retryable or (self._is_retrying() and not self._multiple_retries) @@ -2879,7 +2992,7 @@ def _write(self) -> T: and conn.supports_sessions ) is_mongos = conn.is_mongos - if not sessions_supported: + if not self._always_retryable and not sessions_supported: # A retry is not possible because this server does # not support sessions raise the last error. self._check_last_error() @@ -2911,7 +3024,7 @@ def _read(self) -> T: conn, read_pref, ): - if self._retrying and not self._retryable: + if self._retrying and not self._retryable and not self._always_retryable: self._check_last_error() if self._retrying: _debug_log( diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 89d2080bc8..d33cb59a98 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -19,6 +19,8 @@ import contextlib import logging import os +import socket +import ssl import sys import time import weakref @@ -49,10 +51,12 @@ DocumentTooLarge, ExecutionTimeout, InvalidOperation, + NetworkTimeout, NotPrimaryError, OperationFailure, PyMongoError, WaitQueueTimeoutError, + _CertificateError, ) from pymongo.hello import Hello, HelloCompat from pymongo.helpers_shared import _get_timeout_details, format_timeout_details @@ -250,6 +254,7 @@ def _hello( cmd = self.hello_cmd() performing_handshake = not self.performed_handshake awaitable = False + cmd["backpressure"] = True if performing_handshake: self.performed_handshake = True cmd["client"] = self.opts.metadata @@ -750,8 +755,8 @@ def __init__( # Enforces: maxConnecting # Also used for: clearing the wait queue self._max_connecting_cond = _create_condition(self.lock) - self._max_connecting = self.opts.max_connecting self._pending = 0 + self._max_connecting = self.opts.max_connecting self._client_id = client_id if self.enabled_for_cmap: assert self.opts._event_listeners is not None @@ -982,6 +987,21 @@ def remove_stale_sockets(self, reference_generation: int) -> None: self.requests -= 1 self.size_cond.notify() + def _handle_connection_error(self, error: BaseException) -> None: + # Handle system overload condition for non-sdam pools. + # Look for errors of type AutoReconnect and add error labels if appropriate. + if self.is_sdam or type(error) not in (AutoReconnect, NetworkTimeout): + return + assert isinstance(error, AutoReconnect) # Appease type checker. + # If the original error was a DNS, certificate, or SSL error, ignore it. + if isinstance(error.__cause__, (_CertificateError, SSLErrors, socket.gaierror)): + # End of file errors are excluded, because the server may have disconnected + # during the handshake. + if not isinstance(error.__cause__, (ssl.SSLEOFError, ssl.SSLZeroReturnError)): + return + error._add_error_label("SystemOverloadedError") + error._add_error_label("RetryableError") + def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connection: """Connect to Mongo and return a new Connection. @@ -1033,10 +1053,10 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect reason=_verbose_connection_error_reason(ConnectionClosedReason.ERROR), error=ConnectionClosedReason.ERROR, ) + self._handle_connection_error(error) if isinstance(error, (IOError, OSError, *SSLErrors)): details = _get_timeout_details(self.opts) _raise_connection_failure(self.address, error, timeout_details=details) - raise conn = Connection(networking_interface, self, self.address, conn_id, self.is_sdam) # type: ignore[arg-type] @@ -1045,18 +1065,22 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect self.active_contexts.discard(tmp_context) if tmp_context.cancelled: conn.cancel_context.cancel() + completed_hello = False try: if not self.is_sdam: conn.hello() + completed_hello = True self.is_writable = conn.is_writable if handler: handler.contribute_socket(conn, completed_handshake=False) conn.authenticate() # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. - except BaseException: + except BaseException as e: with self.lock: self.active_contexts.discard(conn.cancel_context) + if not completed_hello: + self._handle_connection_error(e) conn.close_conn(ConnectionClosedReason.ERROR) raise @@ -1385,8 +1409,8 @@ def _perished(self, conn: Connection) -> bool: :class:`~pymongo.errors.AutoReconnect` exceptions on server hiccups, etc. We only check if the socket was closed by an external error if it has been > 1 second since the socket was checked into the - pool, to keep performance reasonable - we can't avoid AutoReconnects - completely anyway. + pool to keep performance reasonable - + we can't avoid AutoReconnects completely anyway. """ idle_time_seconds = conn.idle_time_seconds() # If socket is idle, open a new one. @@ -1397,8 +1421,9 @@ def _perished(self, conn: Connection) -> bool: conn.close_conn(ConnectionClosedReason.IDLE) return True - if self._check_interval_seconds is not None and ( - self._check_interval_seconds == 0 or idle_time_seconds > self._check_interval_seconds + check_interval_seconds = self._check_interval_seconds + if check_interval_seconds is not None and ( + check_interval_seconds == 0 or idle_time_seconds > check_interval_seconds ): if conn.conn_closed(): conn.close_conn(ConnectionClosedReason.ERROR) diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index 38e916b1e7..ec1615f0c6 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -911,7 +911,9 @@ def _handle_error(self, address: _Address, err_ctx: _ErrorContext) -> None: # Clear the pool. server.reset(service_id) elif isinstance(error, ConnectionFailure): - if isinstance(error, WaitQueueTimeoutError): + if isinstance(error, WaitQueueTimeoutError) or ( + error.has_error_label("SystemOverloadedError") + ): return # "Client MUST replace the server's description with type Unknown # ... MUST NOT request an immediate check of the server." diff --git a/pyproject.toml b/pyproject.toml index acc9fa5b0d..ff6754e1dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,6 @@ dev = [] pip = ["pip>=20.2"] gevent = ["gevent>=21.12"] coverage = [ - "pytest-cov>=4.0.0", "coverage[toml]>=5,<=7.10.7" ] mockupdb = [ @@ -133,6 +132,7 @@ markers = [ "mockupdb: tests that rely on mockupdb", "default: default test suite", "default_async: default async test suite", + "test_bson: bson module tests", ] [tool.mypy] diff --git a/test/__init__.py b/test/__init__.py index 8540c442e0..1db3fde4b2 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -84,6 +84,22 @@ _IS_SYNC = True +# Skip tests when using Rust BSON extension for features not yet implemented +# Import pytest lazily to avoid requiring it for integration tests +try: + import pytest + + import bson + + skip_if_rust_bson = pytest.mark.skipif( + bson.get_bson_implementation() == "rust", + reason="Feature not yet implemented in Rust BSON extension", + ) +except ImportError: + # pytest not available, define a no-op decorator + def skip_if_rust_bson(func): + return func + def _connection_string(h): if h.startswith(("mongodb://", "mongodb+srv://")): diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 4dde0acf1f..a0647b0e16 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -84,6 +84,22 @@ _IS_SYNC = False +# Skip tests when using Rust BSON extension for features not yet implemented +# Import pytest lazily to avoid requiring it for integration tests +try: + import pytest + + import bson + + skip_if_rust_bson = pytest.mark.skipif( + bson.get_bson_implementation() == "rust", + reason="Feature not yet implemented in Rust BSON extension", + ) +except ImportError: + # pytest not available, define a no-op decorator + def skip_if_rust_bson(func): + return func + def _connection_string(h): if h.startswith(("mongodb://", "mongodb+srv://")): diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 5511765bae..ca150ca6df 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -652,6 +652,38 @@ async def test_detected_environment_warning(self, mock_get_hosts): with self.assertWarns(UserWarning): self.simple_client(multi_host) + async def test_max_adaptive_retries(self): + # Assert that max adaptive retries defaults to 2. + c = self.simple_client(connect=False) + self.assertEqual(c.options.max_adaptive_retries, 2) + + # Assert that max adaptive retries can be configured through connection or client options. + c = self.simple_client(connect=False, max_adaptive_retries=10) + self.assertEqual(c.options.max_adaptive_retries, 10) + + c = self.simple_client(connect=False, maxAdaptiveRetries=10) + self.assertEqual(c.options.max_adaptive_retries, 10) + + c = self.simple_client(host="mongodb://localhost/?maxAdaptiveRetries=10", connect=False) + self.assertEqual(c.options.max_adaptive_retries, 10) + + async def test_enable_overload_retargeting(self): + # Assert that overload retargeting defaults to false. + c = self.simple_client(connect=False) + self.assertFalse(c.options.enable_overload_retargeting) + + # Assert that overload retargeting can be enabled through connection or client options. + c = self.simple_client(connect=False, enable_overload_retargeting=True) + self.assertTrue(c.options.enable_overload_retargeting) + + c = self.simple_client(connect=False, enableOverloadRetargeting=True) + self.assertTrue(c.options.enable_overload_retargeting) + + c = self.simple_client( + host="mongodb://localhost/?enableOverloadRetargeting=true", connect=False + ) + self.assertTrue(c.options.enable_overload_retargeting) + class TestClient(AsyncIntegrationTest): def test_multiple_uris(self): @@ -1034,7 +1066,7 @@ async def test_list_database_names(self): db_names = await self.client.list_database_names() self.assertIn("pymongo_test", db_names) self.assertIn("pymongo_test_mike", db_names) - self.assertEqual(db_names, cmd_names) + self.assertCountEqual(db_names, cmd_names) async def test_drop_database(self): with self.assertRaises(TypeError): diff --git a/test/asynchronous/test_client_backpressure.py b/test/asynchronous/test_client_backpressure.py new file mode 100644 index 0000000000..3e75ed9b0d --- /dev/null +++ b/test/asynchronous/test_client_backpressure.py @@ -0,0 +1,312 @@ +# Copyright 2025-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test Client Backpressure spec.""" +from __future__ import annotations + +import os +import pathlib +import sys +from time import perf_counter +from unittest.mock import patch + +from pymongo.common import MAX_ADAPTIVE_RETRIES + +sys.path[0:0] = [""] + +from test.asynchronous import ( + AsyncIntegrationTest, + async_client_context, + unittest, +) +from test.asynchronous.unified_format import generate_test_classes +from test.utils_shared import EventListener, OvertCommandListener + +from pymongo.errors import OperationFailure, PyMongoError + +_IS_SYNC = False + +# Mock a system overload error. +mock_overload_error = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "failCommands": ["find", "insert", "update"], + "errorCode": 462, # IngressRequestRateLimitExceeded + "errorLabels": ["RetryableError", "SystemOverloadedError"], + }, +} + + +def get_mock_overload_error(times: int): + error = mock_overload_error.copy() + error["mode"] = {"times": times} + return error + + +class TestBackpressure(AsyncIntegrationTest): + RUN_ON_LOAD_BALANCER = True + + @async_client_context.require_failCommand_appName + async def test_retry_overload_error_command(self): + await self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES) + async with self.fail_point(fail_many): + await self.db.command("find", "t") + + # Ensure command stops retrying after MAX_ADAPTIVE_RETRIES. + fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1) + async with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + await self.db.command("find", "t") + + self.assertIn("RetryableError", str(error.exception)) + self.assertIn("SystemOverloadedError", str(error.exception)) + + @async_client_context.require_failCommand_appName + async def test_retry_overload_error_find(self): + await self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES) + async with self.fail_point(fail_many): + await self.db.t.find_one() + + # Ensure command stops retrying after MAX_ADAPTIVE_RETRIES. + fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1) + async with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + await self.db.t.find_one() + + self.assertIn("RetryableError", str(error.exception)) + self.assertIn("SystemOverloadedError", str(error.exception)) + + @async_client_context.require_failCommand_appName + async def test_retry_overload_error_insert_one(self): + # Ensure command is retried on overload error. + fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES) + async with self.fail_point(fail_many): + await self.db.t.insert_one({"x": 1}) + + # Ensure command stops retrying after MAX_ADAPTIVE_RETRIES. + fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1) + async with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + await self.db.t.insert_one({"x": 1}) + + self.assertIn("RetryableError", str(error.exception)) + self.assertIn("SystemOverloadedError", str(error.exception)) + + @async_client_context.require_failCommand_appName + async def test_retry_overload_error_update_many(self): + # Even though update_many is not a retryable write operation, it will + # still be retried via the "RetryableError" error label. + await self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES) + async with self.fail_point(fail_many): + await self.db.t.update_many({}, {"$set": {"x": 2}}) + + # Ensure command stops retrying after MAX_ADAPTIVE_RETRIES. + fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1) + async with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + await self.db.t.update_many({}, {"$set": {"x": 2}}) + + self.assertIn("RetryableError", str(error.exception)) + self.assertIn("SystemOverloadedError", str(error.exception)) + + @async_client_context.require_failCommand_appName + async def test_retry_overload_error_getMore(self): + coll = self.db.t + await coll.insert_many([{"x": 1} for _ in range(10)]) + + # Ensure command is retried on overload error. + fail_many = { + "configureFailPoint": "failCommand", + "mode": {"times": MAX_ADAPTIVE_RETRIES}, + "data": { + "failCommands": ["getMore"], + "errorCode": 462, # IngressRequestRateLimitExceeded + "errorLabels": ["RetryableError", "SystemOverloadedError"], + }, + } + cursor = coll.find(batch_size=2) + await cursor.next() + async with self.fail_point(fail_many): + await cursor.to_list() + + # Ensure command stops retrying after MAX_ADAPTIVE_RETRIES. + fail_too_many = fail_many.copy() + fail_too_many["mode"] = {"times": MAX_ADAPTIVE_RETRIES + 1} + cursor = coll.find(batch_size=2) + await cursor.next() + async with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + await cursor.to_list() + + self.assertIn("RetryableError", str(error.exception)) + self.assertIn("SystemOverloadedError", str(error.exception)) + + +# Prose tests. +class AsyncTestClientBackpressure(AsyncIntegrationTest): + listener: EventListener + + @classmethod + def setUpClass(cls) -> None: + cls.listener = OvertCommandListener() + + @async_client_context.require_connection + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.listener.reset() + self.app_name = self.__class__.__name__.lower() + self.client = await self.async_rs_or_single_client( + event_listeners=[self.listener], appName=self.app_name + ) + + @patch("random.random") + @async_client_context.require_failCommand_appName + async def test_01_operation_retry_uses_exponential_backoff(self, random_func): + # Drivers should test that retries do not occur immediately when a SystemOverloadedError is encountered. + + # 1. let `client` be a `MongoClient` + client = self.client + + # 2. let `collection` be a collection + collection = client.test.test + + # 3. Now, run transactions without backoff: + + # a. Configure the random number generator used for jitter to always return `0` -- this effectively disables backoff. + random_func.return_value = 0 + + # b. Configure the following failPoint: + fail_point = dict( + mode="alwaysOn", + data=dict( + failCommands=["insert"], + errorCode=2, + errorLabels=["SystemOverloadedError", "RetryableError"], + appName=self.app_name, + ), + ) + async with self.fail_point(fail_point): + # c. Execute the following command. Expect that the command errors. Measure the duration of the command execution. + start0 = perf_counter() + with self.assertRaises(OperationFailure): + await collection.insert_one({"a": 1}) + end0 = perf_counter() + + # d. Configure the random number generator used for jitter to always return `1`. + random_func.return_value = 1 + + # e. Execute step c again. + start1 = perf_counter() + with self.assertRaises(OperationFailure): + await collection.insert_one({"a": 1}) + end1 = perf_counter() + + # f. Compare the times between the two runs. + # The sum of 2 backoffs is 0.3 seconds. There is a 0.3-second window to account for potential variance between the two + # runs. + self.assertTrue(abs((end1 - start1) - (end0 - start0 + 0.3)) < 0.3) + + @async_client_context.require_failCommand_appName + async def test_03_overload_retries_limited(self): + # Drivers should test that overload errors are retried a maximum of two times. + + # 1. Let `client` be a `MongoClient`. + client = self.client + # 2. Let `coll` be a collection. + coll = client.pymongo_test.coll + + # 3. Configure the following failpoint: + failpoint = { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": ["find"], + "errorCode": 462, # IngressRequestRateLimitExceeded + "errorLabels": ["RetryableError", "SystemOverloadedError"], + }, + } + + # 4. Perform a find operation with `coll` that fails. + async with self.fail_point(failpoint): + with self.assertRaises(PyMongoError) as error: + await coll.find_one({}) + + # 5. Assert that the raised error contains both the `RetryableError` and `SystemOverloadedError` error labels. + self.assertIn("RetryableError", str(error.exception)) + self.assertIn("SystemOverloadedError", str(error.exception)) + + # 6. Assert that the total number of started commands is MAX_ADAPTIVE_RETRIES + 1. + self.assertEqual(len(self.listener.started_events), MAX_ADAPTIVE_RETRIES + 1) + + @async_client_context.require_failCommand_appName + async def test_04_overload_retries_limited_configured(self): + # Drivers should test that overload errors are retried a maximum of maxAdaptiveRetries times. + max_retries = 1 + + # 1. Let `client` be a `MongoClient` with `maxAdaptiveRetries=1` and command event monitoring enabled. + client = await self.async_single_client( + maxAdaptiveRetries=max_retries, event_listeners=[self.listener] + ) + # 2. Let `coll` be a collection. + coll = client.pymongo_test.coll + + # 3. Configure the following failpoint: + failpoint = { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": ["find"], + "errorCode": 462, # IngressRequestRateLimitExceeded + "errorLabels": ["RetryableError", "SystemOverloadedError"], + }, + } + + # 4. Perform a find operation with `coll` that fails. + async with self.fail_point(failpoint): + with self.assertRaises(PyMongoError) as error: + await coll.find_one({}) + + # 5. Assert that the raised error contains both the `RetryableError` and `SystemOverloadedError` error labels. + self.assertIn("RetryableError", str(error.exception)) + self.assertIn("SystemOverloadedError", str(error.exception)) + + # 6. Assert that the total number of started commands is max_retries + 1. + self.assertEqual(len(self.listener.started_events), max_retries + 1) + + +# Location of JSON test specifications. +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "client-backpressure") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "client-backpressure") + +globals().update( + generate_test_classes( + _TEST_PATH, + module=__name__, + ) +) + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_client_metadata.py b/test/asynchronous/test_client_metadata.py index 45c1bd1b3b..d4f887b1fc 100644 --- a/test/asynchronous/test_client_metadata.py +++ b/test/asynchronous/test_client_metadata.py @@ -219,6 +219,19 @@ async def test_duplicate_driver_name_no_op(self): # add same metadata again await self.check_metadata_added(client, "Framework", None, None) + async def test_handshake_documents_include_backpressure(self): + # Create a `MongoClient` that is configured to record all handshake documents sent to the server as a part of + # connection establishment. + client = await self.async_rs_or_single_client("mongodb://" + self.server.address_string) + + # Send a `ping` command to the server and verify that the command succeeds. This ensure that a connection is + # established on all topologies. Note: MockupDB only supports standalone servers. + await client.admin.command("ping") + + # Assert that for every handshake document intercepted: + # the document has a field `backpressure` whose value is `true`. + self.assertEqual(self.handshake_req["backpressure"], True) + if __name__ == "__main__": unittest.main() diff --git a/test/asynchronous/test_cursor.py b/test/asynchronous/test_cursor.py index 08da82762c..27c80c62ab 100644 --- a/test/asynchronous/test_cursor.py +++ b/test/asynchronous/test_cursor.py @@ -30,7 +30,12 @@ sys.path[0:0] = [""] -from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous import ( + AsyncIntegrationTest, + async_client_context, + skip_if_rust_bson, + unittest, +) from test.asynchronous.utils import flaky from test.utils_shared import ( AllowListEventListener, @@ -1507,6 +1512,7 @@ async def test_command_cursor_to_list_csot_applied(self): self.assertTrue(ctx.exception.timeout) +@skip_if_rust_bson class TestRawBatchCursor(AsyncIntegrationTest): async def test_find_raw(self): c = self.db.test @@ -1682,6 +1688,7 @@ async def test_monitoring(self): await cursor.close() +@skip_if_rust_bson class TestRawBatchCommandCursor(AsyncIntegrationTest): async def test_aggregate_raw(self): c = self.db.test diff --git a/test/asynchronous/test_custom_types.py b/test/asynchronous/test_custom_types.py index 82c54512cc..613705b283 100644 --- a/test/asynchronous/test_custom_types.py +++ b/test/asynchronous/test_custom_types.py @@ -28,7 +28,12 @@ sys.path[0:0] = [""] -from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous import ( + AsyncIntegrationTest, + async_client_context, + skip_if_rust_bson, + unittest, +) from bson import ( _BUILT_IN_TYPES, @@ -196,12 +201,14 @@ def test_decode_file_iter(self): fileobj.close() +@skip_if_rust_bson class TestCustomPythonBSONTypeToBSONMonolithicCodec(CustomBSONTypeTests, unittest.TestCase): @classmethod def setUpClass(cls): cls.codecopts = DECIMAL_CODECOPTS +@skip_if_rust_bson class TestCustomPythonBSONTypeToBSONMultiplexedCodec(CustomBSONTypeTests, unittest.TestCase): @classmethod def setUpClass(cls): @@ -211,6 +218,7 @@ def setUpClass(cls): cls.codecopts = codec_options +@skip_if_rust_bson class TestBSONFallbackEncoder(unittest.TestCase): def _get_codec_options(self, fallback_encoder): type_registry = TypeRegistry(fallback_encoder=fallback_encoder) @@ -273,6 +281,7 @@ def fallback_encoder(value): self.assertEqual(called_with, [2 << 65]) +@skip_if_rust_bson class TestBSONTypeEnDeCodecs(unittest.TestCase): def test_instantiation(self): msg = "Can't instantiate abstract class" @@ -336,6 +345,7 @@ def test_type_checks(self): self.assertFalse(issubclass(TypeEncoder, TypeDecoder)) +@skip_if_rust_bson class TestBSONCustomTypeEncoderAndFallbackEncoderTandem(unittest.TestCase): TypeA: Any TypeB: Any @@ -432,6 +442,7 @@ def test_infinite_loop_exceeds_max_recursion_depth(self): encode({"x": self.TypeA(100)}, codec_options=codecopts) +@skip_if_rust_bson class TestTypeRegistry(unittest.TestCase): types: Tuple[object, object] codecs: Tuple[Type[TypeCodec], Type[TypeCodec]] @@ -622,6 +633,7 @@ class MyType(pytype): # type: ignore run_test(TypeCodec, {"bson_type": Decimal128, "transform_bson": lambda x: x}) +@skip_if_rust_bson class TestCollectionWCustomType(AsyncIntegrationTest): async def asyncSetUp(self): await super().asyncSetUp() @@ -744,6 +756,7 @@ async def test_find_one_and__w_custom_type_decoder(self): self.assertIsNone(await c.find_one()) +@skip_if_rust_bson class TestGridFileCustomType(AsyncIntegrationTest): async def asyncSetUp(self): await super().asyncSetUp() @@ -910,6 +923,7 @@ async def run_test(doc_cls): await run_test(doc_cls) +@skip_if_rust_bson class TestCollectionChangeStreamsWCustomTypes( AsyncIntegrationTest, ChangeStreamsWCustomTypesTestMixin ): @@ -929,6 +943,7 @@ async def create_targets(self, *args, **kwargs): await self.input_target.delete_many({}) +@skip_if_rust_bson class TestDatabaseChangeStreamsWCustomTypes( AsyncIntegrationTest, ChangeStreamsWCustomTypesTestMixin ): @@ -949,6 +964,7 @@ async def create_targets(self, *args, **kwargs): await self.input_target.insert_one({"data": "dummy"}) +@skip_if_rust_bson class TestClusterChangeStreamsWCustomTypes( AsyncIntegrationTest, ChangeStreamsWCustomTypesTestMixin ): diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py index 0bbf471d87..17a90db60f 100644 --- a/test/asynchronous/test_discovery_and_monitoring.py +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -25,8 +25,10 @@ from pathlib import Path from test.asynchronous.helpers import ConcurrentRunner from test.asynchronous.utils import flaky +from test.utils_shared import delay from pymongo.asynchronous.pool import AsyncConnection +from pymongo.errors import ConnectionFailure from pymongo.operations import _Op from pymongo.server_selectors import writable_server_selector @@ -70,7 +72,12 @@ ) from pymongo.hello import Hello, HelloCompat from pymongo.helpers_shared import _check_command_response, _check_write_command_response -from pymongo.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStartedEvent +from pymongo.monitoring import ( + ConnectionCheckOutFailedEvent, + PoolClearedEvent, + ServerHeartbeatFailedEvent, + ServerHeartbeatStartedEvent, +) from pymongo.server_description import SERVER_TYPE, ServerDescription from pymongo.topology_description import TOPOLOGY_TYPE @@ -131,6 +138,9 @@ async def got_app_error(topology, app_error): raise AssertionError except (AutoReconnect, NotPrimaryError, OperationFailure) as e: if when == "beforeHandshakeCompletes": + # The pool would have added the SystemOverloadedError in this case. + if isinstance(e, AutoReconnect): + e._add_error_label("SystemOverloadedError") completed_handshake = False elif when == "afterHandshakeCompletes": completed_handshake = True @@ -439,6 +449,59 @@ async def mock_close(self, reason): AsyncConnection.close_conn = original_close +class TestPoolBackpressure(AsyncIntegrationTest): + @async_client_context.require_version_min(7, 0, 0) + async def test_connection_pool_is_not_cleared(self): + listener = CMAPListener() + + # Create a client that listens to CMAP events, with maxConnecting=100. + client = await self.async_rs_or_single_client(maxConnecting=100, event_listeners=[listener]) + + # Enable the ingress rate limiter. + await client.admin.command( + "setParameter", 1, ingressConnectionEstablishmentRateLimiterEnabled=True + ) + await client.admin.command("setParameter", 1, ingressConnectionEstablishmentRatePerSec=20) + await client.admin.command( + "setParameter", 1, ingressConnectionEstablishmentBurstCapacitySecs=1 + ) + await client.admin.command("setParameter", 1, ingressConnectionEstablishmentMaxQueueDepth=1) + + # Disable the ingress rate limiter on teardown. + # Sleep for 1 second before disabling to avoid the rate limiter. + async def teardown(): + await asyncio.sleep(1) + await client.admin.command( + "setParameter", 1, ingressConnectionEstablishmentRateLimiterEnabled=False + ) + + self.addAsyncCleanup(teardown) + + # Make sure the collection has at least one document. + await client.test.test.delete_many({}) + await client.test.test.insert_one({}) + + # Run a slow operation to tie up the connection. + async def target(): + try: + await client.test.test.find_one({"$where": delay(0.1)}) + except ConnectionFailure: + pass + + # Run 100 parallel operations that contend for connections. + tasks = [] + for _ in range(100): + tasks.append(ConcurrentRunner(target=target)) + for t in tasks: + await t.start() + for t in tasks: + await t.join() + + # Verify there were at least 10 connection checkout failed event but no pool cleared events. + self.assertGreater(len(listener.events_by_type(ConnectionCheckOutFailedEvent)), 10) + self.assertEqual(len(listener.events_by_type(PoolClearedEvent)), 0) + + class TestServerMonitoringMode(AsyncIntegrationTest): @async_client_context.require_no_load_balancer async def asyncSetUp(self): diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index b1dbc73f39..9650f7043f 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -876,6 +876,8 @@ async def test_views_are_prohibited(self): class TestCorpus(AsyncEncryptionIntegrationTest): + # PYTHON-5708: Encryption tests sending large payloads fail on some mongocryptd versions. + @async_client_context.require_version_max(6, 99) @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") async def asyncSetUp(self): await super().asyncSetUp() @@ -1052,6 +1054,8 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest): client_encrypted: AsyncMongoClient listener: OvertCommandListener + # PYTHON-5708: Encryption tests sending large payloads fail on some mongocryptd versions. + @async_client_context.require_version_max(6, 99) async def asyncSetUp(self): await super().asyncSetUp() db = async_client_context.client.db diff --git a/test/asynchronous/test_pooling.py b/test/asynchronous/test_pooling.py index 3193d9e3d5..9db9b5ab3a 100644 --- a/test/asynchronous/test_pooling.py +++ b/test/asynchronous/test_pooling.py @@ -513,6 +513,39 @@ async def test_connection_timeout_message(self): str(error.exception), ) + @async_client_context.require_failCommand_appName + async def test_pool_backpressure_preserves_existing_connections(self): + client = await self.async_rs_or_single_client() + coll = client.pymongo_test.t + pool = await async_get_pool(client) + await coll.insert_many([{"x": 1} for _ in range(10)]) + t = SocketGetter(self.c, pool) + await t.start() + while t.state != "connection": + await asyncio.sleep(0.1) + + assert not t.sock.conn_closed() + + # Mock a session establishment overload. + mock_connection_fail = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "closeConnection": True, + }, + } + + async with self.fail_point(mock_connection_fail): + await coll.find_one({}) + + # Make sure the existing socket was not affected. + assert not t.sock.conn_closed() + + # Cleanup + await t.release_conn() + await t.join() + await pool.close() + class TestPoolMaxSize(_TestPoolingBase): async def test_max_pool_size(self): diff --git a/test/asynchronous/test_raw_bson.py b/test/asynchronous/test_raw_bson.py index 70832ea668..88ba05011b 100644 --- a/test/asynchronous/test_raw_bson.py +++ b/test/asynchronous/test_raw_bson.py @@ -19,7 +19,12 @@ sys.path[0:0] = [""] -from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous import ( + AsyncIntegrationTest, + async_client_context, + skip_if_rust_bson, + unittest, +) from bson import Code, DBRef, decode, encode from bson.binary import JAVA_LEGACY, Binary, UuidRepresentation @@ -31,6 +36,7 @@ _IS_SYNC = False +@skip_if_rust_bson class TestRawBSONDocument(AsyncIntegrationTest): # {'_id': ObjectId('556df68b6e32ab21a95e0785'), # 'name': 'Sherlock', diff --git a/test/asynchronous/test_retryable_reads.py b/test/asynchronous/test_retryable_reads.py index 47ac91b0f5..259cd9cff5 100644 --- a/test/asynchronous/test_retryable_reads.py +++ b/test/asynchronous/test_retryable_reads.py @@ -261,6 +261,128 @@ async def test_retryable_reads_are_retried_on_the_same_implicit_session(self): self.assertEqual(command_docs[0]["lsid"], command_docs[1]["lsid"]) self.assertIsNot(command_docs[0], command_docs[1]) + @async_client_context.require_replica_set + @async_client_context.require_secondaries_count(1) + @async_client_context.require_failCommand_fail_point + @async_client_context.require_version_min(4, 4, 0) + async def test_03_01_retryable_reads_caused_by_overload_errors_are_retried_on_a_different_replicaset_server_when_one_is_available_and_overload_retargeting_is_enabled( + self + ): + listener = OvertCommandListener() + + # 1. Create a client `client` with `retryReads=true`, `readPreference=primaryPreferred`, `enableOverloadRetargeting=True`, and command event monitoring enabled. + client = await self.async_rs_or_single_client( + event_listeners=[listener], + retryReads=True, + readPreference="primaryPreferred", + enableOverloadRetargeting=True, + ) + + # 2. Configure a fail point with the RetryableError and SystemOverloadedError error labels. + command_args = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "failCommands": ["find"], + "errorLabels": ["RetryableError", "SystemOverloadedError"], + "errorCode": 6, + }, + } + await async_set_fail_point(client, command_args) + + # 3. Reset the command event monitor to clear the fail point command from its stored events. + listener.reset() + + # 4. Execute a `find` command with `client`. + await client.t.t.find_one({}) + + # 5. Assert that one failed command event and one successful command event occurred. + self.assertEqual(len(listener.failed_events), 1) + self.assertEqual(len(listener.succeeded_events), 1) + + # 6. Assert that both events occurred on different servers. + assert listener.failed_events[0].connection_id != listener.succeeded_events[0].connection_id + + @async_client_context.require_replica_set + @async_client_context.require_secondaries_count(1) + @async_client_context.require_failCommand_fail_point + @async_client_context.require_version_min(4, 4, 0) + async def test_03_02_retryable_reads_caused_by_non_overload_errors_are_retried_on_the_same_replicaset_server( + self + ): + listener = OvertCommandListener() + + # 1. Create a client `client` with `retryReads=true`, `readPreference=primaryPreferred`, and command event monitoring enabled. + client = await self.async_rs_or_single_client( + event_listeners=[listener], retryReads=True, readPreference="primaryPreferred" + ) + + # 2. Configure a fail point with the RetryableError error label. + command_args = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "failCommands": ["find"], + "errorLabels": ["RetryableError"], + "errorCode": 6, + }, + } + await async_set_fail_point(client, command_args) + + # 3. Reset the command event monitor to clear the fail point command from its stored events. + listener.reset() + + # 4. Execute a `find` command with `client`. + await client.t.t.find_one({}) + + # 5. Assert that one failed command event and one successful command event occurred. + self.assertEqual(len(listener.failed_events), 1) + self.assertEqual(len(listener.succeeded_events), 1) + + # 6. Assert that both events occurred the same server. + assert listener.failed_events[0].connection_id == listener.succeeded_events[0].connection_id + + @async_client_context.require_replica_set + @async_client_context.require_secondaries_count(1) + @async_client_context.require_failCommand_fail_point + @async_client_context.require_version_min(4, 4, 0) + async def test_03_03_retryable_reads_caused_by_overload_errors_are_retried_on_the_same_replicaset_server_when_one_is_available_and_overload_retargeting_is_disabled( + self + ): + listener = OvertCommandListener() + + # 1. Create a client `client` with `retryReads=true`, `readPreference=primaryPreferred`, and command event monitoring enabled. + client = await self.async_rs_or_single_client( + event_listeners=[listener], + retryReads=True, + readPreference="primaryPreferred", + ) + + # 2. Configure a fail point with the RetryableError and SystemOverloadedError error labels. + command_args = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "failCommands": ["find"], + "errorLabels": ["RetryableError", "SystemOverloadedError"], + "errorCode": 6, + }, + } + await async_set_fail_point(client, command_args) + + # 3. Reset the command event monitor to clear the fail point command from its stored events. + listener.reset() + + # 4. Execute a `find` command with `client`. + await client.t.t.find_one({}) + + # 5. Assert that one failed command event and one successful command event occurred. + self.assertEqual(len(listener.failed_events), 1) + self.assertEqual(len(listener.succeeded_events), 1) + + # 6. Assert that both events occurred on the same server. + assert listener.failed_events[0].connection_id == listener.succeeded_events[0].connection_id + if __name__ == "__main__": unittest.main() diff --git a/test/asynchronous/test_retryable_writes.py b/test/asynchronous/test_retryable_writes.py index ddb1d39eb7..6e2072a2ad 100644 --- a/test/asynchronous/test_retryable_writes.py +++ b/test/asynchronous/test_retryable_writes.py @@ -43,14 +43,17 @@ from bson.int64 import Int64 from bson.raw_bson import RawBSONDocument from bson.son import SON +from pymongo import MongoClient from pymongo.errors import ( AutoReconnect, ConnectionFailure, - OperationFailure, + NotPrimaryError, + PyMongoError, ServerSelectionTimeoutError, WriteConcernError, ) from pymongo.monitoring import ( + CommandFailedEvent, CommandSucceededEvent, ConnectionCheckedOutEvent, ConnectionCheckOutFailedEvent, @@ -601,5 +604,186 @@ def raise_connection_err_select_server(*args, **kwargs): self.assertEqual(sent_txn_id, final_txn_id, msg) +class TestErrorPropagationAfterEncounteringMultipleErrors(AsyncIntegrationTest): + # Only run against replica sets as mongos does not propagate the NoWritesPerformed label to the drivers. + @async_client_context.require_replica_set + # Run against server versions 6.0 and above. + @async_client_context.require_version_min(6, 0) # type: ignore[untyped-decorator] + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.setup_client = MongoClient(**async_client_context.default_client_options) + self.addCleanup(self.setup_client.close) + + # TODO: After PYTHON-4595 we can use async event handlers and remove this workaround. + def configure_fail_point_sync(self, command_args, off=False) -> None: + cmd = {"configureFailPoint": "failCommand"} + cmd.update(command_args) + if off: + cmd["mode"] = "off" + cmd.pop("data", None) + self.setup_client.admin.command(cmd) + + async def test_01_drivers_return_the_correct_error_when_receiving_only_errors_without_NoWritesPerformed( + self + ) -> None: + # Create a client with retryWrites=true. + listener = OvertCommandListener() + + # Configure a fail point with error code 91 (ShutdownInProgress) with the RetryableError and SystemOverloadedError error labels. + command_args = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "failCommands": ["insert"], + "errorLabels": ["RetryableError", "SystemOverloadedError"], + "errorCode": 91, + }, + } + + # Via the command monitoring CommandFailedEvent, configure a fail point with error code 10107 (NotWritablePrimary). + command_args_inner = { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": ["insert"], + "errorCode": 10107, + "errorLabels": ["RetryableError", "SystemOverloadedError"], + }, + } + + def failed(event: CommandFailedEvent) -> None: + # Configure the 10107 fail point command only if the the failed event is for the 91 error configured in step 2. + if listener.failed_events: + return + assert event.failure["code"] == 91 + self.configure_fail_point_sync(command_args_inner) + self.addCleanup(self.configure_fail_point_sync, {}, off=True) + listener.failed_events.append(event) + + listener.failed = failed + + client = await self.async_rs_client(retryWrites=True, event_listeners=[listener]) + + self.configure_fail_point_sync(command_args) + self.addCleanup(self.configure_fail_point_sync, {}, off=True) + + # Attempt an insertOne operation on any record for any database and collection. + # Expect the insertOne to fail with a server error. + with self.assertRaises(NotPrimaryError) as exc: + await client.test.test.insert_one({}) + + # Assert that the error code of the server error is 10107. + assert exc.exception.errors["code"] == 10107 # type:ignore[call-overload] + + async def test_02_drivers_return_the_correct_error_when_receiving_only_errors_with_NoWritesPerformed( + self + ) -> None: + # Create a client with retryWrites=true. + listener = OvertCommandListener() + + # Configure a fail point with error code 91 (ShutdownInProgress) with the RetryableError and SystemOverloadedError error labels. + command_args = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "failCommands": ["insert"], + "errorLabels": ["RetryableError", "SystemOverloadedError", "NoWritesPerformed"], + "errorCode": 91, + }, + } + + # Via the command monitoring CommandFailedEvent, configure a fail point with error code `10107` (NotWritablePrimary) + # and a NoWritesPerformed label. + command_args_inner = { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": ["insert"], + "errorCode": 10107, + "errorLabels": ["RetryableError", "SystemOverloadedError", "NoWritesPerformed"], + }, + } + + def failed(event: CommandFailedEvent) -> None: + if listener.failed_events: + return + # Configure the 10107 fail point command only if the the failed event is for the 91 error configured in step 2. + assert event.failure["code"] == 91 + self.configure_fail_point_sync(command_args_inner) + self.addCleanup(self.configure_fail_point_sync, {}, off=True) + listener.failed_events.append(event) + + listener.failed = failed + + client = await self.async_rs_client(retryWrites=True, event_listeners=[listener]) + + self.configure_fail_point_sync(command_args) + self.addCleanup(self.configure_fail_point_sync, {}, off=True) + + # Attempt an insertOne operation on any record for any database and collection. + # Expect the insertOne to fail with a server error. + with self.assertRaises(NotPrimaryError) as exc: + await client.test.test.insert_one({}) + + # Assert that the error code of the server error is 91. + assert exc.exception.errors["code"] == 91 # type:ignore[call-overload] + + async def test_03_drivers_return_the_correct_error_when_receiving_some_errors_with_NoWritesPerformed_and_some_without_NoWritesPerformed( + self + ) -> None: + # Create a client with retryWrites=true. + listener = OvertCommandListener() + + # Configure the client to listen to CommandFailedEvents. In the attached listener, configure a fail point with error + # code `91` (NotWritablePrimary) and the `NoWritesPerformed`, `RetryableError` and `SystemOverloadedError` labels. + command_args_inner = { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": ["insert"], + "errorLabels": ["RetryableError", "SystemOverloadedError", "NoWritesPerformed"], + "errorCode": 91, + }, + } + + # Configure a fail point with error code `91` (ShutdownInProgress) with the `RetryableError` and + # `SystemOverloadedError` error labels but without the `NoWritesPerformed` error label. + command_args = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "failCommands": ["insert"], + "errorCode": 91, + "errorLabels": ["RetryableError", "SystemOverloadedError"], + }, + } + + def failed(event: CommandFailedEvent) -> None: + # Configure the fail point command only if the failed event is for the 91 error configured in step 2. + if listener.failed_events: + return + assert event.failure["code"] == 91 + self.configure_fail_point_sync(command_args_inner) + self.addCleanup(self.configure_fail_point_sync, {}, off=True) + listener.failed_events.append(event) + + listener.failed = failed + + client = await self.async_rs_client(retryWrites=True, event_listeners=[listener]) + + self.configure_fail_point_sync(command_args) + self.addCleanup(self.configure_fail_point_sync, {}, off=True) + + # Attempt an insertOne operation on any record for any database and collection. + # Expect the insertOne to fail with a server error. + with self.assertRaises(PyMongoError) as exc: + await client.test.test.insert_one({}) + + # Assert that the error code of the server error is 91. + assert exc.exception.errors["code"] == 91 + # Assert that the error does not contain the error label `NoWritesPerformed`. + assert "NoWritesPerformed" not in exc.exception.errors["errorLabels"] + + if __name__ == "__main__": unittest.main() diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index 19ce868c56..404a69fdee 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -189,6 +189,52 @@ async def _test_ops(self, client, *ops): f"{f.__name__} did not return implicit session to pool", ) + # Explicit bound session + for f, args, kw in ops: + async with client.start_session() as s: + async with s.bind(): + listener.reset() + s._materialize() + last_use = s._server_session.last_use + start = time.monotonic() + self.assertLessEqual(last_use, start) + # In case "f" modifies its inputs. + args = copy.copy(args) + kw = copy.copy(kw) + await f(*args, **kw) + self.assertGreaterEqual(len(listener.started_events), 1) + for event in listener.started_events: + self.assertIn( + "lsid", + event.command, + f"{f.__name__} sent no lsid with {event.command_name}", + ) + + self.assertEqual( + s.session_id, + event.command["lsid"], + f"{f.__name__} sent wrong lsid with {event.command_name}", + ) + + self.assertFalse(s.has_ended) + + self.assertTrue(s.has_ended) + with self.assertRaisesRegex(InvalidOperation, "ended session"): + async with s.bind(): + await f(*args, **kw) + + # Test a session cannot be used on another client. + async with self.client2.start_session() as s: + async with s.bind(): + # In case "f" modifies its inputs. + args = copy.copy(args) + kw = copy.copy(kw) + with self.assertRaisesRegex( + InvalidOperation, + "Only the client that created the bound session can perform operations within its context block", + ): + await f(*args, **kw) + async def test_implicit_sessions_checkout(self): # "To confirm that implicit sessions only allocate their server session after a # successful connection checkout" test from Driver Sessions Spec. @@ -825,6 +871,73 @@ async def test_session_not_copyable(self): async with client.start_session() as s: self.assertRaises(TypeError, lambda: copy.copy(s)) + async def test_nested_session_binding(self): + coll = self.client.pymongo_test.test + await coll.insert_one({"x": 1}) + + session1 = self.client.start_session() + session2 = self.client.start_session() + session1._materialize() + session2._materialize() + try: + self.listener.reset() + # Uses implicit session + await coll.find_one() + implicit_lsid = self.listener.started_events[0].command.get("lsid") + self.assertIsNotNone(implicit_lsid) + self.assertNotEqual(implicit_lsid, session1.session_id) + self.assertNotEqual(implicit_lsid, session2.session_id) + + async with session1.bind(end_session=False): + self.listener.reset() + # Uses bound session1 + await coll.find_one() + session1_lsid = self.listener.started_events[0].command.get("lsid") + self.assertEqual(session1_lsid, session1.session_id) + + async with session2.bind(end_session=False): + self.listener.reset() + # Uses bound session2 + await coll.find_one() + session2_lsid = self.listener.started_events[0].command.get("lsid") + self.assertEqual(session2_lsid, session2.session_id) + self.assertNotEqual(session2_lsid, session1.session_id) + + self.listener.reset() + # Use bound session1 again + await coll.find_one() + session1_lsid = self.listener.started_events[0].command.get("lsid") + self.assertEqual(session1_lsid, session1.session_id) + self.assertNotEqual(session1_lsid, session2.session_id) + + self.listener.reset() + # Uses implicit session + await coll.find_one() + implicit_lsid = self.listener.started_events[0].command.get("lsid") + self.assertIsNotNone(implicit_lsid) + self.assertNotEqual(implicit_lsid, session1.session_id) + self.assertNotEqual(implicit_lsid, session2.session_id) + + finally: + await session1.end_session() + await session2.end_session() + + async def test_session_binding_end_session(self): + coll = self.client.pymongo_test.test + await coll.insert_one({"x": 1}) + + async with self.client.start_session().bind() as s1: + await coll.find_one() + + self.assertTrue(s1.has_ended) + + async with self.client.start_session().bind(end_session=False) as s2: + await coll.find_one() + + self.assertFalse(s2.has_ended) + + await s2.end_session() + class TestCausalConsistency(AsyncUnitTest): listener: SessionTestListener diff --git a/test/asynchronous/test_ssl.py b/test/asynchronous/test_ssl.py index 0ce3e8bbac..7fe57e8503 100644 --- a/test/asynchronous/test_ssl.py +++ b/test/asynchronous/test_ssl.py @@ -48,19 +48,11 @@ _HAVE_PYOPENSSL = False try: - # All of these must be available to use PyOpenSSL - import OpenSSL - import requests - import service_identity - - # Ensure service_identity>=18.1 is installed - from service_identity.pyopenssl import verify_ip_address - - from pymongo.ocsp_support import _load_trusted_ca_certs + from pymongo import pyopenssl_context _HAVE_PYOPENSSL = True except ImportError: - _load_trusted_ca_certs = None # type: ignore + pass if HAVE_SSL: @@ -136,11 +128,6 @@ def test_config_ssl(self): def test_use_pyopenssl_when_available(self): self.assertTrue(HAVE_PYSSL) - @unittest.skipUnless(_HAVE_PYOPENSSL, "Cannot test without PyOpenSSL") - def test_load_trusted_ca_certs(self): - trusted_ca_certs = _load_trusted_ca_certs(CA_BUNDLE_PEM) - self.assertEqual(2, len(trusted_ca_certs)) - class TestSSL(AsyncIntegrationTest): saved_port: int diff --git a/test/asynchronous/test_transactions.py b/test/asynchronous/test_transactions.py index 4fc20fba3b..e17bfb14c0 100644 --- a/test/asynchronous/test_transactions.py +++ b/test/asynchronous/test_transactions.py @@ -16,9 +16,13 @@ from __future__ import annotations import asyncio +import random import sys +import time from io import BytesIO +from unittest.mock import patch +import pymongo from gridfs.asynchronous.grid_file import AsyncGridFS, AsyncGridFSBucket from pymongo.asynchronous.pool import PoolState from pymongo.server_selectors import writable_server_selector @@ -45,7 +49,9 @@ CollectionInvalid, ConfigurationError, ConnectionFailure, + ExecutionTimeout, InvalidOperation, + NetworkTimeout, OperationFailure, ) from pymongo.operations import IndexModel, InsertOne @@ -434,7 +440,7 @@ async def set_fail_point(self, command_args): await self.configure_fail_point(client, command_args) @async_client_context.require_transactions - async def test_callback_raises_custom_error(self): + async def test_1_callback_raises_custom_error(self): class _MyException(Exception): pass @@ -446,7 +452,7 @@ async def raise_error(_): await s.with_transaction(raise_error) @async_client_context.require_transactions - async def test_callback_returns_value(self): + async def test_2_callback_returns_value(self): async def callback(_): return "Foo" @@ -474,7 +480,7 @@ def callback(_): self.assertEqual(await s.with_transaction(callback), "Foo") @async_client_context.require_transactions - async def test_callback_not_retried_after_timeout(self): + async def test_3_1_callback_not_retried_after_timeout(self): listener = OvertCommandListener() client = await self.async_rs_client(event_listeners=[listener]) coll = client[self.db.name].test @@ -495,14 +501,16 @@ async def callback(session): listener.reset() async with client.start_session() as s: with PatchSessionTimeout(0): - with self.assertRaises(OperationFailure): + with self.assertRaises(NetworkTimeout) as context: await s.with_transaction(callback) self.assertEqual(listener.started_command_names(), ["insert", "abortTransaction"]) + # Assert that the timeout error has the same labels as the error it wraps. + self.assertTrue(context.exception.has_error_label("TransientTransactionError")) @async_client_context.require_test_commands @async_client_context.require_transactions - async def test_callback_not_retried_after_commit_timeout(self): + async def test_3_2_callback_not_retried_after_commit_timeout(self): listener = OvertCommandListener() client = await self.async_rs_client(event_listeners=[listener]) coll = client[self.db.name].test @@ -529,14 +537,16 @@ async def callback(session): async with client.start_session() as s: with PatchSessionTimeout(0): - with self.assertRaises(OperationFailure): + with self.assertRaises(NetworkTimeout) as context: await s.with_transaction(callback) self.assertEqual(listener.started_command_names(), ["insert", "commitTransaction"]) + # Assert that the timeout error has the same labels as the error it wraps. + self.assertTrue(context.exception.has_error_label("TransientTransactionError")) @async_client_context.require_test_commands @async_client_context.require_transactions - async def test_commit_not_retried_after_timeout(self): + async def test_3_3_commit_not_retried_after_timeout(self): listener = OvertCommandListener() client = await self.async_rs_client(event_listeners=[listener]) coll = client[self.db.name].test @@ -560,7 +570,7 @@ async def callback(session): async with client.start_session() as s: with PatchSessionTimeout(0): - with self.assertRaises(ConnectionFailure): + with self.assertRaises(NetworkTimeout) as context: await s.with_transaction(callback) # One insert for the callback and two commits (includes the automatic @@ -568,6 +578,40 @@ async def callback(session): self.assertEqual( listener.started_command_names(), ["insert", "commitTransaction", "commitTransaction"] ) + # Assert that the timeout error has the same labels as the error it wraps. + self.assertTrue(context.exception.has_error_label("UnknownTransactionCommitResult")) + + @async_client_context.require_transactions + async def test_callback_not_retried_after_csot_timeout(self): + listener = OvertCommandListener() + client = await self.async_rs_client(event_listeners=[listener]) + coll = client[self.db.name].test + + async def callback(session): + await coll.insert_one({}, session=session) + err: dict = { + "ok": 0, + "errmsg": "Transaction 7819 has been aborted.", + "code": 251, + "codeName": "NoSuchTransaction", + "errorLabels": ["TransientTransactionError"], + } + raise OperationFailure(err["errmsg"], err["code"], err) + + # Create the collection. + await coll.insert_one({}) + listener.reset() + async with client.start_session() as s: + with pymongo.timeout(1.0): + with self.assertRaises(ExecutionTimeout): + await s.with_transaction(callback) + + # At least two attempts: the original and one or more retries. + inserts = len([x for x in listener.started_command_names() if x == "insert"]) + aborts = len([x for x in listener.started_command_names() if x == "abortTransaction"]) + + self.assertGreaterEqual(inserts, 2) + self.assertGreaterEqual(aborts, 2) # Tested here because this supports Motor's convenient transactions API. @async_client_context.require_transactions @@ -606,6 +650,63 @@ async def callback(session): await s.with_transaction(callback) self.assertFalse(s.in_transaction) + @async_client_context.require_test_commands + @async_client_context.require_transactions + async def test_4_retry_backoff_is_enforced(self): + client = async_client_context.client + coll = client[self.db.name].test + end = start = no_backoff_time = 0 + + # Make random.random always return 0 (no backoff) + with patch.object(random, "random", return_value=0): + # set fail point to trigger transaction failure and trigger backoff + await self.set_fail_point( + { + "configureFailPoint": "failCommand", + "mode": {"times": 13}, + "data": { + "failCommands": ["commitTransaction"], + "errorCode": 251, + }, + } + ) + self.addAsyncCleanup( + self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"} + ) + + async def callback(session): + await coll.insert_one({}, session=session) + + start = time.monotonic() + async with self.client.start_session() as s: + await s.with_transaction(callback) + end = time.monotonic() + no_backoff_time = end - start + + # Make random.random always return 1 (max backoff) + with patch.object(random, "random", return_value=1): + # set fail point to trigger transaction failure and trigger backoff + await self.set_fail_point( + { + "configureFailPoint": "failCommand", + "mode": { + "times": 13 + }, # sufficiently high enough such that the time effect of backoff is noticeable + "data": { + "failCommands": ["commitTransaction"], + "errorCode": 251, + }, + } + ) + self.addAsyncCleanup( + self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"} + ) + start = time.monotonic() + async with self.client.start_session() as s: + await s.with_transaction(callback) + end = time.monotonic() + self.assertLess(abs(end - start - (no_backoff_time + 2.2)), 1) # sum of 13 backoffs is 2.2 + class TestOptionsInsideTransactionProse(AsyncTransactionsBase): @async_client_context.require_transactions diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 6ce8f852cf..1fb93e7b86 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -1464,11 +1464,6 @@ async def verify_outcome(self, spec): self.assertListEqual(sorted_expected_documents, actual_documents) async def run_scenario(self, spec, uri=None): - # Kill all sessions before and after each test to prevent an open - # transaction (from a test failure) from blocking collection/database - # operations during test set up and tear down. - await self.kill_all_sessions() - # Handle flaky tests. flaky_tests = [ ("PYTHON-5170", ".*test_discovery_and_monitoring.*"), @@ -1504,6 +1499,15 @@ async def _run_scenario(self, spec, uri=None): if skip_reason is not None: raise unittest.SkipTest(f"{skip_reason}") + # Kill all sessions after each test with transactions to prevent an open + # transaction (from a test failure) from blocking collection/database + # operations during test set up and tear down. + for op in spec["operations"]: + name = op["name"] + if name == "startTransaction" or name == "withTransaction": + self.addAsyncCleanup(self.kill_all_sessions) + break + # process createEntities self._uri = uri self.entity_map = EntityMapUtil(self) diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index 63e7e9e150..ff5f61db06 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -16,43 +16,13 @@ from __future__ import annotations import asyncio -import functools import os -import time -import unittest -from collections import abc -from inspect import iscoroutinefunction -from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs +from test.asynchronous import async_client_context from test.asynchronous.helpers import ConcurrentRunner -from test.utils_shared import ( - CMAPListener, - CompareType, - EventListener, - OvertCommandListener, - ScenarioDict, - ServerAndTopologyEventListener, - camel_to_snake, - camel_to_snake_args, - parse_spec_options, - prepare_spec_arguments, -) -from typing import List - -from bson import ObjectId, decode, encode, json_util -from bson.binary import Binary -from bson.int64 import Int64 -from bson.son import SON -from gridfs import GridFSBucket -from gridfs.asynchronous.grid_file import AsyncGridFSBucket -from pymongo.asynchronous import client_session -from pymongo.asynchronous.command_cursor import AsyncCommandCursor -from pymongo.asynchronous.cursor import AsyncCursor -from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError +from test.utils_shared import ScenarioDict + +from bson import json_util from pymongo.lock import _async_cond_wait, _async_create_condition, _async_create_lock -from pymongo.read_concern import ReadConcern -from pymongo.read_preferences import ReadPreference -from pymongo.results import BulkWriteResult, _WriteResult -from pymongo.write_concern import WriteConcern _IS_SYNC = False @@ -219,597 +189,3 @@ def create_tests(self): self._create_tests() else: asyncio.run(self._create_tests()) - - -class AsyncSpecRunner(AsyncIntegrationTest): - mongos_clients: List - knobs: client_knobs - listener: EventListener - - async def asyncSetUp(self) -> None: - await super().asyncSetUp() - self.mongos_clients = [] - - # Speed up the tests by decreasing the heartbeat frequency. - self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - self.knobs.enable() - self.targets = {} - self.listener = None # type: ignore - self.pool_listener = None - self.server_listener = None - self.maxDiff = None - - async def asyncTearDown(self) -> None: - self.knobs.disable() - - async def set_fail_point(self, command_args): - clients = self.mongos_clients if self.mongos_clients else [self.client] - for client in clients: - await self.configure_fail_point(client, command_args) - - async def targeted_fail_point(self, session, fail_point): - """Run the targetedFailPoint test operation. - - Enable the fail point on the session's pinned mongos. - """ - clients = {c.address: c for c in self.mongos_clients} - client = clients[session._pinned_address] - await self.configure_fail_point(client, fail_point) - self.addAsyncCleanup(self.set_fail_point, {"mode": "off"}) - - def assert_session_pinned(self, session): - """Run the assertSessionPinned test operation. - - Assert that the given session is pinned. - """ - self.assertIsNotNone(session._transaction.pinned_address) - - def assert_session_unpinned(self, session): - """Run the assertSessionUnpinned test operation. - - Assert that the given session is not pinned. - """ - self.assertIsNone(session._pinned_address) - self.assertIsNone(session._transaction.pinned_address) - - async def assert_collection_exists(self, database, collection): - """Run the assertCollectionExists test operation.""" - db = self.client[database] - self.assertIn(collection, await db.list_collection_names()) - - async def assert_collection_not_exists(self, database, collection): - """Run the assertCollectionNotExists test operation.""" - db = self.client[database] - self.assertNotIn(collection, await db.list_collection_names()) - - async def assert_index_exists(self, database, collection, index): - """Run the assertIndexExists test operation.""" - coll = self.client[database][collection] - self.assertIn(index, [doc["name"] async for doc in await coll.list_indexes()]) - - async def assert_index_not_exists(self, database, collection, index): - """Run the assertIndexNotExists test operation.""" - coll = self.client[database][collection] - self.assertNotIn(index, [doc["name"] async for doc in await coll.list_indexes()]) - - async def wait(self, ms): - """Run the "wait" test operation.""" - await asyncio.sleep(ms / 1000.0) - - def assertErrorLabelsContain(self, exc, expected_labels): - labels = [l for l in expected_labels if exc.has_error_label(l)] - self.assertEqual(labels, expected_labels) - - def assertErrorLabelsOmit(self, exc, omit_labels): - for label in omit_labels: - self.assertFalse( - exc.has_error_label(label), msg=f"error labels should not contain {label}" - ) - - async def kill_all_sessions(self): - clients = self.mongos_clients if self.mongos_clients else [self.client] - for client in clients: - try: - await client.admin.command("killAllSessions", []) - except (OperationFailure, AutoReconnect): - # "operation was interrupted" by killing the command's - # own session. - # On 8.0+ killAllSessions sometimes returns a network error. - pass - - def check_command_result(self, expected_result, result): - # Only compare the keys in the expected result. - filtered_result = {} - for key in expected_result: - try: - filtered_result[key] = result[key] - except KeyError: - pass - self.assertEqual(filtered_result, expected_result) - - # TODO: factor the following function with test_crud.py. - def check_result(self, expected_result, result): - if isinstance(result, _WriteResult): - for res in expected_result: - prop = camel_to_snake(res) - # SPEC-869: Only BulkWriteResult has upserted_count. - if prop == "upserted_count" and not isinstance(result, BulkWriteResult): - if result.upserted_id is not None: - upserted_count = 1 - else: - upserted_count = 0 - self.assertEqual(upserted_count, expected_result[res], prop) - elif prop == "inserted_ids": - # BulkWriteResult does not have inserted_ids. - if isinstance(result, BulkWriteResult): - self.assertEqual(len(expected_result[res]), result.inserted_count) - else: - # InsertManyResult may be compared to [id1] from the - # crud spec or {"0": id1} from the retryable write spec. - ids = expected_result[res] - if isinstance(ids, dict): - ids = [ids[str(i)] for i in range(len(ids))] - - self.assertEqual(ids, result.inserted_ids, prop) - elif prop == "upserted_ids": - # Convert indexes from strings to integers. - ids = expected_result[res] - expected_ids = {} - for str_index in ids: - expected_ids[int(str_index)] = ids[str_index] - self.assertEqual(expected_ids, result.upserted_ids, prop) - else: - self.assertEqual(getattr(result, prop), expected_result[res], prop) - - return True - else: - - def _helper(expected_result, result): - if isinstance(expected_result, abc.Mapping): - for i in expected_result.keys(): - self.assertEqual(expected_result[i], result[i]) - - elif isinstance(expected_result, list): - for i, k in zip(expected_result, result): - _helper(i, k) - else: - self.assertEqual(expected_result, result) - - _helper(expected_result, result) - return None - - def get_object_name(self, op): - """Allow subclasses to override handling of 'object' - - Transaction spec says 'object' is required. - """ - return op["object"] - - @staticmethod - def parse_options(opts): - return parse_spec_options(opts) - - async def run_operation(self, sessions, collection, operation): - original_collection = collection - name = camel_to_snake(operation["name"]) - if name == "run_command": - name = "command" - elif name == "download_by_name": - name = "open_download_stream_by_name" - elif name == "download": - name = "open_download_stream" - elif name == "map_reduce": - self.skipTest("PyMongo does not support mapReduce") - elif name == "count": - self.skipTest("PyMongo does not support count") - - database = collection.database - collection = database.get_collection(collection.name) - if "collectionOptions" in operation: - collection = collection.with_options( - **self.parse_options(operation["collectionOptions"]) - ) - - object_name = self.get_object_name(operation) - if object_name == "gridfsbucket": - # Only create the GridFSBucket when we need it (for the gridfs - # retryable reads tests). - obj = AsyncGridFSBucket(database, bucket_name=collection.name) - else: - objects = { - "client": database.client, - "database": database, - "collection": collection, - "testRunner": self, - } - objects.update(sessions) - obj = objects[object_name] - - # Combine arguments with options and handle special cases. - arguments = operation.get("arguments", {}) - arguments.update(arguments.pop("options", {})) - self.parse_options(arguments) - - cmd = getattr(obj, name) - - with_txn_callback = functools.partial( - self.run_operations, sessions, original_collection, in_with_transaction=True - ) - prepare_spec_arguments(operation, arguments, name, sessions, with_txn_callback) - - if name == "run_on_thread": - args = {"sessions": sessions, "collection": collection} - args.update(arguments) - arguments = args - - if not _IS_SYNC and iscoroutinefunction(cmd): - result = await cmd(**dict(arguments)) - else: - result = cmd(**dict(arguments)) - # Cleanup open change stream cursors. - if name == "watch": - self.addAsyncCleanup(result.close) - - if name == "aggregate": - if arguments["pipeline"] and "$out" in arguments["pipeline"][-1]: - # Read from the primary to ensure causal consistency. - out = collection.database.get_collection( - arguments["pipeline"][-1]["$out"], read_preference=ReadPreference.PRIMARY - ) - return out.find() - if "download" in name: - result = Binary(result.read()) - - if isinstance(result, AsyncCursor) or isinstance(result, AsyncCommandCursor): - return await result.to_list() - - return result - - def allowable_errors(self, op): - """Allow encryption spec to override expected error classes.""" - return (PyMongoError,) - - async def _run_op(self, sessions, collection, op, in_with_transaction): - expected_result = op.get("result") - if expect_error(op): - with self.assertRaises(self.allowable_errors(op), msg=op["name"]) as context: - await self.run_operation(sessions, collection, op.copy()) - exc = context.exception - if expect_error_message(expected_result): - if isinstance(exc, BulkWriteError): - errmsg = str(exc.details).lower() - else: - errmsg = str(exc).lower() - self.assertIn(expected_result["errorContains"].lower(), errmsg) - if expect_error_code(expected_result): - self.assertEqual(expected_result["errorCodeName"], exc.details.get("codeName")) - if expect_error_labels_contain(expected_result): - self.assertErrorLabelsContain(exc, expected_result["errorLabelsContain"]) - if expect_error_labels_omit(expected_result): - self.assertErrorLabelsOmit(exc, expected_result["errorLabelsOmit"]) - if expect_timeout_error(expected_result): - self.assertIsInstance(exc, PyMongoError) - if not exc.timeout: - # Re-raise the exception for better diagnostics. - raise exc - - # Reraise the exception if we're in the with_transaction - # callback. - if in_with_transaction: - raise context.exception - else: - result = await self.run_operation(sessions, collection, op.copy()) - if "result" in op: - if op["name"] == "runCommand": - self.check_command_result(expected_result, result) - else: - self.check_result(expected_result, result) - - async def run_operations(self, sessions, collection, ops, in_with_transaction=False): - for op in ops: - await self._run_op(sessions, collection, op, in_with_transaction) - - # TODO: factor with test_command_monitoring.py - def check_events(self, test, listener, session_ids): - events = listener.started_events - if not len(test["expectations"]): - return - - # Give a nicer message when there are missing or extra events - cmds = decode_raw([event.command for event in events]) - self.assertEqual(len(events), len(test["expectations"]), cmds) - for i, expectation in enumerate(test["expectations"]): - event_type = next(iter(expectation)) - event = events[i] - - # The tests substitute 42 for any number other than 0. - if event.command_name == "getMore" and event.command["getMore"]: - event.command["getMore"] = Int64(42) - elif event.command_name == "killCursors": - event.command["cursors"] = [Int64(42)] - elif event.command_name == "update": - # TODO: remove this once PYTHON-1744 is done. - # Add upsert and multi fields back into expectations. - updates = expectation[event_type]["command"]["updates"] - for update in updates: - update.setdefault("upsert", False) - update.setdefault("multi", False) - - # Replace afterClusterTime: 42 with actual afterClusterTime. - expected_cmd = expectation[event_type]["command"] - expected_read_concern = expected_cmd.get("readConcern") - if expected_read_concern is not None: - time = expected_read_concern.get("afterClusterTime") - if time == 42: - actual_time = event.command.get("readConcern", {}).get("afterClusterTime") - if actual_time is not None: - expected_read_concern["afterClusterTime"] = actual_time - - recovery_token = expected_cmd.get("recoveryToken") - if recovery_token == 42: - expected_cmd["recoveryToken"] = CompareType(dict) - - # Replace lsid with a name like "session0" to match test. - if "lsid" in event.command: - for name, lsid in session_ids.items(): - if event.command["lsid"] == lsid: - event.command["lsid"] = name - break - - for attr, expected in expectation[event_type].items(): - actual = getattr(event, attr) - expected = wrap_types(expected) - if isinstance(expected, dict): - for key, val in expected.items(): - if val is None: - if key in actual: - self.fail(f"Unexpected key [{key}] in {actual!r}") - elif key not in actual: - self.fail(f"Expected key [{key}] in {actual!r}") - else: - self.assertEqual( - val, decode_raw(actual[key]), f"Key [{key}] in {actual}" - ) - else: - self.assertEqual(actual, expected) - - def maybe_skip_scenario(self, test): - if test.get("skipReason"): - self.skipTest(test.get("skipReason")) - - def get_scenario_db_name(self, scenario_def): - """Allow subclasses to override a test's database name.""" - return scenario_def["database_name"] - - def get_scenario_coll_name(self, scenario_def): - """Allow subclasses to override a test's collection name.""" - return scenario_def["collection_name"] - - def get_outcome_coll_name(self, outcome, collection): - """Allow subclasses to override outcome collection.""" - return collection.name - - async def run_test_ops(self, sessions, collection, test): - """Added to allow retryable writes spec to override a test's - operation. - """ - await self.run_operations(sessions, collection, test["operations"]) - - def parse_client_options(self, opts): - """Allow encryption spec to override a clientOptions parsing.""" - return opts - - async def setup_scenario(self, scenario_def): - """Allow specs to override a test's setup.""" - db_name = self.get_scenario_db_name(scenario_def) - coll_name = self.get_scenario_coll_name(scenario_def) - documents = scenario_def["data"] - - # Setup the collection with as few majority writes as possible. - db = async_client_context.client.get_database(db_name) - coll_exists = bool(await db.list_collection_names(filter={"name": coll_name})) - if coll_exists: - await db[coll_name].delete_many({}) - # Only use majority wc only on the final write. - wc = WriteConcern(w="majority") - if documents: - db.get_collection(coll_name, write_concern=wc).insert_many(documents) - elif not coll_exists: - # Ensure collection exists. - await db.create_collection(coll_name, write_concern=wc) - - async def run_scenario(self, scenario_def, test): - self.maybe_skip_scenario(test) - - # Kill all sessions before and after each test to prevent an open - # transaction (from a test failure) from blocking collection/database - # operations during test set up and tear down. - await self.kill_all_sessions() - self.addAsyncCleanup(self.kill_all_sessions) - await self.setup_scenario(scenario_def) - database_name = self.get_scenario_db_name(scenario_def) - collection_name = self.get_scenario_coll_name(scenario_def) - # SPEC-1245 workaround StaleDbVersion on distinct - for c in self.mongos_clients: - await c[database_name][collection_name].distinct("x") - - # Configure the fail point before creating the client. - if "failPoint" in test: - fp = test["failPoint"] - await self.set_fail_point(fp) - self.addAsyncCleanup( - self.set_fail_point, {"configureFailPoint": fp["configureFailPoint"], "mode": "off"} - ) - - listener = OvertCommandListener() - pool_listener = CMAPListener() - server_listener = ServerAndTopologyEventListener() - # Create a new client, to avoid interference from pooled sessions. - client_options = self.parse_client_options(test["clientOptions"]) - use_multi_mongos = test["useMultipleMongoses"] - host = None - if use_multi_mongos: - if async_client_context.load_balancer: - host = async_client_context.MULTI_MONGOS_LB_URI - elif async_client_context.is_mongos: - host = async_client_context.mongos_seeds() - client = await self.async_rs_client( - h=host, event_listeners=[listener, pool_listener, server_listener], **client_options - ) - self.scenario_client = client - self.listener = listener - self.pool_listener = pool_listener - self.server_listener = server_listener - - # Create session0 and session1. - sessions = {} - session_ids = {} - for i in range(2): - # Don't attempt to create sessions if they are not supported by - # the running server version. - if not async_client_context.sessions_enabled: - break - session_name = "session%d" % i - opts = camel_to_snake_args(test["sessionOptions"][session_name]) - if "default_transaction_options" in opts: - txn_opts = self.parse_options(opts["default_transaction_options"]) - txn_opts = client_session.TransactionOptions(**txn_opts) - opts["default_transaction_options"] = txn_opts - - s = client.start_session(**dict(opts)) - - sessions[session_name] = s - # Store lsid so we can access it after end_session, in check_events. - session_ids[session_name] = s.session_id - - self.addAsyncCleanup(end_sessions, sessions) - - collection = client[database_name][collection_name] - await self.run_test_ops(sessions, collection, test) - - await end_sessions(sessions) - - self.check_events(test, listener, session_ids) - - # Disable fail points. - if "failPoint" in test: - fp = test["failPoint"] - await self.set_fail_point( - {"configureFailPoint": fp["configureFailPoint"], "mode": "off"} - ) - - # Assert final state is expected. - outcome = test["outcome"] - expected_c = outcome.get("collection") - if expected_c is not None: - outcome_coll_name = self.get_outcome_coll_name(outcome, collection) - - # Read from the primary with local read concern to ensure causal - # consistency. - outcome_coll = async_client_context.client[collection.database.name].get_collection( - outcome_coll_name, - read_preference=ReadPreference.PRIMARY, - read_concern=ReadConcern("local"), - ) - actual_data = await outcome_coll.find(sort=[("_id", 1)]).to_list() - - # The expected data needs to be the left hand side here otherwise - # CompareType(Binary) doesn't work. - self.assertEqual(wrap_types(expected_c["data"]), actual_data) - - -def expect_any_error(op): - if isinstance(op, dict): - return op.get("error") - - return False - - -def expect_error_message(expected_result): - if isinstance(expected_result, dict): - return isinstance(expected_result["errorContains"], str) - - return False - - -def expect_error_code(expected_result): - if isinstance(expected_result, dict): - return expected_result["errorCodeName"] - - return False - - -def expect_error_labels_contain(expected_result): - if isinstance(expected_result, dict): - return expected_result["errorLabelsContain"] - - return False - - -def expect_error_labels_omit(expected_result): - if isinstance(expected_result, dict): - return expected_result["errorLabelsOmit"] - - return False - - -def expect_timeout_error(expected_result): - if isinstance(expected_result, dict): - return expected_result["isTimeoutError"] - - return False - - -def expect_error(op): - expected_result = op.get("result") - return ( - expect_any_error(op) - or expect_error_message(expected_result) - or expect_error_code(expected_result) - or expect_error_labels_contain(expected_result) - or expect_error_labels_omit(expected_result) - or expect_timeout_error(expected_result) - ) - - -async def end_sessions(sessions): - for s in sessions.values(): - # Aborts the transaction if it's open. - await s.end_session() - - -def decode_raw(val): - """Decode RawBSONDocuments in the given container.""" - if isinstance(val, (list, abc.Mapping)): - return decode(encode({"v": val}))["v"] - return val - - -TYPES = { - "binData": Binary, - "long": Int64, - "int": int, - "string": str, - "objectId": ObjectId, - "object": dict, - "array": list, -} - - -def wrap_types(val): - """Support $$type assertion in command results.""" - if isinstance(val, list): - return [wrap_types(v) for v in val] - if isinstance(val, abc.Mapping): - typ = val.get("$$type") - if typ: - if isinstance(typ, str): - types = TYPES[typ] - else: - types = tuple(TYPES[t] for t in typ) - return CompareType(types) - d = {} - for key in val: - d[key] = wrap_types(val[key]) - return d - return val diff --git a/test/client-backpressure/backpressure-connection-checkin.json b/test/client-backpressure/backpressure-connection-checkin.json new file mode 100644 index 0000000000..794951ad5f --- /dev/null +++ b/test/client-backpressure/backpressure-connection-checkin.json @@ -0,0 +1,111 @@ +{ + "description": "tests that connections are returned to the pool on retry attempts for overload errors", + "schemaVersion": "1.3", + "runOnRequirements": [ + { + "minServerVersion": "4.4", + "topologies": [ + "replicaset", + "sharded", + "load-balanced" + ] + } + ], + "createEntities": [ + { + "client": { + "id": "client", + "useMultipleMongoses": false, + "observeEvents": [ + "connectionCheckedOutEvent", + "connectionCheckedInEvent" + ] + } + }, + { + "client": { + "id": "fail_point_client", + "useMultipleMongoses": false + } + }, + { + "database": { + "id": "database", + "client": "client", + "databaseName": "backpressure-connection-checkin" + } + }, + { + "collection": { + "id": "collection", + "database": "database", + "collectionName": "coll" + } + } + ], + "tests": [ + { + "description": "overload error retry attempts return connections to the pool", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "find" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "find", + "object": "collection", + "arguments": { + "filter": {} + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "eventType": "cmap", + "events": [ + { + "connectionCheckedOutEvent": {} + }, + { + "connectionCheckedInEvent": {} + }, + { + "connectionCheckedOutEvent": {} + }, + { + "connectionCheckedInEvent": {} + }, + { + "connectionCheckedOutEvent": {} + }, + { + "connectionCheckedInEvent": {} + } + ] + } + ] + } + ] +} diff --git a/test/client-backpressure/backpressure-retry-loop.json b/test/client-backpressure/backpressure-retry-loop.json new file mode 100644 index 0000000000..a0b4877fac --- /dev/null +++ b/test/client-backpressure/backpressure-retry-loop.json @@ -0,0 +1,4553 @@ +{ + "description": "tests that operations respect overload backoff retry loop", + "schemaVersion": "1.3", + "runOnRequirements": [ + { + "minServerVersion": "4.4", + "topologies": [ + "replicaset", + "sharded", + "load-balanced" + ] + } + ], + "createEntities": [ + { + "client": { + "id": "client", + "useMultipleMongoses": false, + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent", + "commandFailedEvent" + ], + "ignoreCommandMonitoringEvents": [ + "killCursors" + ] + } + }, + { + "client": { + "id": "internal_client", + "useMultipleMongoses": false + } + }, + { + "database": { + "id": "internal_db", + "client": "internal_client", + "databaseName": "retryable-writes-tests" + } + }, + { + "collection": { + "id": "retryable-writes-tests", + "database": "internal_db", + "collectionName": "coll" + } + }, + { + "database": { + "id": "database", + "client": "client", + "databaseName": "retryable-writes-tests" + } + }, + { + "collection": { + "id": "collection", + "database": "database", + "collectionName": "coll" + } + }, + { + "client": { + "id": "client_retryReads_false", + "useMultipleMongoses": false, + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent", + "commandFailedEvent" + ], + "ignoreCommandMonitoringEvents": [ + "killCursors" + ], + "uriOptions": { + "retryReads": false + } + } + }, + { + "database": { + "id": "database_retryReads_false", + "client": "client_retryReads_false", + "databaseName": "retryable-writes-tests" + } + }, + { + "collection": { + "id": "collection_retryReads_false", + "database": "database_retryReads_false", + "collectionName": "coll" + } + }, + { + "client": { + "id": "client_retryWrites_false", + "useMultipleMongoses": false, + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent", + "commandFailedEvent" + ], + "ignoreCommandMonitoringEvents": [ + "killCursors" + ], + "uriOptions": { + "retryWrites": false + } + } + }, + { + "database": { + "id": "database_retryWrites_false", + "client": "client_retryWrites_false", + "databaseName": "retryable-writes-tests" + } + }, + { + "collection": { + "id": "collection_retryWrites_false", + "database": "database_retryWrites_false", + "collectionName": "coll" + } + } + ], + "initialData": [ + { + "collectionName": "coll", + "databaseName": "retryable-writes-tests", + "documents": [] + } + ], + "_yamlAnchors": { + "bulWriteInsertNamespace": "retryable-writes-tests.coll" + }, + "tests": [ + { + "description": "client.listDatabases retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "listDatabases" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "listDatabases", + "object": "client", + "arguments": { + "filter": {} + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "listDatabases" + } + }, + { + "commandFailedEvent": { + "commandName": "listDatabases" + } + }, + { + "commandStartedEvent": { + "commandName": "listDatabases" + } + }, + { + "commandFailedEvent": { + "commandName": "listDatabases" + } + }, + { + "commandStartedEvent": { + "commandName": "listDatabases" + } + }, + { + "commandSucceededEvent": { + "commandName": "listDatabases" + } + } + ] + } + ] + }, + { + "description": "client.listDatabases (read) does not retry if retryReads=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "listDatabases" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "listDatabases", + "object": "client_retryReads_false", + "arguments": { + "filter": {} + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryReads_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "listDatabases" + } + }, + { + "commandFailedEvent": { + "commandName": "listDatabases" + } + } + ] + } + ] + }, + { + "description": "client.listDatabaseNames retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "listDatabases" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "listDatabaseNames", + "object": "client" + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "listDatabases" + } + }, + { + "commandFailedEvent": { + "commandName": "listDatabases" + } + }, + { + "commandStartedEvent": { + "commandName": "listDatabases" + } + }, + { + "commandFailedEvent": { + "commandName": "listDatabases" + } + }, + { + "commandStartedEvent": { + "commandName": "listDatabases" + } + }, + { + "commandSucceededEvent": { + "commandName": "listDatabases" + } + } + ] + } + ] + }, + { + "description": "client.listDatabaseNames (read) does not retry if retryReads=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "listDatabases" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "listDatabaseNames", + "object": "client_retryReads_false", + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryReads_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "listDatabases" + } + }, + { + "commandFailedEvent": { + "commandName": "listDatabases" + } + } + ] + } + ] + }, + { + "description": "client.createChangeStream retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "aggregate" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "createChangeStream", + "object": "client", + "arguments": { + "pipeline": [] + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + }, + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + }, + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandSucceededEvent": { + "commandName": "aggregate" + } + } + ] + } + ] + }, + { + "description": "client.createChangeStream (read) does not retry if retryReads=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "aggregate" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "createChangeStream", + "object": "client_retryReads_false", + "arguments": { + "pipeline": [] + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryReads_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + } + ] + } + ] + }, + { + "description": "client.clientBulkWrite retries using operation loop", + "runOnRequirements": [ + { + "minServerVersion": "8.0" + } + ], + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "bulkWrite" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "clientBulkWrite", + "object": "client", + "arguments": { + "models": [ + { + "insertOne": { + "namespace": "retryable-writes-tests.coll", + "document": { + "_id": 8, + "x": 88 + } + } + } + ] + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "bulkWrite" + } + }, + { + "commandFailedEvent": { + "commandName": "bulkWrite" + } + }, + { + "commandStartedEvent": { + "commandName": "bulkWrite" + } + }, + { + "commandFailedEvent": { + "commandName": "bulkWrite" + } + }, + { + "commandStartedEvent": { + "commandName": "bulkWrite" + } + }, + { + "commandSucceededEvent": { + "commandName": "bulkWrite" + } + } + ] + } + ] + }, + { + "description": "client.clientBulkWrite (write) does not retry if retryWrites=false", + "runOnRequirements": [ + { + "minServerVersion": "8.0" + } + ], + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "bulkWrite" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "clientBulkWrite", + "object": "client_retryWrites_false", + "arguments": { + "models": [ + { + "insertOne": { + "namespace": "retryable-writes-tests.coll", + "document": { + "_id": 8, + "x": 88 + } + } + } + ] + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryWrites_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "bulkWrite" + } + }, + { + "commandFailedEvent": { + "commandName": "bulkWrite" + } + } + ] + } + ] + }, + { + "description": "database.aggregate read retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "aggregate" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "aggregate", + "object": "database", + "arguments": { + "pipeline": [ + { + "$listLocalSessions": {} + }, + { + "$limit": 1 + } + ] + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + }, + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + }, + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandSucceededEvent": { + "commandName": "aggregate" + } + } + ] + } + ] + }, + { + "description": "database.aggregate (read) does not retry if retryReads=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "aggregate" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "aggregate", + "object": "database_retryReads_false", + "arguments": { + "pipeline": [ + { + "$listLocalSessions": {} + }, + { + "$limit": 1 + } + ] + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryReads_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + } + ] + } + ] + }, + { + "description": "database.listCollections retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "listCollections" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "listCollections", + "object": "database", + "arguments": { + "filter": {} + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "listCollections" + } + }, + { + "commandFailedEvent": { + "commandName": "listCollections" + } + }, + { + "commandStartedEvent": { + "commandName": "listCollections" + } + }, + { + "commandFailedEvent": { + "commandName": "listCollections" + } + }, + { + "commandStartedEvent": { + "commandName": "listCollections" + } + }, + { + "commandSucceededEvent": { + "commandName": "listCollections" + } + } + ] + } + ] + }, + { + "description": "database.listCollections (read) does not retry if retryReads=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "listCollections" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "listCollections", + "object": "database_retryReads_false", + "arguments": { + "filter": {} + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryReads_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "listCollections" + } + }, + { + "commandFailedEvent": { + "commandName": "listCollections" + } + } + ] + } + ] + }, + { + "description": "database.listCollectionNames retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "listCollections" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "listCollectionNames", + "object": "database", + "arguments": { + "filter": {} + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "listCollections" + } + }, + { + "commandFailedEvent": { + "commandName": "listCollections" + } + }, + { + "commandStartedEvent": { + "commandName": "listCollections" + } + }, + { + "commandFailedEvent": { + "commandName": "listCollections" + } + }, + { + "commandStartedEvent": { + "commandName": "listCollections" + } + }, + { + "commandSucceededEvent": { + "commandName": "listCollections" + } + } + ] + } + ] + }, + { + "description": "database.listCollectionNames (read) does not retry if retryReads=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "listCollections" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "listCollectionNames", + "object": "database_retryReads_false", + "arguments": { + "filter": {} + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryReads_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "listCollections" + } + }, + { + "commandFailedEvent": { + "commandName": "listCollections" + } + } + ] + } + ] + }, + { + "description": "database.runCommand retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "ping" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "runCommand", + "object": "database", + "arguments": { + "command": { + "ping": 1 + }, + "commandName": "ping" + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "ping" + } + }, + { + "commandFailedEvent": { + "commandName": "ping" + } + }, + { + "commandStartedEvent": { + "commandName": "ping" + } + }, + { + "commandFailedEvent": { + "commandName": "ping" + } + }, + { + "commandStartedEvent": { + "commandName": "ping" + } + }, + { + "commandSucceededEvent": { + "commandName": "ping" + } + } + ] + } + ] + }, + { + "description": "database.runCommand (read) does not retry if retryReads=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "ping" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "runCommand", + "object": "database_retryReads_false", + "arguments": { + "command": { + "ping": 1 + }, + "commandName": "ping" + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryReads_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "ping" + } + }, + { + "commandFailedEvent": { + "commandName": "ping" + } + } + ] + } + ] + }, + { + "description": "database.runCommand (write) does not retry if retryWrites=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "ping" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "runCommand", + "object": "database_retryWrites_false", + "arguments": { + "command": { + "ping": 1 + }, + "commandName": "ping" + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryWrites_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "ping" + } + }, + { + "commandFailedEvent": { + "commandName": "ping" + } + } + ] + } + ] + }, + { + "description": "database.createChangeStream retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "aggregate" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "createChangeStream", + "object": "database", + "arguments": { + "pipeline": [] + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + }, + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + }, + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandSucceededEvent": { + "commandName": "aggregate" + } + } + ] + } + ] + }, + { + "description": "database.createChangeStream (read) does not retry if retryReads=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "aggregate" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "createChangeStream", + "object": "database_retryReads_false", + "arguments": { + "pipeline": [] + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryReads_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + } + ] + } + ] + }, + { + "description": "collection.aggregate read retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "aggregate" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "aggregate", + "object": "collection", + "arguments": { + "pipeline": [] + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + }, + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + }, + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandSucceededEvent": { + "commandName": "aggregate" + } + } + ] + } + ] + }, + { + "description": "collection.aggregate (read) does not retry if retryReads=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "aggregate" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "aggregate", + "object": "collection_retryReads_false", + "arguments": { + "pipeline": [] + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryReads_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + } + ] + } + ] + }, + { + "description": "collection.countDocuments retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "aggregate" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "countDocuments", + "object": "collection", + "arguments": { + "filter": {} + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + }, + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + }, + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandSucceededEvent": { + "commandName": "aggregate" + } + } + ] + } + ] + }, + { + "description": "collection.countDocuments (read) does not retry if retryReads=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "aggregate" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "countDocuments", + "object": "collection_retryReads_false", + "arguments": { + "filter": {} + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryReads_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + } + ] + } + ] + }, + { + "description": "collection.estimatedDocumentCount retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "count" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "estimatedDocumentCount", + "object": "collection" + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "count" + } + }, + { + "commandFailedEvent": { + "commandName": "count" + } + }, + { + "commandStartedEvent": { + "commandName": "count" + } + }, + { + "commandFailedEvent": { + "commandName": "count" + } + }, + { + "commandStartedEvent": { + "commandName": "count" + } + }, + { + "commandSucceededEvent": { + "commandName": "count" + } + } + ] + } + ] + }, + { + "description": "collection.estimatedDocumentCount (read) does not retry if retryReads=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "count" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "estimatedDocumentCount", + "object": "collection_retryReads_false", + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryReads_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "count" + } + }, + { + "commandFailedEvent": { + "commandName": "count" + } + } + ] + } + ] + }, + { + "description": "collection.distinct retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "distinct" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "distinct", + "object": "collection", + "arguments": { + "fieldName": "x", + "filter": {} + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "distinct" + } + }, + { + "commandFailedEvent": { + "commandName": "distinct" + } + }, + { + "commandStartedEvent": { + "commandName": "distinct" + } + }, + { + "commandFailedEvent": { + "commandName": "distinct" + } + }, + { + "commandStartedEvent": { + "commandName": "distinct" + } + }, + { + "commandSucceededEvent": { + "commandName": "distinct" + } + } + ] + } + ] + }, + { + "description": "collection.distinct (read) does not retry if retryReads=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "distinct" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "distinct", + "object": "collection_retryReads_false", + "arguments": { + "fieldName": "x", + "filter": {} + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryReads_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "distinct" + } + }, + { + "commandFailedEvent": { + "commandName": "distinct" + } + } + ] + } + ] + }, + { + "description": "collection.find retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "find" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "find", + "object": "collection", + "arguments": { + "filter": {} + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "find" + } + }, + { + "commandFailedEvent": { + "commandName": "find" + } + }, + { + "commandStartedEvent": { + "commandName": "find" + } + }, + { + "commandFailedEvent": { + "commandName": "find" + } + }, + { + "commandStartedEvent": { + "commandName": "find" + } + }, + { + "commandSucceededEvent": { + "commandName": "find" + } + } + ] + } + ] + }, + { + "description": "collection.find (read) does not retry if retryReads=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "find" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "find", + "object": "collection_retryReads_false", + "arguments": { + "filter": {} + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryReads_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "find" + } + }, + { + "commandFailedEvent": { + "commandName": "find" + } + } + ] + } + ] + }, + { + "description": "collection.findOne retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "find" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "findOne", + "object": "collection", + "arguments": { + "filter": {} + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "find" + } + }, + { + "commandFailedEvent": { + "commandName": "find" + } + }, + { + "commandStartedEvent": { + "commandName": "find" + } + }, + { + "commandFailedEvent": { + "commandName": "find" + } + }, + { + "commandStartedEvent": { + "commandName": "find" + } + }, + { + "commandSucceededEvent": { + "commandName": "find" + } + } + ] + } + ] + }, + { + "description": "collection.findOne (read) does not retry if retryReads=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "find" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "findOne", + "object": "collection_retryReads_false", + "arguments": { + "filter": {} + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryReads_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "find" + } + }, + { + "commandFailedEvent": { + "commandName": "find" + } + } + ] + } + ] + }, + { + "description": "collection.listIndexes retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "listIndexes" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "listIndexes", + "object": "collection" + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "listIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "listIndexes" + } + }, + { + "commandStartedEvent": { + "commandName": "listIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "listIndexes" + } + }, + { + "commandStartedEvent": { + "commandName": "listIndexes" + } + }, + { + "commandSucceededEvent": { + "commandName": "listIndexes" + } + } + ] + } + ] + }, + { + "description": "collection.listIndexes (read) does not retry if retryReads=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "listIndexes" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "listIndexes", + "object": "collection_retryReads_false", + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryReads_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "listIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "listIndexes" + } + } + ] + } + ] + }, + { + "description": "collection.listIndexNames retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "listIndexes" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "listIndexNames", + "object": "collection" + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "listIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "listIndexes" + } + }, + { + "commandStartedEvent": { + "commandName": "listIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "listIndexes" + } + }, + { + "commandStartedEvent": { + "commandName": "listIndexes" + } + }, + { + "commandSucceededEvent": { + "commandName": "listIndexes" + } + } + ] + } + ] + }, + { + "description": "collection.listIndexNames (read) does not retry if retryReads=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "listIndexes" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "listIndexNames", + "object": "collection_retryReads_false", + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryReads_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "listIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "listIndexes" + } + } + ] + } + ] + }, + { + "description": "collection.createChangeStream retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "aggregate" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "createChangeStream", + "object": "collection", + "arguments": { + "pipeline": [] + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + }, + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + }, + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandSucceededEvent": { + "commandName": "aggregate" + } + } + ] + } + ] + }, + { + "description": "collection.createChangeStream (read) does not retry if retryReads=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "aggregate" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "createChangeStream", + "object": "collection_retryReads_false", + "arguments": { + "pipeline": [] + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryReads_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + } + ] + } + ] + }, + { + "description": "collection.insertOne retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "insert" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "insertOne", + "object": "collection", + "arguments": { + "document": { + "_id": 2, + "x": 22 + } + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandSucceededEvent": { + "commandName": "insert" + } + } + ] + } + ] + }, + { + "description": "collection.insertOne (write) does not retry if retryWrites=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "insert" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "insertOne", + "object": "collection_retryWrites_false", + "arguments": { + "document": { + "_id": 2, + "x": 22 + } + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryWrites_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + } + ] + } + ] + }, + { + "description": "collection.insertMany retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "insert" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "insertMany", + "object": "collection", + "arguments": { + "documents": [ + { + "_id": 2, + "x": 22 + } + ] + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandSucceededEvent": { + "commandName": "insert" + } + } + ] + } + ] + }, + { + "description": "collection.insertMany (write) does not retry if retryWrites=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "insert" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "insertMany", + "object": "collection_retryWrites_false", + "arguments": { + "documents": [ + { + "_id": 2, + "x": 22 + } + ] + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryWrites_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + } + ] + } + ] + }, + { + "description": "collection.deleteOne retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "delete" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "deleteOne", + "object": "collection", + "arguments": { + "filter": {} + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "delete" + } + }, + { + "commandFailedEvent": { + "commandName": "delete" + } + }, + { + "commandStartedEvent": { + "commandName": "delete" + } + }, + { + "commandFailedEvent": { + "commandName": "delete" + } + }, + { + "commandStartedEvent": { + "commandName": "delete" + } + }, + { + "commandSucceededEvent": { + "commandName": "delete" + } + } + ] + } + ] + }, + { + "description": "collection.deleteOne (write) does not retry if retryWrites=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "delete" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "deleteOne", + "object": "collection_retryWrites_false", + "arguments": { + "filter": {} + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryWrites_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "delete" + } + }, + { + "commandFailedEvent": { + "commandName": "delete" + } + } + ] + } + ] + }, + { + "description": "collection.deleteMany retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "delete" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "deleteMany", + "object": "collection", + "arguments": { + "filter": {} + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "delete" + } + }, + { + "commandFailedEvent": { + "commandName": "delete" + } + }, + { + "commandStartedEvent": { + "commandName": "delete" + } + }, + { + "commandFailedEvent": { + "commandName": "delete" + } + }, + { + "commandStartedEvent": { + "commandName": "delete" + } + }, + { + "commandSucceededEvent": { + "commandName": "delete" + } + } + ] + } + ] + }, + { + "description": "collection.deleteMany (write) does not retry if retryWrites=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "delete" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "deleteMany", + "object": "collection_retryWrites_false", + "arguments": { + "filter": {} + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryWrites_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "delete" + } + }, + { + "commandFailedEvent": { + "commandName": "delete" + } + } + ] + } + ] + }, + { + "description": "collection.replaceOne retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "update" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "replaceOne", + "object": "collection", + "arguments": { + "filter": {}, + "replacement": { + "x": 22 + } + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "update" + } + }, + { + "commandFailedEvent": { + "commandName": "update" + } + }, + { + "commandStartedEvent": { + "commandName": "update" + } + }, + { + "commandFailedEvent": { + "commandName": "update" + } + }, + { + "commandStartedEvent": { + "commandName": "update" + } + }, + { + "commandSucceededEvent": { + "commandName": "update" + } + } + ] + } + ] + }, + { + "description": "collection.replaceOne (write) does not retry if retryWrites=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "update" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "replaceOne", + "object": "collection_retryWrites_false", + "arguments": { + "filter": {}, + "replacement": { + "x": 22 + } + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryWrites_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "update" + } + }, + { + "commandFailedEvent": { + "commandName": "update" + } + } + ] + } + ] + }, + { + "description": "collection.updateOne retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "update" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "updateOne", + "object": "collection", + "arguments": { + "filter": {}, + "update": { + "$set": { + "x": 22 + } + } + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "update" + } + }, + { + "commandFailedEvent": { + "commandName": "update" + } + }, + { + "commandStartedEvent": { + "commandName": "update" + } + }, + { + "commandFailedEvent": { + "commandName": "update" + } + }, + { + "commandStartedEvent": { + "commandName": "update" + } + }, + { + "commandSucceededEvent": { + "commandName": "update" + } + } + ] + } + ] + }, + { + "description": "collection.updateOne (write) does not retry if retryWrites=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "update" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "updateOne", + "object": "collection_retryWrites_false", + "arguments": { + "filter": {}, + "update": { + "$set": { + "x": 22 + } + } + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryWrites_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "update" + } + }, + { + "commandFailedEvent": { + "commandName": "update" + } + } + ] + } + ] + }, + { + "description": "collection.updateMany retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "update" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "updateMany", + "object": "collection", + "arguments": { + "filter": {}, + "update": { + "$set": { + "x": 22 + } + } + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "update" + } + }, + { + "commandFailedEvent": { + "commandName": "update" + } + }, + { + "commandStartedEvent": { + "commandName": "update" + } + }, + { + "commandFailedEvent": { + "commandName": "update" + } + }, + { + "commandStartedEvent": { + "commandName": "update" + } + }, + { + "commandSucceededEvent": { + "commandName": "update" + } + } + ] + } + ] + }, + { + "description": "collection.updateMany (write) does not retry if retryWrites=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "update" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "updateMany", + "object": "collection_retryWrites_false", + "arguments": { + "filter": {}, + "update": { + "$set": { + "x": 22 + } + } + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryWrites_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "update" + } + }, + { + "commandFailedEvent": { + "commandName": "update" + } + } + ] + } + ] + }, + { + "description": "collection.findOneAndDelete retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "findAndModify" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "findOneAndDelete", + "object": "collection", + "arguments": { + "filter": {} + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandFailedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandStartedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandFailedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandStartedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandSucceededEvent": { + "commandName": "findAndModify" + } + } + ] + } + ] + }, + { + "description": "collection.findOneAndDelete (write) does not retry if retryWrites=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "findAndModify" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "findOneAndDelete", + "object": "collection_retryWrites_false", + "arguments": { + "filter": {} + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryWrites_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandFailedEvent": { + "commandName": "findAndModify" + } + } + ] + } + ] + }, + { + "description": "collection.findOneAndReplace retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "findAndModify" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "findOneAndReplace", + "object": "collection", + "arguments": { + "filter": {}, + "replacement": { + "x": 22 + } + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandFailedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandStartedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandFailedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandStartedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandSucceededEvent": { + "commandName": "findAndModify" + } + } + ] + } + ] + }, + { + "description": "collection.findOneAndReplace (write) does not retry if retryWrites=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "findAndModify" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "findOneAndReplace", + "object": "collection_retryWrites_false", + "arguments": { + "filter": {}, + "replacement": { + "x": 22 + } + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryWrites_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandFailedEvent": { + "commandName": "findAndModify" + } + } + ] + } + ] + }, + { + "description": "collection.findOneAndUpdate retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "findAndModify" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "findOneAndUpdate", + "object": "collection", + "arguments": { + "filter": {}, + "update": { + "$set": { + "x": 22 + } + } + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandFailedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandStartedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandFailedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandStartedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandSucceededEvent": { + "commandName": "findAndModify" + } + } + ] + } + ] + }, + { + "description": "collection.findOneAndUpdate (write) does not retry if retryWrites=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "findAndModify" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "findOneAndUpdate", + "object": "collection_retryWrites_false", + "arguments": { + "filter": {}, + "update": { + "$set": { + "x": 22 + } + } + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryWrites_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandFailedEvent": { + "commandName": "findAndModify" + } + } + ] + } + ] + }, + { + "description": "collection.bulkWrite retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "insert" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "bulkWrite", + "object": "collection", + "arguments": { + "requests": [ + { + "insertOne": { + "document": { + "_id": 2, + "x": 22 + } + } + } + ] + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandSucceededEvent": { + "commandName": "insert" + } + } + ] + } + ] + }, + { + "description": "collection.bulkWrite (write) does not retry if retryWrites=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "insert" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "bulkWrite", + "object": "collection_retryWrites_false", + "arguments": { + "requests": [ + { + "insertOne": { + "document": { + "_id": 2, + "x": 22 + } + } + } + ] + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryWrites_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + } + ] + } + ] + }, + { + "description": "collection.createIndex retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "createIndexes" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "createIndex", + "object": "collection", + "arguments": { + "keys": { + "x": 11 + }, + "name": "x_11" + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "createIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "createIndexes" + } + }, + { + "commandStartedEvent": { + "commandName": "createIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "createIndexes" + } + }, + { + "commandStartedEvent": { + "commandName": "createIndexes" + } + }, + { + "commandSucceededEvent": { + "commandName": "createIndexes" + } + } + ] + } + ] + }, + { + "description": "collection.createIndex (write) does not retry if retryWrites=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "createIndexes" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "createIndex", + "object": "collection_retryWrites_false", + "arguments": { + "keys": { + "x": 11 + }, + "name": "x_11" + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryWrites_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "createIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "createIndexes" + } + } + ] + } + ] + }, + { + "description": "collection.dropIndex retries using operation loop", + "operations": [ + { + "name": "createIndex", + "object": "retryable-writes-tests", + "arguments": { + "keys": { + "x": 11 + }, + "name": "x_11" + } + }, + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "dropIndexes" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "dropIndex", + "object": "collection", + "arguments": { + "name": "x_11" + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "dropIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "dropIndexes" + } + }, + { + "commandStartedEvent": { + "commandName": "dropIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "dropIndexes" + } + }, + { + "commandStartedEvent": { + "commandName": "dropIndexes" + } + }, + { + "commandSucceededEvent": { + "commandName": "dropIndexes" + } + } + ] + } + ] + }, + { + "description": "collection.dropIndex (write) does not retry if retryWrites=false", + "operations": [ + { + "name": "createIndex", + "object": "retryable-writes-tests", + "arguments": { + "keys": { + "x": 11 + }, + "name": "x_11" + } + }, + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "dropIndexes" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "dropIndex", + "object": "collection_retryWrites_false", + "arguments": { + "name": "x_11" + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryWrites_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "dropIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "dropIndexes" + } + } + ] + } + ] + }, + { + "description": "collection.dropIndexes retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "dropIndexes" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "dropIndexes", + "object": "collection" + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "dropIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "dropIndexes" + } + }, + { + "commandStartedEvent": { + "commandName": "dropIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "dropIndexes" + } + }, + { + "commandStartedEvent": { + "commandName": "dropIndexes" + } + }, + { + "commandSucceededEvent": { + "commandName": "dropIndexes" + } + } + ] + } + ] + }, + { + "description": "collection.dropIndexes (write) does not retry if retryWrites=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "dropIndexes" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "dropIndexes", + "object": "collection_retryWrites_false", + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryWrites_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "dropIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "dropIndexes" + } + } + ] + } + ] + }, + { + "description": "collection.aggregate write retries using operation loop", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "aggregate" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "aggregate", + "object": "collection", + "arguments": { + "pipeline": [ + { + "$out": "output" + } + ] + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + }, + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + }, + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandSucceededEvent": { + "commandName": "aggregate" + } + } + ] + } + ] + }, + { + "description": "collection.aggregate (write) does not retry if retryWrites=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "internal_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "aggregate" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "aggregate", + "object": "collection_retryWrites_false", + "arguments": { + "pipeline": [ + { + "$out": "output" + } + ] + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client_retryWrites_false", + "events": [ + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + } + ] + } + ] + } + ] +} diff --git a/test/client-backpressure/backpressure-retry-max-attempts.json b/test/client-backpressure/backpressure-retry-max-attempts.json new file mode 100644 index 0000000000..de52572765 --- /dev/null +++ b/test/client-backpressure/backpressure-retry-max-attempts.json @@ -0,0 +1,2569 @@ +{ + "description": "tests that operations retry at most maxAttempts=2 times", + "schemaVersion": "1.3", + "runOnRequirements": [ + { + "minServerVersion": "4.4", + "topologies": [ + "replicaset", + "sharded", + "load-balanced" + ] + } + ], + "createEntities": [ + { + "client": { + "id": "client", + "useMultipleMongoses": false, + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent", + "commandFailedEvent" + ], + "ignoreCommandMonitoringEvents": [ + "killCursors" + ] + } + }, + { + "client": { + "id": "fail_point_client", + "useMultipleMongoses": false + } + }, + { + "database": { + "id": "database", + "client": "client", + "databaseName": "retryable-writes-tests" + } + }, + { + "collection": { + "id": "collection", + "database": "database", + "collectionName": "coll" + } + } + ], + "_yamlAnchors": { + "bulkWriteInsertNamespace": "retryable-writes-tests.coll" + }, + "initialData": [ + { + "collectionName": "coll", + "databaseName": "retryable-writes-tests", + "documents": [ + { + "_id": 1, + "x": 11 + }, + { + "_id": 2, + "x": 22 + } + ] + } + ], + "tests": [ + { + "description": "client.listDatabases retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "listDatabases" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "listDatabases", + "object": "client", + "arguments": { + "filter": {} + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "listDatabases" + } + }, + { + "commandFailedEvent": { + "commandName": "listDatabases" + } + }, + { + "commandStartedEvent": { + "commandName": "listDatabases" + } + }, + { + "commandFailedEvent": { + "commandName": "listDatabases" + } + }, + { + "commandStartedEvent": { + "commandName": "listDatabases" + } + }, + { + "commandFailedEvent": { + "commandName": "listDatabases" + } + } + ] + } + ] + }, + { + "description": "client.listDatabaseNames retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "listDatabases" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "listDatabaseNames", + "object": "client", + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "listDatabases" + } + }, + { + "commandFailedEvent": { + "commandName": "listDatabases" + } + }, + { + "commandStartedEvent": { + "commandName": "listDatabases" + } + }, + { + "commandFailedEvent": { + "commandName": "listDatabases" + } + }, + { + "commandStartedEvent": { + "commandName": "listDatabases" + } + }, + { + "commandFailedEvent": { + "commandName": "listDatabases" + } + } + ] + } + ] + }, + { + "description": "client.createChangeStream retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "aggregate" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "createChangeStream", + "object": "client", + "arguments": { + "pipeline": [] + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + }, + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + }, + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + } + ] + } + ] + }, + { + "description": "client.clientBulkWrite retries at most maxAttempts=2 times", + "runOnRequirements": [ + { + "minServerVersion": "8.0" + } + ], + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "bulkWrite" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "clientBulkWrite", + "object": "client", + "arguments": { + "models": [ + { + "insertOne": { + "namespace": "retryable-writes-tests.coll", + "document": { + "_id": 8, + "x": 88 + } + } + } + ] + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "bulkWrite" + } + }, + { + "commandFailedEvent": { + "commandName": "bulkWrite" + } + }, + { + "commandStartedEvent": { + "commandName": "bulkWrite" + } + }, + { + "commandFailedEvent": { + "commandName": "bulkWrite" + } + }, + { + "commandStartedEvent": { + "commandName": "bulkWrite" + } + }, + { + "commandFailedEvent": { + "commandName": "bulkWrite" + } + } + ] + } + ] + }, + { + "description": "database.aggregate read retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "aggregate" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "aggregate", + "object": "database", + "arguments": { + "pipeline": [ + { + "$listLocalSessions": {} + }, + { + "$limit": 1 + } + ] + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + }, + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + }, + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + } + ] + } + ] + }, + { + "description": "database.listCollections retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "listCollections" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "listCollections", + "object": "database", + "arguments": { + "filter": {} + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "listCollections" + } + }, + { + "commandFailedEvent": { + "commandName": "listCollections" + } + }, + { + "commandStartedEvent": { + "commandName": "listCollections" + } + }, + { + "commandFailedEvent": { + "commandName": "listCollections" + } + }, + { + "commandStartedEvent": { + "commandName": "listCollections" + } + }, + { + "commandFailedEvent": { + "commandName": "listCollections" + } + } + ] + } + ] + }, + { + "description": "database.listCollectionNames retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "listCollections" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "listCollectionNames", + "object": "database", + "arguments": { + "filter": {} + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "listCollections" + } + }, + { + "commandFailedEvent": { + "commandName": "listCollections" + } + }, + { + "commandStartedEvent": { + "commandName": "listCollections" + } + }, + { + "commandFailedEvent": { + "commandName": "listCollections" + } + }, + { + "commandStartedEvent": { + "commandName": "listCollections" + } + }, + { + "commandFailedEvent": { + "commandName": "listCollections" + } + } + ] + } + ] + }, + { + "description": "database.runCommand retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "ping" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "runCommand", + "object": "database", + "arguments": { + "command": { + "ping": 1 + }, + "commandName": "ping" + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "ping" + } + }, + { + "commandFailedEvent": { + "commandName": "ping" + } + }, + { + "commandStartedEvent": { + "commandName": "ping" + } + }, + { + "commandFailedEvent": { + "commandName": "ping" + } + }, + { + "commandStartedEvent": { + "commandName": "ping" + } + }, + { + "commandFailedEvent": { + "commandName": "ping" + } + } + ] + } + ] + }, + { + "description": "database.createChangeStream retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "aggregate" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "createChangeStream", + "object": "database", + "arguments": { + "pipeline": [] + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + }, + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + }, + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + } + ] + } + ] + }, + { + "description": "collection.aggregate read retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "aggregate" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "aggregate", + "object": "collection", + "arguments": { + "pipeline": [] + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + }, + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + }, + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + } + ] + } + ] + }, + { + "description": "collection.countDocuments retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "aggregate" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "countDocuments", + "object": "collection", + "arguments": { + "filter": {} + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + }, + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + }, + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + } + ] + } + ] + }, + { + "description": "collection.estimatedDocumentCount retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "count" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "estimatedDocumentCount", + "object": "collection", + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "count" + } + }, + { + "commandFailedEvent": { + "commandName": "count" + } + }, + { + "commandStartedEvent": { + "commandName": "count" + } + }, + { + "commandFailedEvent": { + "commandName": "count" + } + }, + { + "commandStartedEvent": { + "commandName": "count" + } + }, + { + "commandFailedEvent": { + "commandName": "count" + } + } + ] + } + ] + }, + { + "description": "collection.distinct retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "distinct" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "distinct", + "object": "collection", + "arguments": { + "fieldName": "x", + "filter": {} + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "distinct" + } + }, + { + "commandFailedEvent": { + "commandName": "distinct" + } + }, + { + "commandStartedEvent": { + "commandName": "distinct" + } + }, + { + "commandFailedEvent": { + "commandName": "distinct" + } + }, + { + "commandStartedEvent": { + "commandName": "distinct" + } + }, + { + "commandFailedEvent": { + "commandName": "distinct" + } + } + ] + } + ] + }, + { + "description": "collection.find retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "find" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "find", + "object": "collection", + "arguments": { + "filter": {} + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "find" + } + }, + { + "commandFailedEvent": { + "commandName": "find" + } + }, + { + "commandStartedEvent": { + "commandName": "find" + } + }, + { + "commandFailedEvent": { + "commandName": "find" + } + }, + { + "commandStartedEvent": { + "commandName": "find" + } + }, + { + "commandFailedEvent": { + "commandName": "find" + } + } + ] + } + ] + }, + { + "description": "collection.findOne retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "find" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "findOne", + "object": "collection", + "arguments": { + "filter": {} + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "find" + } + }, + { + "commandFailedEvent": { + "commandName": "find" + } + }, + { + "commandStartedEvent": { + "commandName": "find" + } + }, + { + "commandFailedEvent": { + "commandName": "find" + } + }, + { + "commandStartedEvent": { + "commandName": "find" + } + }, + { + "commandFailedEvent": { + "commandName": "find" + } + } + ] + } + ] + }, + { + "description": "collection.listIndexes retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "listIndexes" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "listIndexes", + "object": "collection", + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "listIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "listIndexes" + } + }, + { + "commandStartedEvent": { + "commandName": "listIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "listIndexes" + } + }, + { + "commandStartedEvent": { + "commandName": "listIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "listIndexes" + } + } + ] + } + ] + }, + { + "description": "collection.listIndexNames retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "listIndexes" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "listIndexNames", + "object": "collection", + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "listIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "listIndexes" + } + }, + { + "commandStartedEvent": { + "commandName": "listIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "listIndexes" + } + }, + { + "commandStartedEvent": { + "commandName": "listIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "listIndexes" + } + } + ] + } + ] + }, + { + "description": "collection.createChangeStream retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "aggregate" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "createChangeStream", + "object": "collection", + "arguments": { + "pipeline": [] + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + }, + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + }, + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + } + ] + } + ] + }, + { + "description": "collection.insertOne retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "insert" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "insertOne", + "object": "collection", + "arguments": { + "document": { + "_id": 2, + "x": 22 + } + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + } + ] + } + ] + }, + { + "description": "collection.insertMany retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "insert" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "insertMany", + "object": "collection", + "arguments": { + "documents": [ + { + "_id": 2, + "x": 22 + } + ] + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + } + ] + } + ] + }, + { + "description": "collection.deleteOne retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "delete" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "deleteOne", + "object": "collection", + "arguments": { + "filter": {} + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "delete" + } + }, + { + "commandFailedEvent": { + "commandName": "delete" + } + }, + { + "commandStartedEvent": { + "commandName": "delete" + } + }, + { + "commandFailedEvent": { + "commandName": "delete" + } + }, + { + "commandStartedEvent": { + "commandName": "delete" + } + }, + { + "commandFailedEvent": { + "commandName": "delete" + } + } + ] + } + ] + }, + { + "description": "collection.deleteMany retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "delete" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "deleteMany", + "object": "collection", + "arguments": { + "filter": {} + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "delete" + } + }, + { + "commandFailedEvent": { + "commandName": "delete" + } + }, + { + "commandStartedEvent": { + "commandName": "delete" + } + }, + { + "commandFailedEvent": { + "commandName": "delete" + } + }, + { + "commandStartedEvent": { + "commandName": "delete" + } + }, + { + "commandFailedEvent": { + "commandName": "delete" + } + } + ] + } + ] + }, + { + "description": "collection.replaceOne retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "update" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "replaceOne", + "object": "collection", + "arguments": { + "filter": {}, + "replacement": { + "x": 22 + } + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "update" + } + }, + { + "commandFailedEvent": { + "commandName": "update" + } + }, + { + "commandStartedEvent": { + "commandName": "update" + } + }, + { + "commandFailedEvent": { + "commandName": "update" + } + }, + { + "commandStartedEvent": { + "commandName": "update" + } + }, + { + "commandFailedEvent": { + "commandName": "update" + } + } + ] + } + ] + }, + { + "description": "collection.updateOne retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "update" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "updateOne", + "object": "collection", + "arguments": { + "filter": {}, + "update": { + "$set": { + "x": 22 + } + } + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "update" + } + }, + { + "commandFailedEvent": { + "commandName": "update" + } + }, + { + "commandStartedEvent": { + "commandName": "update" + } + }, + { + "commandFailedEvent": { + "commandName": "update" + } + }, + { + "commandStartedEvent": { + "commandName": "update" + } + }, + { + "commandFailedEvent": { + "commandName": "update" + } + } + ] + } + ] + }, + { + "description": "collection.updateMany retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "update" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "updateMany", + "object": "collection", + "arguments": { + "filter": {}, + "update": { + "$set": { + "x": 22 + } + } + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "update" + } + }, + { + "commandFailedEvent": { + "commandName": "update" + } + }, + { + "commandStartedEvent": { + "commandName": "update" + } + }, + { + "commandFailedEvent": { + "commandName": "update" + } + }, + { + "commandStartedEvent": { + "commandName": "update" + } + }, + { + "commandFailedEvent": { + "commandName": "update" + } + } + ] + } + ] + }, + { + "description": "collection.findOneAndDelete retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "findAndModify" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "findOneAndDelete", + "object": "collection", + "arguments": { + "filter": {} + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandFailedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandStartedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandFailedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandStartedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandFailedEvent": { + "commandName": "findAndModify" + } + } + ] + } + ] + }, + { + "description": "collection.findOneAndReplace retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "findAndModify" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "findOneAndReplace", + "object": "collection", + "arguments": { + "filter": {}, + "replacement": { + "x": 22 + } + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandFailedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandStartedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandFailedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandStartedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandFailedEvent": { + "commandName": "findAndModify" + } + } + ] + } + ] + }, + { + "description": "collection.findOneAndUpdate retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "findAndModify" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "findOneAndUpdate", + "object": "collection", + "arguments": { + "filter": {}, + "update": { + "$set": { + "x": 22 + } + } + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandFailedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandStartedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandFailedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandStartedEvent": { + "commandName": "findAndModify" + } + }, + { + "commandFailedEvent": { + "commandName": "findAndModify" + } + } + ] + } + ] + }, + { + "description": "collection.bulkWrite retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "insert" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "bulkWrite", + "object": "collection", + "arguments": { + "requests": [ + { + "insertOne": { + "document": { + "_id": 2, + "x": 22 + } + } + } + ] + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + } + ] + } + ] + }, + { + "description": "collection.createIndex retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "createIndexes" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "createIndex", + "object": "collection", + "arguments": { + "keys": { + "x": 11 + }, + "name": "x_11" + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "createIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "createIndexes" + } + }, + { + "commandStartedEvent": { + "commandName": "createIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "createIndexes" + } + }, + { + "commandStartedEvent": { + "commandName": "createIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "createIndexes" + } + } + ] + } + ] + }, + { + "description": "collection.dropIndex retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "dropIndexes" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "dropIndex", + "object": "collection", + "arguments": { + "name": "x_11" + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "dropIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "dropIndexes" + } + }, + { + "commandStartedEvent": { + "commandName": "dropIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "dropIndexes" + } + }, + { + "commandStartedEvent": { + "commandName": "dropIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "dropIndexes" + } + } + ] + } + ] + }, + { + "description": "collection.dropIndexes retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "dropIndexes" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "dropIndexes", + "object": "collection", + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "dropIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "dropIndexes" + } + }, + { + "commandStartedEvent": { + "commandName": "dropIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "dropIndexes" + } + }, + { + "commandStartedEvent": { + "commandName": "dropIndexes" + } + }, + { + "commandFailedEvent": { + "commandName": "dropIndexes" + } + } + ] + } + ] + }, + { + "description": "collection.aggregate write retries at most maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "fail_point_client", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "aggregate" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "aggregate", + "object": "collection", + "arguments": { + "pipeline": [ + { + "$out": "output" + } + ] + }, + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client", + "events": [ + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + }, + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + }, + { + "commandStartedEvent": { + "commandName": "aggregate" + } + }, + { + "commandFailedEvent": { + "commandName": "aggregate" + } + } + ] + } + ] + } + ] +} diff --git a/test/client-backpressure/getMore-retried.json b/test/client-backpressure/getMore-retried.json new file mode 100644 index 0000000000..d7607d694b --- /dev/null +++ b/test/client-backpressure/getMore-retried.json @@ -0,0 +1,253 @@ +{ + "description": "getMore-retried-backpressure", + "schemaVersion": "1.3", + "runOnRequirements": [ + { + "minServerVersion": "4.4" + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "useMultipleMongoses": false, + "observeEvents": [ + "commandStartedEvent", + "commandFailedEvent", + "commandSucceededEvent" + ] + } + }, + { + "client": { + "id": "failPointClient", + "useMultipleMongoses": false + } + }, + { + "database": { + "id": "db", + "client": "client0", + "databaseName": "default" + } + }, + { + "collection": { + "id": "coll", + "database": "db", + "collectionName": "default" + } + } + ], + "initialData": [ + { + "databaseName": "default", + "collectionName": "default", + "documents": [ + { + "a": 1 + }, + { + "a": 2 + }, + { + "a": 3 + } + ] + } + ], + "tests": [ + { + "description": "getMores are retried", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "getMore" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "find", + "object": "coll", + "arguments": { + "batchSize": 2, + "filter": {}, + "sort": { + "a": 1 + } + }, + "expectResult": [ + { + "a": 1 + }, + { + "a": 2 + }, + { + "a": 3 + } + ] + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "commandName": "find" + } + }, + { + "commandSucceededEvent": { + "commandName": "find" + } + }, + { + "commandStartedEvent": { + "commandName": "getMore" + } + }, + { + "commandFailedEvent": { + "commandName": "getMore" + } + }, + { + "commandStartedEvent": { + "commandName": "getMore" + } + }, + { + "commandFailedEvent": { + "commandName": "getMore" + } + }, + { + "commandStartedEvent": { + "commandName": "getMore" + } + }, + { + "commandSucceededEvent": { + "commandName": "getMore" + } + } + ] + } + ] + }, + { + "description": "getMores are retried maxAttempts=2 times", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "getMore" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 2 + } + } + } + }, + { + "name": "find", + "arguments": { + "batchSize": 2, + "filter": {} + }, + "object": "coll", + "expectError": { + "isError": true, + "isClientError": false + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "commandName": "find" + } + }, + { + "commandSucceededEvent": { + "commandName": "find" + } + }, + { + "commandStartedEvent": { + "commandName": "getMore" + } + }, + { + "commandFailedEvent": { + "commandName": "getMore" + } + }, + { + "commandStartedEvent": { + "commandName": "getMore" + } + }, + { + "commandFailedEvent": { + "commandName": "getMore" + } + }, + { + "commandStartedEvent": { + "commandName": "getMore" + } + }, + { + "commandFailedEvent": { + "commandName": "getMore" + } + }, + { + "commandStartedEvent": { + "commandName": "killCursors" + } + }, + { + "commandSucceededEvent": { + "commandName": "killCursors" + } + } + ] + } + ] + } + ] +} diff --git a/test/connection_monitoring/pool-create-min-size-error.json b/test/connection_monitoring/pool-create-min-size-error.json index 5c8ad02dbd..4334ce2571 100644 --- a/test/connection_monitoring/pool-create-min-size-error.json +++ b/test/connection_monitoring/pool-create-min-size-error.json @@ -9,9 +9,7 @@ ], "failPoint": { "configureFailPoint": "failCommand", - "mode": { - "times": 50 - }, + "mode": "alwaysOn", "data": { "failCommands": [ "isMaster", diff --git a/test/csot/convenient-transactions.json b/test/csot/convenient-transactions.json index f9d03429db..3400b82ba9 100644 --- a/test/csot/convenient-transactions.json +++ b/test/csot/convenient-transactions.json @@ -27,7 +27,8 @@ "awaitMinPoolSizeMS": 10000, "useMultipleMongoses": false, "observeEvents": [ - "commandStartedEvent" + "commandStartedEvent", + "commandFailedEvent" ] } }, @@ -188,6 +189,11 @@ } } }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, { "commandStartedEvent": { "commandName": "abortTransaction", @@ -206,6 +212,105 @@ ] } ] + }, + { + "description": "withTransaction surfaces a timeout after exhausting transient transaction retries, retaining the last transient error as the timeout cause.", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "insert" + ], + "blockConnection": true, + "blockTimeMS": 25, + "errorCode": 24, + "errorLabels": [ + "TransientTransactionError" + ] + } + } + } + }, + { + "name": "withTransaction", + "object": "session", + "arguments": { + "callback": [ + { + "name": "insertOne", + "object": "collection", + "arguments": { + "document": { + "_id": 1 + }, + "session": "session" + }, + "expectError": { + "isError": true + } + } + ] + }, + "expectError": { + "isTimeoutError": true + } + } + ], + "expectEvents": [ + { + "client": "client", + "ignoreExtraEvents": true, + "events": [ + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "commandName": "abortTransaction" + } + }, + { + "commandFailedEvent": { + "commandName": "abortTransaction" + } + }, + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "commandName": "abortTransaction" + } + }, + { + "commandFailedEvent": { + "commandName": "abortTransaction" + } + } + ] + } + ] } ] } diff --git a/test/discovery_and_monitoring/errors/error_handling_handshake.json b/test/discovery_and_monitoring/errors/error_handling_handshake.json index 56ca7d1132..c60ee453dd 100644 --- a/test/discovery_and_monitoring/errors/error_handling_handshake.json +++ b/test/discovery_and_monitoring/errors/error_handling_handshake.json @@ -85,7 +85,7 @@ } }, { - "description": "Mark server unknown on network timeout application error (beforeHandshakeCompletes)", + "description": "Ignore network timeout application error (beforeHandshakeCompletes)", "applicationErrors": [ { "address": "a:27017", @@ -97,14 +97,22 @@ "outcome": { "servers": { "a:27017": { - "type": "Unknown", - "topologyVersion": null, + "type": "RSPrimary", + "setName": "rs", + "topologyVersion": { + "processId": { + "$oid": "000000000000000000000001" + }, + "counter": { + "$numberLong": "1" + } + }, "pool": { - "generation": 1 + "generation": 0 } } }, - "topologyType": "ReplicaSetNoPrimary", + "topologyType": "ReplicaSetWithPrimary", "logicalSessionTimeoutMinutes": null, "setName": "rs" } diff --git a/test/discovery_and_monitoring/rs/disaggregated_storage_setversion.json b/test/discovery_and_monitoring/rs/disaggregated_storage_setversion.json new file mode 100644 index 0000000000..c8b41d30ca --- /dev/null +++ b/test/discovery_and_monitoring/rs/disaggregated_storage_setversion.json @@ -0,0 +1,167 @@ +{ + "description": "Static setVersion (DSC) is compatible with both pre and post DRIVERS-2412", + "uri": "mongodb://a/?replicaSet=rs", + "phases": [ + { + "responses": [ + [ + "a:27017", + { + "ok": 1, + "helloOk": true, + "isWritablePrimary": true, + "hosts": [ + "a:27017", + "b:27017" + ], + "setName": "rs", + "setVersion": 1, + "electionId": { + "$oid": "000000000000000000000005" + }, + "minWireVersion": 0, + "maxWireVersion": 17 + } + ], + [ + "b:27017", + { + "ok": 1, + "helloOk": true, + "isWritablePrimary": false, + "secondary": true, + "hosts": [ + "a:27017", + "b:27017" + ], + "setName": "rs", + "setVersion": 1, + "minWireVersion": 0, + "maxWireVersion": 17 + } + ] + ], + "outcome": { + "servers": { + "a:27017": { + "type": "RSPrimary", + "setName": "rs", + "setVersion": 1, + "electionId": { + "$oid": "000000000000000000000005" + } + }, + "b:27017": { + "type": "RSSecondary", + "setName": "rs", + "setVersion": 1, + "electionId": null + } + }, + "topologyType": "ReplicaSetWithPrimary", + "logicalSessionTimeoutMinutes": null, + "setName": "rs", + "maxSetVersion": 1, + "maxElectionId": { + "$oid": "000000000000000000000005" + } + } + }, + { + "responses": [ + [ + "b:27017", + { + "ok": 1, + "helloOk": true, + "isWritablePrimary": true, + "hosts": [ + "a:27017", + "b:27017" + ], + "setName": "rs", + "setVersion": 1, + "electionId": { + "$oid": "000000000000000000000006" + }, + "minWireVersion": 0, + "maxWireVersion": 17 + } + ] + ], + "outcome": { + "servers": { + "a:27017": { + "type": "Unknown", + "setName": null, + "setVersion": null, + "electionId": null + }, + "b:27017": { + "type": "RSPrimary", + "setName": "rs", + "setVersion": 1, + "electionId": { + "$oid": "000000000000000000000006" + } + } + }, + "topologyType": "ReplicaSetWithPrimary", + "logicalSessionTimeoutMinutes": null, + "setName": "rs", + "maxSetVersion": 1, + "maxElectionId": { + "$oid": "000000000000000000000006" + } + } + }, + { + "responses": [ + [ + "a:27017", + { + "ok": 1, + "helloOk": true, + "isWritablePrimary": true, + "hosts": [ + "a:27017", + "b:27017" + ], + "setName": "rs", + "setVersion": 1, + "electionId": { + "$oid": "000000000000000000000005" + }, + "minWireVersion": 0, + "maxWireVersion": 17 + } + ] + ], + "outcome": { + "servers": { + "a:27017": { + "type": "Unknown", + "setName": null, + "setVersion": null, + "electionId": null + }, + "b:27017": { + "type": "RSPrimary", + "setName": "rs", + "setVersion": 1, + "electionId": { + "$oid": "000000000000000000000006" + } + } + }, + "topologyType": "ReplicaSetWithPrimary", + "logicalSessionTimeoutMinutes": null, + "setName": "rs", + "maxSetVersion": 1, + "maxElectionId": { + "$oid": "000000000000000000000006" + } + } + } + ] +} diff --git a/test/discovery_and_monitoring/rs/member_list_update_with_unchanged_setversion_and_electionid.json b/test/discovery_and_monitoring/rs/member_list_update_with_unchanged_setversion_and_electionid.json new file mode 100644 index 0000000000..0045591db9 --- /dev/null +++ b/test/discovery_and_monitoring/rs/member_list_update_with_unchanged_setversion_and_electionid.json @@ -0,0 +1,227 @@ +{ + "description": "Member list is updated when setVersion and electionId remain the same", + "uri": "mongodb://a/?replicaSet=rs", + "phases": [ + { + "responses": [ + [ + "a:27017", + { + "ok": 1, + "helloOk": true, + "isWritablePrimary": true, + "hosts": [ + "a:27017", + "b:27017" + ], + "setName": "rs", + "setVersion": 1, + "electionId": { + "$oid": "000000000000000000000001" + }, + "minWireVersion": 0, + "maxWireVersion": 17 + } + ], + [ + "b:27017", + { + "ok": 1, + "helloOk": true, + "isWritablePrimary": false, + "secondary": true, + "hosts": [ + "a:27017", + "b:27017" + ], + "setName": "rs", + "setVersion": 1, + "minWireVersion": 0, + "maxWireVersion": 17 + } + ] + ], + "outcome": { + "servers": { + "a:27017": { + "type": "RSPrimary", + "setName": "rs", + "setVersion": 1, + "electionId": { + "$oid": "000000000000000000000001" + } + }, + "b:27017": { + "type": "RSSecondary", + "setName": "rs", + "setVersion": 1, + "electionId": null + } + }, + "topologyType": "ReplicaSetWithPrimary", + "logicalSessionTimeoutMinutes": null, + "setName": "rs", + "maxSetVersion": 1, + "maxElectionId": { + "$oid": "000000000000000000000001" + } + } + }, + { + "responses": [ + [ + "a:27017", + { + "ok": 1, + "helloOk": true, + "isWritablePrimary": true, + "hosts": [ + "a:27017", + "b:27017", + "c:27017" + ], + "setName": "rs", + "setVersion": 1, + "electionId": { + "$oid": "000000000000000000000001" + }, + "minWireVersion": 0, + "maxWireVersion": 17 + } + ] + ], + "outcome": { + "servers": { + "a:27017": { + "type": "RSPrimary", + "setName": "rs", + "setVersion": 1, + "electionId": { + "$oid": "000000000000000000000001" + } + }, + "b:27017": { + "type": "RSSecondary", + "setName": "rs", + "setVersion": 1, + "electionId": null + }, + "c:27017": { + "type": "Unknown", + "setName": null, + "setVersion": null, + "electionId": null + } + }, + "topologyType": "ReplicaSetWithPrimary", + "logicalSessionTimeoutMinutes": null, + "setName": "rs", + "maxSetVersion": 1, + "maxElectionId": { + "$oid": "000000000000000000000001" + } + } + }, + { + "responses": [ + [ + "c:27017", + { + "ok": 1, + "helloOk": true, + "isWritablePrimary": false, + "secondary": true, + "hosts": [ + "a:27017", + "b:27017", + "c:27017" + ], + "setName": "rs", + "setVersion": 1, + "minWireVersion": 0, + "maxWireVersion": 17 + } + ] + ], + "outcome": { + "servers": { + "a:27017": { + "type": "RSPrimary", + "setName": "rs", + "setVersion": 1, + "electionId": { + "$oid": "000000000000000000000001" + } + }, + "b:27017": { + "type": "RSSecondary", + "setName": "rs", + "setVersion": 1, + "electionId": null + }, + "c:27017": { + "type": "RSSecondary", + "setName": "rs", + "setVersion": 1, + "electionId": null + } + }, + "topologyType": "ReplicaSetWithPrimary", + "logicalSessionTimeoutMinutes": null, + "setName": "rs", + "maxSetVersion": 1, + "maxElectionId": { + "$oid": "000000000000000000000001" + } + } + }, + { + "responses": [ + [ + "a:27017", + { + "ok": 1, + "helloOk": true, + "isWritablePrimary": true, + "hosts": [ + "a:27017", + "b:27017" + ], + "setName": "rs", + "setVersion": 1, + "electionId": { + "$oid": "000000000000000000000001" + }, + "minWireVersion": 0, + "maxWireVersion": 17 + } + ] + ], + "outcome": { + "servers": { + "a:27017": { + "type": "RSPrimary", + "setName": "rs", + "setVersion": 1, + "electionId": { + "$oid": "000000000000000000000001" + } + }, + "b:27017": { + "type": "RSSecondary", + "setName": "rs", + "setVersion": 1, + "electionId": null + } + }, + "topologyType": "ReplicaSetWithPrimary", + "logicalSessionTimeoutMinutes": null, + "setName": "rs", + "maxSetVersion": 1, + "maxElectionId": { + "$oid": "000000000000000000000001" + } + } + } + ] +} diff --git a/test/discovery_and_monitoring/rs/migration_from_disaggregated_storage.json b/test/discovery_and_monitoring/rs/migration_from_disaggregated_storage.json new file mode 100644 index 0000000000..c5109026bc --- /dev/null +++ b/test/discovery_and_monitoring/rs/migration_from_disaggregated_storage.json @@ -0,0 +1,167 @@ +{ + "description": "DSC to ASC reverse migration - ASC primary with higher setVersion is accepted", + "uri": "mongodb://a/?replicaSet=rs", + "phases": [ + { + "responses": [ + [ + "a:27017", + { + "ok": 1, + "helloOk": true, + "isWritablePrimary": true, + "hosts": [ + "a:27017", + "b:27017" + ], + "setName": "rs", + "setVersion": 1, + "electionId": { + "$oid": "000000000000000000000005" + }, + "minWireVersion": 0, + "maxWireVersion": 17 + } + ], + [ + "b:27017", + { + "ok": 1, + "helloOk": true, + "isWritablePrimary": false, + "secondary": true, + "hosts": [ + "a:27017", + "b:27017" + ], + "setName": "rs", + "setVersion": 1, + "minWireVersion": 0, + "maxWireVersion": 17 + } + ] + ], + "outcome": { + "servers": { + "a:27017": { + "type": "RSPrimary", + "setName": "rs", + "setVersion": 1, + "electionId": { + "$oid": "000000000000000000000005" + } + }, + "b:27017": { + "type": "RSSecondary", + "setName": "rs", + "setVersion": 1, + "electionId": null + } + }, + "topologyType": "ReplicaSetWithPrimary", + "logicalSessionTimeoutMinutes": null, + "setName": "rs", + "maxSetVersion": 1, + "maxElectionId": { + "$oid": "000000000000000000000005" + } + } + }, + { + "responses": [ + [ + "b:27017", + { + "ok": 1, + "helloOk": true, + "isWritablePrimary": true, + "hosts": [ + "a:27017", + "b:27017" + ], + "setName": "rs", + "setVersion": 1000, + "electionId": { + "$oid": "000000000000000000000006" + }, + "minWireVersion": 0, + "maxWireVersion": 17 + } + ] + ], + "outcome": { + "servers": { + "a:27017": { + "type": "Unknown", + "setName": null, + "setVersion": null, + "electionId": null + }, + "b:27017": { + "type": "RSPrimary", + "setName": "rs", + "setVersion": 1000, + "electionId": { + "$oid": "000000000000000000000006" + } + } + }, + "topologyType": "ReplicaSetWithPrimary", + "logicalSessionTimeoutMinutes": null, + "setName": "rs", + "maxSetVersion": 1000, + "maxElectionId": { + "$oid": "000000000000000000000006" + } + } + }, + { + "responses": [ + [ + "a:27017", + { + "ok": 1, + "helloOk": true, + "isWritablePrimary": true, + "hosts": [ + "a:27017", + "b:27017" + ], + "setName": "rs", + "setVersion": 1, + "electionId": { + "$oid": "000000000000000000000005" + }, + "minWireVersion": 0, + "maxWireVersion": 17 + } + ] + ], + "outcome": { + "servers": { + "a:27017": { + "type": "Unknown", + "setName": null, + "setVersion": null, + "electionId": null + }, + "b:27017": { + "type": "RSPrimary", + "setName": "rs", + "setVersion": 1000, + "electionId": { + "$oid": "000000000000000000000006" + } + } + }, + "topologyType": "ReplicaSetWithPrimary", + "logicalSessionTimeoutMinutes": null, + "setName": "rs", + "maxSetVersion": 1000, + "maxElectionId": { + "$oid": "000000000000000000000006" + } + } + } + ] +} diff --git a/test/discovery_and_monitoring/rs/migration_to_disaggregated_storage.json b/test/discovery_and_monitoring/rs/migration_to_disaggregated_storage.json new file mode 100644 index 0000000000..57f39c93b2 --- /dev/null +++ b/test/discovery_and_monitoring/rs/migration_to_disaggregated_storage.json @@ -0,0 +1,119 @@ +{ + "description": "ASC to DSC forward migration - DSC uses setVersionASC + 1 to prevent false stale detection", + "uri": "mongodb://a/?replicaSet=rs", + "phases": [ + { + "responses": [ + [ + "a:27017", + { + "ok": 1, + "helloOk": true, + "isWritablePrimary": true, + "hosts": [ + "a:27017", + "b:27017" + ], + "setName": "rs", + "setVersion": 10, + "electionId": { + "$oid": "000000000000000000000005" + }, + "minWireVersion": 0, + "maxWireVersion": 17 + } + ], + [ + "b:27017", + { + "ok": 1, + "helloOk": true, + "isWritablePrimary": false, + "secondary": true, + "hosts": [ + "a:27017", + "b:27017" + ], + "setName": "rs", + "setVersion": 10, + "minWireVersion": 0, + "maxWireVersion": 17 + } + ] + ], + "outcome": { + "servers": { + "a:27017": { + "type": "RSPrimary", + "setName": "rs", + "setVersion": 10, + "electionId": { + "$oid": "000000000000000000000005" + } + }, + "b:27017": { + "type": "RSSecondary", + "setName": "rs", + "setVersion": 10, + "electionId": null + } + }, + "topologyType": "ReplicaSetWithPrimary", + "logicalSessionTimeoutMinutes": null, + "setName": "rs", + "maxSetVersion": 10, + "maxElectionId": { + "$oid": "000000000000000000000005" + } + } + }, + { + "responses": [ + [ + "a:27017", + { + "ok": 1, + "helloOk": true, + "isWritablePrimary": true, + "hosts": [ + "a:27017", + "b:27017" + ], + "setName": "rs", + "setVersion": 11, + "electionId": { + "$oid": "000000000000000000000006" + }, + "minWireVersion": 0, + "maxWireVersion": 17 + } + ] + ], + "outcome": { + "servers": { + "a:27017": { + "type": "RSPrimary", + "setName": "rs", + "setVersion": 11, + "electionId": { + "$oid": "000000000000000000000006" + } + }, + "b:27017": { + "type": "RSSecondary", + "setName": "rs", + "setVersion": 10, + "electionId": null + } + }, + "topologyType": "ReplicaSetWithPrimary", + "logicalSessionTimeoutMinutes": null, + "setName": "rs", + "maxSetVersion": 11, + "maxElectionId": { + "$oid": "000000000000000000000006" + } + } + } + ] +} diff --git a/test/discovery_and_monitoring/unified/backpressure-network-error-fail-replicaset.json b/test/discovery_and_monitoring/unified/backpressure-network-error-fail-replicaset.json new file mode 100644 index 0000000000..ccaea8d135 --- /dev/null +++ b/test/discovery_and_monitoring/unified/backpressure-network-error-fail-replicaset.json @@ -0,0 +1,142 @@ +{ + "description": "backpressure-network-error-fail-replicaset", + "schemaVersion": "1.17", + "runOnRequirements": [ + { + "minServerVersion": "4.4", + "serverless": "forbid", + "topologies": [ + "replicaset" + ] + } + ], + "createEntities": [ + { + "client": { + "id": "setupClient", + "useMultipleMongoses": false + } + } + ], + "initialData": [ + { + "collectionName": "backpressure-network-error-fail", + "databaseName": "sdam-tests", + "documents": [ + { + "_id": 1 + }, + { + "_id": 2 + } + ] + } + ], + "tests": [ + { + "description": "apply backpressure on network connection errors during connection establishment", + "operations": [ + { + "name": "createEntities", + "object": "testRunner", + "arguments": { + "entities": [ + { + "client": { + "id": "client", + "useMultipleMongoses": false, + "observeEvents": [ + "serverDescriptionChangedEvent", + "poolClearedEvent" + ], + "uriOptions": { + "retryWrites": false, + "heartbeatFrequencyMS": 1000000, + "serverMonitoringMode": "poll", + "appname": "backpressureNetworkErrorFailTest" + } + } + }, + { + "database": { + "id": "database", + "client": "client", + "databaseName": "sdam-tests" + } + }, + { + "collection": { + "id": "collection", + "database": "database", + "collectionName": "backpressure-network-error-fail" + } + } + ] + } + }, + { + "name": "waitForEvent", + "object": "testRunner", + "arguments": { + "client": "client", + "event": { + "serverDescriptionChangedEvent": { + "newDescription": { + "type": "RSPrimary" + } + } + }, + "count": 1 + } + }, + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "setupClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "isMaster", + "hello" + ], + "appName": "backpressureNetworkErrorFailTest", + "closeConnection": true + } + } + } + }, + { + "name": "insertMany", + "object": "collection", + "arguments": { + "documents": [ + { + "_id": 3 + }, + { + "_id": 4 + } + ] + }, + "expectError": { + "isError": true, + "errorLabelsContain": [ + "SystemOverloadedError", + "RetryableError" + ] + } + } + ], + "expectEvents": [ + { + "client": "client", + "eventType": "cmap", + "events": [] + } + ] + } + ] +} diff --git a/test/discovery_and_monitoring/unified/backpressure-network-error-fail-single.json b/test/discovery_and_monitoring/unified/backpressure-network-error-fail-single.json new file mode 100644 index 0000000000..c1ff67c732 --- /dev/null +++ b/test/discovery_and_monitoring/unified/backpressure-network-error-fail-single.json @@ -0,0 +1,142 @@ +{ + "description": "backpressure-network-error-fail-single", + "schemaVersion": "1.17", + "runOnRequirements": [ + { + "minServerVersion": "4.4", + "serverless": "forbid", + "topologies": [ + "single" + ] + } + ], + "createEntities": [ + { + "client": { + "id": "setupClient", + "useMultipleMongoses": false + } + } + ], + "initialData": [ + { + "collectionName": "backpressure-network-error-fail", + "databaseName": "sdam-tests", + "documents": [ + { + "_id": 1 + }, + { + "_id": 2 + } + ] + } + ], + "tests": [ + { + "description": "apply backpressure on network connection errors during connection establishment", + "operations": [ + { + "name": "createEntities", + "object": "testRunner", + "arguments": { + "entities": [ + { + "client": { + "id": "client", + "useMultipleMongoses": false, + "observeEvents": [ + "serverDescriptionChangedEvent", + "poolClearedEvent" + ], + "uriOptions": { + "retryWrites": false, + "heartbeatFrequencyMS": 1000000, + "serverMonitoringMode": "poll", + "appname": "backpressureNetworkErrorFailTest" + } + } + }, + { + "database": { + "id": "database", + "client": "client", + "databaseName": "sdam-tests" + } + }, + { + "collection": { + "id": "collection", + "database": "database", + "collectionName": "backpressure-network-error-fail" + } + } + ] + } + }, + { + "name": "waitForEvent", + "object": "testRunner", + "arguments": { + "client": "client", + "event": { + "serverDescriptionChangedEvent": { + "newDescription": { + "type": "Standalone" + } + } + }, + "count": 1 + } + }, + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "setupClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "isMaster", + "hello" + ], + "appName": "backpressureNetworkErrorFailTest", + "closeConnection": true + } + } + } + }, + { + "name": "insertMany", + "object": "collection", + "arguments": { + "documents": [ + { + "_id": 3 + }, + { + "_id": 4 + } + ] + }, + "expectError": { + "isError": true, + "errorLabelsContain": [ + "SystemOverloadedError", + "RetryableError" + ] + } + } + ], + "expectEvents": [ + { + "client": "client", + "eventType": "cmap", + "events": [] + } + ] + } + ] +} diff --git a/test/discovery_and_monitoring/unified/backpressure-network-timeout-fail-replicaset.json b/test/discovery_and_monitoring/unified/backpressure-network-timeout-fail-replicaset.json new file mode 100644 index 0000000000..35b088f422 --- /dev/null +++ b/test/discovery_and_monitoring/unified/backpressure-network-timeout-fail-replicaset.json @@ -0,0 +1,145 @@ +{ + "description": "backpressure-network-timeout-error-replicaset", + "schemaVersion": "1.17", + "runOnRequirements": [ + { + "minServerVersion": "4.4", + "serverless": "forbid", + "topologies": [ + "replicaset" + ] + } + ], + "createEntities": [ + { + "client": { + "id": "setupClient", + "useMultipleMongoses": false + } + } + ], + "initialData": [ + { + "collectionName": "backpressure-network-timeout-error", + "databaseName": "sdam-tests", + "documents": [ + { + "_id": 1 + }, + { + "_id": 2 + } + ] + } + ], + "tests": [ + { + "description": "apply backpressure on network timeout error during connection establishment", + "operations": [ + { + "name": "createEntities", + "object": "testRunner", + "arguments": { + "entities": [ + { + "client": { + "id": "client", + "useMultipleMongoses": false, + "observeEvents": [ + "serverDescriptionChangedEvent", + "poolClearedEvent" + ], + "uriOptions": { + "retryWrites": false, + "heartbeatFrequencyMS": 1000000, + "appname": "backpressureNetworkTimeoutErrorTest", + "serverMonitoringMode": "poll", + "connectTimeoutMS": 250, + "socketTimeoutMS": 250 + } + } + }, + { + "database": { + "id": "database", + "client": "client", + "databaseName": "sdam-tests" + } + }, + { + "collection": { + "id": "collection", + "database": "database", + "collectionName": "backpressure-network-timeout-error" + } + } + ] + } + }, + { + "name": "waitForEvent", + "object": "testRunner", + "arguments": { + "client": "client", + "event": { + "serverDescriptionChangedEvent": { + "newDescription": { + "type": "RSPrimary" + } + } + }, + "count": 1 + } + }, + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "setupClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "isMaster", + "hello" + ], + "blockConnection": true, + "blockTimeMS": 500, + "appName": "backpressureNetworkTimeoutErrorTest" + } + } + } + }, + { + "name": "insertMany", + "object": "collection", + "arguments": { + "documents": [ + { + "_id": 3 + }, + { + "_id": 4 + } + ] + }, + "expectError": { + "isError": true, + "errorLabelsContain": [ + "SystemOverloadedError", + "RetryableError" + ] + } + } + ], + "expectEvents": [ + { + "client": "client", + "eventType": "cmap", + "events": [] + } + ] + } + ] +} diff --git a/test/discovery_and_monitoring/unified/backpressure-network-timeout-fail-single.json b/test/discovery_and_monitoring/unified/backpressure-network-timeout-fail-single.json new file mode 100644 index 0000000000..54b11d4d5b --- /dev/null +++ b/test/discovery_and_monitoring/unified/backpressure-network-timeout-fail-single.json @@ -0,0 +1,145 @@ +{ + "description": "backpressure-network-timeout-error-single", + "schemaVersion": "1.17", + "runOnRequirements": [ + { + "minServerVersion": "4.4", + "serverless": "forbid", + "topologies": [ + "single" + ] + } + ], + "createEntities": [ + { + "client": { + "id": "setupClient", + "useMultipleMongoses": false + } + } + ], + "initialData": [ + { + "collectionName": "backpressure-network-timeout-error", + "databaseName": "sdam-tests", + "documents": [ + { + "_id": 1 + }, + { + "_id": 2 + } + ] + } + ], + "tests": [ + { + "description": "apply backpressure on network timeout error during connection establishment", + "operations": [ + { + "name": "createEntities", + "object": "testRunner", + "arguments": { + "entities": [ + { + "client": { + "id": "client", + "useMultipleMongoses": false, + "observeEvents": [ + "serverDescriptionChangedEvent", + "poolClearedEvent" + ], + "uriOptions": { + "retryWrites": false, + "heartbeatFrequencyMS": 1000000, + "appname": "backpressureNetworkTimeoutErrorTest", + "serverMonitoringMode": "poll", + "connectTimeoutMS": 250, + "socketTimeoutMS": 250 + } + } + }, + { + "database": { + "id": "database", + "client": "client", + "databaseName": "sdam-tests" + } + }, + { + "collection": { + "id": "collection", + "database": "database", + "collectionName": "backpressure-network-timeout-error" + } + } + ] + } + }, + { + "name": "waitForEvent", + "object": "testRunner", + "arguments": { + "client": "client", + "event": { + "serverDescriptionChangedEvent": { + "newDescription": { + "type": "Standalone" + } + } + }, + "count": 1 + } + }, + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "setupClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "isMaster", + "hello" + ], + "blockConnection": true, + "blockTimeMS": 500, + "appName": "backpressureNetworkTimeoutErrorTest" + } + } + } + }, + { + "name": "insertMany", + "object": "collection", + "arguments": { + "documents": [ + { + "_id": 3 + }, + { + "_id": 4 + } + ] + }, + "expectError": { + "isError": true, + "errorLabelsContain": [ + "SystemOverloadedError", + "RetryableError" + ] + } + } + ], + "expectEvents": [ + { + "client": "client", + "eventType": "cmap", + "events": [] + } + ] + } + ] +} diff --git a/test/discovery_and_monitoring/unified/backpressure-server-description-unchanged-on-min-pool-size-population-error.json b/test/discovery_and_monitoring/unified/backpressure-server-description-unchanged-on-min-pool-size-population-error.json new file mode 100644 index 0000000000..f0597124b7 --- /dev/null +++ b/test/discovery_and_monitoring/unified/backpressure-server-description-unchanged-on-min-pool-size-population-error.json @@ -0,0 +1,106 @@ +{ + "description": "backpressure-server-description-unchanged-on-min-pool-size-population-error", + "schemaVersion": "1.17", + "runOnRequirements": [ + { + "minServerVersion": "4.4", + "serverless": "forbid", + "topologies": [ + "single" + ] + } + ], + "createEntities": [ + { + "client": { + "id": "setupClient", + "useMultipleMongoses": false + } + } + ], + "tests": [ + { + "description": "the server description is not changed on handshake error during minPoolSize population", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "setupClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "skip": 1 + }, + "data": { + "failCommands": [ + "hello", + "isMaster" + ], + "appName": "authErrorTest", + "closeConnection": true + } + } + } + }, + { + "name": "createEntities", + "object": "testRunner", + "arguments": { + "entities": [ + { + "client": { + "id": "client", + "observeEvents": [ + "serverDescriptionChangedEvent", + "connectionClosedEvent" + ], + "uriOptions": { + "appname": "authErrorTest", + "minPoolSize": 5, + "maxConnecting": 1, + "serverMonitoringMode": "poll", + "heartbeatFrequencyMS": 1000000 + } + } + } + ] + } + }, + { + "name": "waitForEvent", + "object": "testRunner", + "arguments": { + "client": "client", + "event": { + "serverDescriptionChangedEvent": {} + }, + "count": 1 + } + }, + { + "name": "waitForEvent", + "object": "testRunner", + "arguments": { + "client": "client", + "event": { + "connectionClosedEvent": {} + }, + "count": 1 + } + } + ], + "expectEvents": [ + { + "client": "client", + "eventType": "sdam", + "events": [ + { + "serverDescriptionChangedEvent": {} + } + ] + } + ] + } + ] +} diff --git a/test/performance/async_perf_test.py b/test/performance/async_perf_test.py index 6eb31ea4fe..01a238c64f 100644 --- a/test/performance/async_perf_test.py +++ b/test/performance/async_perf_test.py @@ -206,6 +206,152 @@ async def runTest(self): self.results = results +# RUST COMPARISON MICRO-BENCHMARKS +class RustComparisonTest(PerformanceTest): + """Base class for tests that compare C vs Rust implementations.""" + + implementation: str = "c" # Default to C + + async def asyncSetUp(self): + await super().asyncSetUp() + # Set up environment for C or Rust + if self.implementation == "rust": + os.environ["PYMONGO_USE_RUST"] = "1" + else: + os.environ.pop("PYMONGO_USE_RUST", None) + + # Preserve extension modules when reloading + _cbson = sys.modules.get("bson._cbson") + _rbson = sys.modules.get("bson._rbson") + + # Clear bson modules except extensions + for key in list(sys.modules.keys()): + if key.startswith("bson") and not key.endswith(("_cbson", "_rbson")): + del sys.modules[key] + + # Restore extension modules + if _cbson: + sys.modules["bson._cbson"] = _cbson + if _rbson: + sys.modules["bson._rbson"] = _rbson + + # Re-import bson + import bson as bson_module + + self.bson = bson_module + + +class RustSimpleIntEncodingTest(RustComparisonTest): + """Test encoding of simple integer documents.""" + + async def asyncSetUp(self): + await super().asyncSetUp() + self.document = {"number": 42} + self.data_size = len(encode(self.document)) * NUM_DOCS + + async def do_task(self): + for _ in range(NUM_DOCS): + self.bson.encode(self.document) + + +class TestRustSimpleIntEncodingC(RustSimpleIntEncodingTest, AsyncPyMongoTestCase): + implementation = "c" + + +class TestRustSimpleIntEncodingRust(RustSimpleIntEncodingTest, AsyncPyMongoTestCase): + implementation = "rust" + + +class RustSimpleIntDecodingTest(RustComparisonTest): + """Test decoding of simple integer documents.""" + + async def asyncSetUp(self): + await super().asyncSetUp() + self.document = encode({"number": 42}) + self.data_size = len(self.document) * NUM_DOCS + + async def do_task(self): + for _ in range(NUM_DOCS): + self.bson.decode(self.document) + + +class TestRustSimpleIntDecodingC(RustSimpleIntDecodingTest, AsyncPyMongoTestCase): + implementation = "c" + + +class TestRustSimpleIntDecodingRust(RustSimpleIntDecodingTest, AsyncPyMongoTestCase): + implementation = "rust" + + +class RustMixedTypesEncodingTest(RustComparisonTest): + """Test encoding of documents with mixed types.""" + + async def asyncSetUp(self): + await super().asyncSetUp() + self.document = { + "string": "hello", + "int": 42, + "float": 3.14, + "bool": True, + "null": None, + } + self.data_size = len(encode(self.document)) * NUM_DOCS + + async def do_task(self): + for _ in range(NUM_DOCS): + self.bson.encode(self.document) + + +class TestRustMixedTypesEncodingC(RustMixedTypesEncodingTest, AsyncPyMongoTestCase): + implementation = "c" + + +class TestRustMixedTypesEncodingRust(RustMixedTypesEncodingTest, AsyncPyMongoTestCase): + implementation = "rust" + + +class RustNestedEncodingTest(RustComparisonTest): + """Test encoding of nested documents.""" + + async def asyncSetUp(self): + await super().asyncSetUp() + self.document = {"nested": {"level1": {"level2": {"value": "deep"}}}} + self.data_size = len(encode(self.document)) * NUM_DOCS + + async def do_task(self): + for _ in range(NUM_DOCS): + self.bson.encode(self.document) + + +class TestRustNestedEncodingC(RustNestedEncodingTest, AsyncPyMongoTestCase): + implementation = "c" + + +class TestRustNestedEncodingRust(RustNestedEncodingTest, AsyncPyMongoTestCase): + implementation = "rust" + + +class RustListEncodingTest(RustComparisonTest): + """Test encoding of documents with lists.""" + + async def asyncSetUp(self): + await super().asyncSetUp() + self.document = {"numbers": list(range(10))} + self.data_size = len(encode(self.document)) * NUM_DOCS + + async def do_task(self): + for _ in range(NUM_DOCS): + self.bson.encode(self.document) + + +class TestRustListEncodingC(RustListEncodingTest, AsyncPyMongoTestCase): + implementation = "c" + + +class TestRustListEncodingRust(RustListEncodingTest, AsyncPyMongoTestCase): + implementation = "rust" + + # SINGLE-DOC BENCHMARKS class TestRunCommand(PerformanceTest, AsyncPyMongoTestCase): data_size = len(encode({"hello": True})) * NUM_DOCS diff --git a/test/performance/perf_test.py b/test/performance/perf_test.py index 5688d28d2d..59653f5b20 100644 --- a/test/performance/perf_test.py +++ b/test/performance/perf_test.py @@ -137,7 +137,11 @@ def tearDown(self): # Remove "Test" so that TestFlatEncoding is reported as "FlatEncoding". name = self.__class__.__name__[4:] median = self.percentile(50) - megabytes_per_sec = (self.data_size * self.n_threads) / median / 1000000 + # Protect against division by zero for very fast operations + if median > 0: + megabytes_per_sec = (self.data_size * self.n_threads) / median / 1000000 + else: + megabytes_per_sec = float("inf") print( f"Completed {self.__class__.__name__} {megabytes_per_sec:.3f} MB/s, MEDIAN={self.percentile(50):.3f}s, " f"total time={duration:.3f}s, iterations={len(self.results)}" @@ -273,6 +277,241 @@ class TestFullDecoding(BsonDecodingTest, unittest.TestCase): dataset = "full_bson.json" +# RUST COMPARISON MICRO-BENCHMARKS +# These tests compare C vs Rust implementations for the same BSON operations +class RustComparisonTest(PerformanceTest): + """Base class for tests that compare C vs Rust implementations.""" + + implementation: str = "c" # Default to C + + def setUp(self): + super().setUp() + # Set up environment for C or Rust + if self.implementation == "rust": + os.environ["PYMONGO_USE_RUST"] = "1" + else: + os.environ.pop("PYMONGO_USE_RUST", None) + + # Preserve extension modules when reloading + _cbson = sys.modules.get("bson._cbson") + _rbson = sys.modules.get("bson._rbson") + + # Clear bson modules except extensions + for key in list(sys.modules.keys()): + if key.startswith("bson") and not key.endswith(("_cbson", "_rbson")): + del sys.modules[key] + + # Restore extension modules + if _cbson: + sys.modules["bson._cbson"] = _cbson + if _rbson: + sys.modules["bson._rbson"] = _rbson + + # Re-import bson + import bson as bson_module + + self.bson = bson_module + + +class RustSimpleIntEncodingTest(RustComparisonTest): + """Test encoding of simple integer documents.""" + + def setUp(self): + super().setUp() + self.document = {"number": 42} + self.data_size = len(encode(self.document)) * NUM_DOCS + + def do_task(self): + for _ in range(NUM_DOCS): + self.bson.encode(self.document) + + +class TestRustSimpleIntEncodingC(RustSimpleIntEncodingTest, unittest.TestCase): + implementation = "c" + + +class TestRustSimpleIntEncodingRust(RustSimpleIntEncodingTest, unittest.TestCase): + implementation = "rust" + + +class RustSimpleIntDecodingTest(RustComparisonTest): + """Test decoding of simple integer documents.""" + + def setUp(self): + super().setUp() + self.document = encode({"number": 42}) + self.data_size = len(self.document) * NUM_DOCS + + def do_task(self): + for _ in range(NUM_DOCS): + self.bson.decode(self.document) + + +class TestRustSimpleIntDecodingC(RustSimpleIntDecodingTest, unittest.TestCase): + implementation = "c" + + +class TestRustSimpleIntDecodingRust(RustSimpleIntDecodingTest, unittest.TestCase): + implementation = "rust" + + +class RustMixedTypesEncodingTest(RustComparisonTest): + """Test encoding of documents with mixed types.""" + + def setUp(self): + super().setUp() + self.document = { + "string": "hello", + "int": 42, + "float": 3.14, + "bool": True, + "null": None, + } + self.data_size = len(encode(self.document)) * NUM_DOCS + + def do_task(self): + for _ in range(NUM_DOCS): + self.bson.encode(self.document) + + +class TestRustMixedTypesEncodingC(RustMixedTypesEncodingTest, unittest.TestCase): + implementation = "c" + + +class TestRustMixedTypesEncodingRust(RustMixedTypesEncodingTest, unittest.TestCase): + implementation = "rust" + + +class RustNestedEncodingTest(RustComparisonTest): + """Test encoding of nested documents.""" + + def setUp(self): + super().setUp() + self.document = {"nested": {"level1": {"level2": {"value": "deep"}}}} + self.data_size = len(encode(self.document)) * NUM_DOCS + + def do_task(self): + for _ in range(NUM_DOCS): + self.bson.encode(self.document) + + +class TestRustNestedEncodingC(RustNestedEncodingTest, unittest.TestCase): + implementation = "c" + + +class TestRustNestedEncodingRust(RustNestedEncodingTest, unittest.TestCase): + implementation = "rust" + + +class RustListEncodingTest(RustComparisonTest): + """Test encoding of documents with lists.""" + + def setUp(self): + super().setUp() + self.document = {"numbers": list(range(10))} + self.data_size = len(encode(self.document)) * NUM_DOCS + + def do_task(self): + for _ in range(NUM_DOCS): + self.bson.encode(self.document) + + +class TestRustListEncodingC(RustListEncodingTest, unittest.TestCase): + implementation = "c" + + +class TestRustListEncodingRust(RustListEncodingTest, unittest.TestCase): + implementation = "rust" + + +# Rust comparison versions of standard BSON benchmarks +# These use the same test data as the standard benchmarks but compare C vs Rust + + +class RustFlatEncodingTest(RustComparisonTest, BsonEncodingTest): + """Rust comparison for flat BSON encoding.""" + + dataset = "flat_bson.json" + + +class TestRustFlatEncodingC(RustFlatEncodingTest, unittest.TestCase): + implementation = "c" + + +class TestRustFlatEncodingRust(RustFlatEncodingTest, unittest.TestCase): + implementation = "rust" + + +class RustFlatDecodingTest(RustComparisonTest, BsonDecodingTest): + """Rust comparison for flat BSON decoding.""" + + dataset = "flat_bson.json" + + +class TestRustFlatDecodingC(RustFlatDecodingTest, unittest.TestCase): + implementation = "c" + + +class TestRustFlatDecodingRust(RustFlatDecodingTest, unittest.TestCase): + implementation = "rust" + + +class RustDeepEncodingTest(RustComparisonTest, BsonEncodingTest): + """Rust comparison for deep BSON encoding.""" + + dataset = "deep_bson.json" + + +class TestRustDeepEncodingC(RustDeepEncodingTest, unittest.TestCase): + implementation = "c" + + +class TestRustDeepEncodingRust(RustDeepEncodingTest, unittest.TestCase): + implementation = "rust" + + +class RustDeepDecodingTest(RustComparisonTest, BsonDecodingTest): + """Rust comparison for deep BSON decoding.""" + + dataset = "deep_bson.json" + + +class TestRustDeepDecodingC(RustDeepDecodingTest, unittest.TestCase): + implementation = "c" + + +class TestRustDeepDecodingRust(RustDeepDecodingTest, unittest.TestCase): + implementation = "rust" + + +class RustFullEncodingTest(RustComparisonTest, BsonEncodingTest): + """Rust comparison for full BSON encoding.""" + + dataset = "full_bson.json" + + +class TestRustFullEncodingC(RustFullEncodingTest, unittest.TestCase): + implementation = "c" + + +class TestRustFullEncodingRust(RustFullEncodingTest, unittest.TestCase): + implementation = "rust" + + +class RustFullDecodingTest(RustComparisonTest, BsonDecodingTest): + """Rust comparison for full BSON decoding.""" + + dataset = "full_bson.json" + + +class TestRustFullDecodingC(RustFullDecodingTest, unittest.TestCase): + implementation = "c" + + +class TestRustFullDecodingRust(RustFullDecodingTest, unittest.TestCase): + implementation = "rust" + + # JSON MICRO-BENCHMARKS class JsonEncodingTest(MicroTest): def setUp(self): diff --git a/test/server_selection/server_selection/ReplicaSetNoPrimary/read/DeprioritizedNearestOnlyMatchingTags.json b/test/server_selection/server_selection/ReplicaSetNoPrimary/read/DeprioritizedNearestOnlyMatchingTags.json new file mode 100644 index 0000000000..5a9e8797e4 --- /dev/null +++ b/test/server_selection/server_selection/ReplicaSetNoPrimary/read/DeprioritizedNearestOnlyMatchingTags.json @@ -0,0 +1,62 @@ +{ + "topology_description": { + "type": "ReplicaSetNoPrimary", + "servers": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary", + "tags": { + "data_center": "nyc" + } + }, + { + "address": "c:27017", + "avg_rtt_ms": 100, + "type": "RSSecondary", + "tags": { + "data_center": "tokyo" + } + } + ] + }, + "operation": "read", + "read_preference": { + "mode": "Nearest", + "tag_sets": [ + { + "data_center": "nyc" + } + ] + }, + "deprioritized_servers": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary", + "tags": { + "data_center": "nyc" + } + } + ], + "suitable_servers": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary", + "tags": { + "data_center": "nyc" + } + } + ], + "in_latency_window": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary", + "tags": { + "data_center": "nyc" + } + } + ] +} diff --git a/test/server_selection/server_selection/ReplicaSetNoPrimary/read/DeprioritizedPrimaryPreferredOnlyMatchingTags.json b/test/server_selection/server_selection/ReplicaSetNoPrimary/read/DeprioritizedPrimaryPreferredOnlyMatchingTags.json new file mode 100644 index 0000000000..086532e710 --- /dev/null +++ b/test/server_selection/server_selection/ReplicaSetNoPrimary/read/DeprioritizedPrimaryPreferredOnlyMatchingTags.json @@ -0,0 +1,62 @@ +{ + "topology_description": { + "type": "ReplicaSetNoPrimary", + "servers": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary", + "tags": { + "data_center": "nyc" + } + }, + { + "address": "c:27017", + "avg_rtt_ms": 100, + "type": "RSSecondary", + "tags": { + "data_center": "tokyo" + } + } + ] + }, + "operation": "read", + "read_preference": { + "mode": "PrimaryPreferred", + "tag_sets": [ + { + "data_center": "nyc" + } + ] + }, + "deprioritized_servers": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary", + "tags": { + "data_center": "nyc" + } + } + ], + "suitable_servers": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary", + "tags": { + "data_center": "nyc" + } + } + ], + "in_latency_window": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary", + "tags": { + "data_center": "nyc" + } + } + ] +} diff --git a/test/server_selection/server_selection/ReplicaSetNoPrimary/read/DeprioritizedSecondaryOnlyMatchingTags.json b/test/server_selection/server_selection/ReplicaSetNoPrimary/read/DeprioritizedSecondaryOnlyMatchingTags.json new file mode 100644 index 0000000000..18926581c3 --- /dev/null +++ b/test/server_selection/server_selection/ReplicaSetNoPrimary/read/DeprioritizedSecondaryOnlyMatchingTags.json @@ -0,0 +1,62 @@ +{ + "topology_description": { + "type": "ReplicaSetNoPrimary", + "servers": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary", + "tags": { + "data_center": "nyc" + } + }, + { + "address": "c:27017", + "avg_rtt_ms": 100, + "type": "RSSecondary", + "tags": { + "data_center": "tokyo" + } + } + ] + }, + "operation": "read", + "read_preference": { + "mode": "Secondary", + "tag_sets": [ + { + "data_center": "nyc" + } + ] + }, + "deprioritized_servers": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary", + "tags": { + "data_center": "nyc" + } + } + ], + "suitable_servers": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary", + "tags": { + "data_center": "nyc" + } + } + ], + "in_latency_window": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary", + "tags": { + "data_center": "nyc" + } + } + ] +} diff --git a/test/server_selection/server_selection/ReplicaSetNoPrimary/read/DeprioritizedSecondaryPreferredOnlyMatchingTags.json b/test/server_selection/server_selection/ReplicaSetNoPrimary/read/DeprioritizedSecondaryPreferredOnlyMatchingTags.json new file mode 100644 index 0000000000..ab51345353 --- /dev/null +++ b/test/server_selection/server_selection/ReplicaSetNoPrimary/read/DeprioritizedSecondaryPreferredOnlyMatchingTags.json @@ -0,0 +1,62 @@ +{ + "topology_description": { + "type": "ReplicaSetNoPrimary", + "servers": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary", + "tags": { + "data_center": "nyc" + } + }, + { + "address": "c:27017", + "avg_rtt_ms": 100, + "type": "RSSecondary", + "tags": { + "data_center": "tokyo" + } + } + ] + }, + "operation": "read", + "read_preference": { + "mode": "SecondaryPreferred", + "tag_sets": [ + { + "data_center": "nyc" + } + ] + }, + "deprioritized_servers": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary", + "tags": { + "data_center": "nyc" + } + } + ], + "suitable_servers": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary", + "tags": { + "data_center": "nyc" + } + } + ], + "in_latency_window": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary", + "tags": { + "data_center": "nyc" + } + } + ] +} diff --git a/test/server_selection/server_selection/ReplicaSetWithPrimary/read/DeprioritizedNearestOnlyMatchingTags.json b/test/server_selection/server_selection/ReplicaSetWithPrimary/read/DeprioritizedNearestOnlyMatchingTags.json new file mode 100644 index 0000000000..021f361480 --- /dev/null +++ b/test/server_selection/server_selection/ReplicaSetWithPrimary/read/DeprioritizedNearestOnlyMatchingTags.json @@ -0,0 +1,70 @@ +{ + "topology_description": { + "type": "ReplicaSetWithPrimary", + "servers": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary", + "tags": { + "data_center": "nyc" + } + }, + { + "address": "c:27017", + "avg_rtt_ms": 100, + "type": "RSSecondary", + "tags": { + "data_center": "tokyo" + } + }, + { + "address": "a:27017", + "avg_rtt_ms": 26, + "type": "RSPrimary", + "tags": { + "data_center": "tokyo" + } + } + ] + }, + "operation": "read", + "read_preference": { + "mode": "Nearest", + "tag_sets": [ + { + "data_center": "nyc" + } + ] + }, + "deprioritized_servers": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary", + "tags": { + "data_center": "nyc" + } + } + ], + "suitable_servers": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary", + "tags": { + "data_center": "nyc" + } + } + ], + "in_latency_window": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary", + "tags": { + "data_center": "nyc" + } + } + ] +} diff --git a/test/server_selection/server_selection/ReplicaSetWithPrimary/read/DeprioritizedPrimaryPreferredOnlyMatchingTags.json b/test/server_selection/server_selection/ReplicaSetWithPrimary/read/DeprioritizedPrimaryPreferredOnlyMatchingTags.json new file mode 100644 index 0000000000..4002907b3b --- /dev/null +++ b/test/server_selection/server_selection/ReplicaSetWithPrimary/read/DeprioritizedPrimaryPreferredOnlyMatchingTags.json @@ -0,0 +1,70 @@ +{ + "topology_description": { + "type": "ReplicaSetWithPrimary", + "servers": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary", + "tags": { + "data_center": "tokyo" + } + }, + { + "address": "c:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary", + "tags": { + "data_center": "tokyo" + } + }, + { + "address": "a:27017", + "avg_rtt_ms": 5, + "type": "RSPrimary", + "tags": { + "data_center": "nyc" + } + } + ] + }, + "operation": "read", + "read_preference": { + "mode": "PrimaryPreferred", + "tag_sets": [ + { + "data_center": "nyc" + } + ] + }, + "deprioritized_servers": [ + { + "address": "a:27017", + "avg_rtt_ms": 5, + "type": "RSPrimary", + "tags": { + "data_center": "nyc" + } + } + ], + "suitable_servers": [ + { + "address": "a:27017", + "avg_rtt_ms": 5, + "type": "RSPrimary", + "tags": { + "data_center": "nyc" + } + } + ], + "in_latency_window": [ + { + "address": "a:27017", + "avg_rtt_ms": 5, + "type": "RSPrimary", + "tags": { + "data_center": "nyc" + } + } + ] +} diff --git a/test/server_selection/server_selection/ReplicaSetWithPrimary/read/DeprioritizedSecondaryOnlyMatchingTags.json b/test/server_selection/server_selection/ReplicaSetWithPrimary/read/DeprioritizedSecondaryOnlyMatchingTags.json new file mode 100644 index 0000000000..2de5bdd4c7 --- /dev/null +++ b/test/server_selection/server_selection/ReplicaSetWithPrimary/read/DeprioritizedSecondaryOnlyMatchingTags.json @@ -0,0 +1,70 @@ +{ + "topology_description": { + "type": "ReplicaSetWithPrimary", + "servers": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary", + "tags": { + "data_center": "nyc" + } + }, + { + "address": "c:27017", + "avg_rtt_ms": 100, + "type": "RSSecondary", + "tags": { + "data_center": "tokyo" + } + }, + { + "address": "a:27017", + "avg_rtt_ms": 26, + "type": "RSPrimary", + "tags": { + "data_center": "tokyo" + } + } + ] + }, + "operation": "read", + "read_preference": { + "mode": "Secondary", + "tag_sets": [ + { + "data_center": "nyc" + } + ] + }, + "deprioritized_servers": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary", + "tags": { + "data_center": "nyc" + } + } + ], + "suitable_servers": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary", + "tags": { + "data_center": "nyc" + } + } + ], + "in_latency_window": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary", + "tags": { + "data_center": "nyc" + } + } + ] +} diff --git a/test/server_selection/server_selection/ReplicaSetWithPrimary/read/DeprioritizedSecondaryPreferredOnlyMatchingTags.json b/test/server_selection/server_selection/ReplicaSetWithPrimary/read/DeprioritizedSecondaryPreferredOnlyMatchingTags.json new file mode 100644 index 0000000000..7e1f39a606 --- /dev/null +++ b/test/server_selection/server_selection/ReplicaSetWithPrimary/read/DeprioritizedSecondaryPreferredOnlyMatchingTags.json @@ -0,0 +1,70 @@ +{ + "topology_description": { + "type": "ReplicaSetWithPrimary", + "servers": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary", + "tags": { + "data_center": "nyc" + } + }, + { + "address": "c:27017", + "avg_rtt_ms": 100, + "type": "RSSecondary", + "tags": { + "data_center": "tokyo" + } + }, + { + "address": "a:27017", + "avg_rtt_ms": 5, + "type": "RSPrimary", + "tags": { + "data_center": "tokyo" + } + } + ] + }, + "operation": "read", + "read_preference": { + "mode": "SecondaryPreferred", + "tag_sets": [ + { + "data_center": "nyc" + } + ] + }, + "deprioritized_servers": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary", + "tags": { + "data_center": "nyc" + } + } + ], + "suitable_servers": [ + { + "address": "a:27017", + "avg_rtt_ms": 5, + "type": "RSPrimary", + "tags": { + "data_center": "tokyo" + } + } + ], + "in_latency_window": [ + { + "address": "a:27017", + "avg_rtt_ms": 5, + "type": "RSPrimary", + "tags": { + "data_center": "tokyo" + } + } + ] +} diff --git a/test/server_selection/server_selection/ReplicaSetWithPrimary/read/SecondaryPreferred_empty_tags.json b/test/server_selection/server_selection/ReplicaSetWithPrimary/read/SecondaryPreferred_empty_tags.json new file mode 100644 index 0000000000..8ec8049efe --- /dev/null +++ b/test/server_selection/server_selection/ReplicaSetWithPrimary/read/SecondaryPreferred_empty_tags.json @@ -0,0 +1,41 @@ +{ + "topology_description": { + "type": "ReplicaSetWithPrimary", + "servers": [ + { + "address": "a:27017", + "avg_rtt_ms": 5, + "type": "RSPrimary" + }, + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary" + } + ] + }, + "operation": "read", + "read_preference": { + "mode": "SecondaryPreferred", + "tag_sets": [ + { + "data_center": "nyc" + }, + {} + ] + }, + "suitable_servers": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary" + } + ], + "in_latency_window": [ + { + "address": "b:27017", + "avg_rtt_ms": 5, + "type": "RSSecondary" + } + ] +} diff --git a/test/test_azure_helpers.py b/test/test_azure_helpers.py new file mode 100644 index 0000000000..6fe6451877 --- /dev/null +++ b/test/test_azure_helpers.py @@ -0,0 +1,155 @@ +# Copyright 2026-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for _azure_helpers.py. + +These tests mock urlopen to avoid requiring a live Azure IMDS endpoint. +Integration tests that exercise the real endpoint are gated by environment +variables in test_on_demand_csfle.py and test_auth_oidc.py. +""" + +from __future__ import annotations + +import json +import sys +import unittest +from contextlib import contextmanager +from unittest.mock import MagicMock, patch + +sys.path[0:0] = [""] + +from pymongo._azure_helpers import _get_azure_response + + +@contextmanager +def _mock_urlopen(status: int, body: str): + """Context manager that patches ``urllib.request.urlopen`` with a fake response.""" + mock_response = MagicMock() + mock_response.__enter__ = lambda s: s + mock_response.__exit__ = MagicMock(return_value=False) + mock_response.status = status + mock_response.read.return_value = body.encode("utf8") + + with patch("urllib.request.urlopen", return_value=mock_response) as mock_open: + yield mock_open + + +class TestGetAzureResponse(unittest.TestCase): + def _call(self, resource="https://example.com/", client_id=None, timeout=5): + return _get_azure_response(resource, client_id=client_id, timeout=timeout) + + def test_success_without_client_id(self): + body = json.dumps({"access_token": "tok", "expires_in": "3600"}) + with _mock_urlopen(200, body) as mock_open: + result = self._call() + + self.assertEqual(result["access_token"], "tok") + self.assertEqual(result["expires_in"], "3600") + + # Verify client_id was NOT added to the URL + url = mock_open.call_args[0][0].full_url + self.assertNotIn("client_id", url) + + def test_success_with_client_id(self): + body = json.dumps({"access_token": "tok", "expires_in": "3600"}) + with _mock_urlopen(200, body) as mock_open: + result = self._call(client_id="my-client-id") + + self.assertEqual(result["access_token"], "tok") + url = mock_open.call_args[0][0].full_url + self.assertIn("client_id=my-client-id", url) + + def test_url_contains_resource_and_api_version(self): + body = json.dumps({"access_token": "tok", "expires_in": "3600"}) + with _mock_urlopen(200, body) as mock_open: + self._call(resource="https://test-resource.example.com") + + url = mock_open.call_args[0][0].full_url + self.assertIn("api-version=2018-02-01", url) + self.assertIn("resource=https://test-resource.example.com", url) + + def test_request_headers(self): + body = json.dumps({"access_token": "tok", "expires_in": "3600"}) + with _mock_urlopen(200, body) as mock_open: + self._call() + + request = mock_open.call_args[0][0] + self.assertEqual(request.get_header("Metadata"), "true") + self.assertEqual(request.get_header("Accept"), "application/json") + + def test_urlopen_exception_raises_value_error(self): + with patch("urllib.request.urlopen", side_effect=OSError("connection refused")): + with self.assertRaises(ValueError) as ctx: + self._call() + + self.assertIn("Failed to acquire IMDS access token", str(ctx.exception)) + + def test_non_200_status_raises_value_error(self): + body = json.dumps({"error": "something went wrong"}) + with _mock_urlopen(400, body): + with self.assertRaises(ValueError) as ctx: + self._call() + + self.assertIn("Failed to acquire IMDS access token", str(ctx.exception)) + + def test_non_json_body_raises_value_error(self): + with _mock_urlopen(200, "not-json"): + with self.assertRaises(ValueError) as ctx: + self._call() + + self.assertIn("Azure IMDS response must be in JSON format", str(ctx.exception)) + + def test_missing_access_token_raises_value_error(self): + body = json.dumps({"expires_in": "3600"}) + with _mock_urlopen(200, body): + with self.assertRaises(ValueError) as ctx: + self._call() + + self.assertIn("access_token", str(ctx.exception)) + + def test_missing_expires_in_raises_value_error(self): + body = json.dumps({"access_token": "tok"}) + with _mock_urlopen(200, body): + with self.assertRaises(ValueError) as ctx: + self._call() + + self.assertIn("expires_in", str(ctx.exception)) + + def test_empty_access_token_raises_value_error(self): + body = json.dumps({"access_token": "", "expires_in": "3600"}) + with _mock_urlopen(200, body): + with self.assertRaises(ValueError) as ctx: + self._call() + + self.assertIn("access_token", str(ctx.exception)) + + def test_empty_expires_in_raises_value_error(self): + body = json.dumps({"access_token": "tok", "expires_in": ""}) + with _mock_urlopen(200, body): + with self.assertRaises(ValueError) as ctx: + self._call() + + self.assertIn("expires_in", str(ctx.exception)) + + def test_timeout_passed_to_urlopen(self): + body = json.dumps({"access_token": "tok", "expires_in": "3600"}) + with _mock_urlopen(200, body) as mock_open: + self._call(timeout=42) + + _, kwargs = mock_open.call_args + self.assertEqual(kwargs["timeout"], 42) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_binary.py b/test/test_binary.py index a64aa42280..7046062c54 100644 --- a/test/test_binary.py +++ b/test/test_binary.py @@ -26,7 +26,7 @@ sys.path[0:0] = [""] -from test import IntegrationTest, client_context, unittest +from test import IntegrationTest, client_context, skip_if_rust_bson, unittest import bson from bson import decode, encode @@ -137,6 +137,7 @@ def test_hash(self): self.assertNotEqual(hash(one), hash(two)) self.assertEqual(hash(Binary(b"hello world", 42)), hash(two)) + @skip_if_rust_bson def test_uuid_subtype_4(self): """Only STANDARD should decode subtype 4 as native uuid.""" expected_uuid = uuid.uuid4() @@ -153,6 +154,7 @@ def test_uuid_subtype_4(self): opts = CodecOptions(uuid_representation=UuidRepresentation.STANDARD) self.assertEqual(expected_uuid, decode(encoded, opts)["uuid"]) + @skip_if_rust_bson def test_legacy_java_uuid(self): # Test decoding data = BinaryData.java_data @@ -193,6 +195,7 @@ def test_legacy_java_uuid(self): ) self.assertEqual(data, encoded) + @skip_if_rust_bson def test_legacy_csharp_uuid(self): data = BinaryData.csharp_data diff --git a/test/test_bson.py b/test/test_bson.py index ffc02965fb..d973c4c678 100644 --- a/test/test_bson.py +++ b/test/test_bson.py @@ -1746,9 +1746,11 @@ def test_long_long_to_string(self): try: from bson import _cbson + if _cbson is None: + self.skipTest("C extension not available") _cbson._test_long_long_to_str() except ImportError: - print("_cbson was not imported. Check compilation logs.") + self.skipTest("C extension not available") if __name__ == "__main__": diff --git a/test/test_bson_corpus.py b/test/test_bson_corpus.py index 3370c18bda..86a2457f53 100644 --- a/test/test_bson_corpus.py +++ b/test/test_bson_corpus.py @@ -25,7 +25,7 @@ sys.path[0:0] = [""] -from test import unittest +from test import skip_if_rust_bson, unittest from bson import decode, encode, json_util from bson.binary import STANDARD @@ -96,6 +96,7 @@ loads = functools.partial(json.loads, object_pairs_hook=SON) +@skip_if_rust_bson class TestBSONCorpus(unittest.TestCase): def assertJsonEqual(self, first, second, msg=None): """Fail if the two json strings are unequal. diff --git a/test/test_client.py b/test/test_client.py index 737b3afe60..75d585fdad 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -645,6 +645,38 @@ def test_detected_environment_warning(self, mock_get_hosts): with self.assertWarns(UserWarning): self.simple_client(multi_host) + def test_max_adaptive_retries(self): + # Assert that max adaptive retries defaults to 2. + c = self.simple_client(connect=False) + self.assertEqual(c.options.max_adaptive_retries, 2) + + # Assert that max adaptive retries can be configured through connection or client options. + c = self.simple_client(connect=False, max_adaptive_retries=10) + self.assertEqual(c.options.max_adaptive_retries, 10) + + c = self.simple_client(connect=False, maxAdaptiveRetries=10) + self.assertEqual(c.options.max_adaptive_retries, 10) + + c = self.simple_client(host="mongodb://localhost/?maxAdaptiveRetries=10", connect=False) + self.assertEqual(c.options.max_adaptive_retries, 10) + + def test_enable_overload_retargeting(self): + # Assert that overload retargeting defaults to false. + c = self.simple_client(connect=False) + self.assertFalse(c.options.enable_overload_retargeting) + + # Assert that overload retargeting can be enabled through connection or client options. + c = self.simple_client(connect=False, enable_overload_retargeting=True) + self.assertTrue(c.options.enable_overload_retargeting) + + c = self.simple_client(connect=False, enableOverloadRetargeting=True) + self.assertTrue(c.options.enable_overload_retargeting) + + c = self.simple_client( + host="mongodb://localhost/?enableOverloadRetargeting=true", connect=False + ) + self.assertTrue(c.options.enable_overload_retargeting) + class TestClient(IntegrationTest): def test_multiple_uris(self): @@ -1007,7 +1039,7 @@ def test_list_database_names(self): db_names = self.client.list_database_names() self.assertIn("pymongo_test", db_names) self.assertIn("pymongo_test_mike", db_names) - self.assertEqual(db_names, cmd_names) + self.assertCountEqual(db_names, cmd_names) def test_drop_database(self): with self.assertRaises(TypeError): diff --git a/test/test_client_backpressure.py b/test/test_client_backpressure.py new file mode 100644 index 0000000000..61334a1218 --- /dev/null +++ b/test/test_client_backpressure.py @@ -0,0 +1,310 @@ +# Copyright 2025-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test Client Backpressure spec.""" +from __future__ import annotations + +import os +import pathlib +import sys +from time import perf_counter +from unittest.mock import patch + +from pymongo.common import MAX_ADAPTIVE_RETRIES + +sys.path[0:0] = [""] + +from test import ( + IntegrationTest, + client_context, + unittest, +) +from test.unified_format import generate_test_classes +from test.utils_shared import EventListener, OvertCommandListener + +from pymongo.errors import OperationFailure, PyMongoError + +_IS_SYNC = True + +# Mock a system overload error. +mock_overload_error = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "failCommands": ["find", "insert", "update"], + "errorCode": 462, # IngressRequestRateLimitExceeded + "errorLabels": ["RetryableError", "SystemOverloadedError"], + }, +} + + +def get_mock_overload_error(times: int): + error = mock_overload_error.copy() + error["mode"] = {"times": times} + return error + + +class TestBackpressure(IntegrationTest): + RUN_ON_LOAD_BALANCER = True + + @client_context.require_failCommand_appName + def test_retry_overload_error_command(self): + self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES) + with self.fail_point(fail_many): + self.db.command("find", "t") + + # Ensure command stops retrying after MAX_ADAPTIVE_RETRIES. + fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1) + with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + self.db.command("find", "t") + + self.assertIn("RetryableError", str(error.exception)) + self.assertIn("SystemOverloadedError", str(error.exception)) + + @client_context.require_failCommand_appName + def test_retry_overload_error_find(self): + self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES) + with self.fail_point(fail_many): + self.db.t.find_one() + + # Ensure command stops retrying after MAX_ADAPTIVE_RETRIES. + fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1) + with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + self.db.t.find_one() + + self.assertIn("RetryableError", str(error.exception)) + self.assertIn("SystemOverloadedError", str(error.exception)) + + @client_context.require_failCommand_appName + def test_retry_overload_error_insert_one(self): + # Ensure command is retried on overload error. + fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES) + with self.fail_point(fail_many): + self.db.t.insert_one({"x": 1}) + + # Ensure command stops retrying after MAX_ADAPTIVE_RETRIES. + fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1) + with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + self.db.t.insert_one({"x": 1}) + + self.assertIn("RetryableError", str(error.exception)) + self.assertIn("SystemOverloadedError", str(error.exception)) + + @client_context.require_failCommand_appName + def test_retry_overload_error_update_many(self): + # Even though update_many is not a retryable write operation, it will + # still be retried via the "RetryableError" error label. + self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES) + with self.fail_point(fail_many): + self.db.t.update_many({}, {"$set": {"x": 2}}) + + # Ensure command stops retrying after MAX_ADAPTIVE_RETRIES. + fail_too_many = get_mock_overload_error(MAX_ADAPTIVE_RETRIES + 1) + with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + self.db.t.update_many({}, {"$set": {"x": 2}}) + + self.assertIn("RetryableError", str(error.exception)) + self.assertIn("SystemOverloadedError", str(error.exception)) + + @client_context.require_failCommand_appName + def test_retry_overload_error_getMore(self): + coll = self.db.t + coll.insert_many([{"x": 1} for _ in range(10)]) + + # Ensure command is retried on overload error. + fail_many = { + "configureFailPoint": "failCommand", + "mode": {"times": MAX_ADAPTIVE_RETRIES}, + "data": { + "failCommands": ["getMore"], + "errorCode": 462, # IngressRequestRateLimitExceeded + "errorLabels": ["RetryableError", "SystemOverloadedError"], + }, + } + cursor = coll.find(batch_size=2) + cursor.next() + with self.fail_point(fail_many): + cursor.to_list() + + # Ensure command stops retrying after MAX_ADAPTIVE_RETRIES. + fail_too_many = fail_many.copy() + fail_too_many["mode"] = {"times": MAX_ADAPTIVE_RETRIES + 1} + cursor = coll.find(batch_size=2) + cursor.next() + with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + cursor.to_list() + + self.assertIn("RetryableError", str(error.exception)) + self.assertIn("SystemOverloadedError", str(error.exception)) + + +# Prose tests. +class TestClientBackpressure(IntegrationTest): + listener: EventListener + + @classmethod + def setUpClass(cls) -> None: + cls.listener = OvertCommandListener() + + @client_context.require_connection + def setUp(self) -> None: + super().setUp() + self.listener.reset() + self.app_name = self.__class__.__name__.lower() + self.client = self.rs_or_single_client( + event_listeners=[self.listener], appName=self.app_name + ) + + @patch("random.random") + @client_context.require_failCommand_appName + def test_01_operation_retry_uses_exponential_backoff(self, random_func): + # Drivers should test that retries do not occur immediately when a SystemOverloadedError is encountered. + + # 1. let `client` be a `MongoClient` + client = self.client + + # 2. let `collection` be a collection + collection = client.test.test + + # 3. Now, run transactions without backoff: + + # a. Configure the random number generator used for jitter to always return `0` -- this effectively disables backoff. + random_func.return_value = 0 + + # b. Configure the following failPoint: + fail_point = dict( + mode="alwaysOn", + data=dict( + failCommands=["insert"], + errorCode=2, + errorLabels=["SystemOverloadedError", "RetryableError"], + appName=self.app_name, + ), + ) + with self.fail_point(fail_point): + # c. Execute the following command. Expect that the command errors. Measure the duration of the command execution. + start0 = perf_counter() + with self.assertRaises(OperationFailure): + collection.insert_one({"a": 1}) + end0 = perf_counter() + + # d. Configure the random number generator used for jitter to always return `1`. + random_func.return_value = 1 + + # e. Execute step c again. + start1 = perf_counter() + with self.assertRaises(OperationFailure): + collection.insert_one({"a": 1}) + end1 = perf_counter() + + # f. Compare the times between the two runs. + # The sum of 2 backoffs is 0.3 seconds. There is a 0.3-second window to account for potential variance between the two + # runs. + self.assertTrue(abs((end1 - start1) - (end0 - start0 + 0.3)) < 0.3) + + @client_context.require_failCommand_appName + def test_03_overload_retries_limited(self): + # Drivers should test that overload errors are retried a maximum of two times. + + # 1. Let `client` be a `MongoClient`. + client = self.client + # 2. Let `coll` be a collection. + coll = client.pymongo_test.coll + + # 3. Configure the following failpoint: + failpoint = { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": ["find"], + "errorCode": 462, # IngressRequestRateLimitExceeded + "errorLabels": ["RetryableError", "SystemOverloadedError"], + }, + } + + # 4. Perform a find operation with `coll` that fails. + with self.fail_point(failpoint): + with self.assertRaises(PyMongoError) as error: + coll.find_one({}) + + # 5. Assert that the raised error contains both the `RetryableError` and `SystemOverloadedError` error labels. + self.assertIn("RetryableError", str(error.exception)) + self.assertIn("SystemOverloadedError", str(error.exception)) + + # 6. Assert that the total number of started commands is MAX_ADAPTIVE_RETRIES + 1. + self.assertEqual(len(self.listener.started_events), MAX_ADAPTIVE_RETRIES + 1) + + @client_context.require_failCommand_appName + def test_04_overload_retries_limited_configured(self): + # Drivers should test that overload errors are retried a maximum of maxAdaptiveRetries times. + max_retries = 1 + + # 1. Let `client` be a `MongoClient` with `maxAdaptiveRetries=1` and command event monitoring enabled. + client = self.single_client(maxAdaptiveRetries=max_retries, event_listeners=[self.listener]) + # 2. Let `coll` be a collection. + coll = client.pymongo_test.coll + + # 3. Configure the following failpoint: + failpoint = { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": ["find"], + "errorCode": 462, # IngressRequestRateLimitExceeded + "errorLabels": ["RetryableError", "SystemOverloadedError"], + }, + } + + # 4. Perform a find operation with `coll` that fails. + with self.fail_point(failpoint): + with self.assertRaises(PyMongoError) as error: + coll.find_one({}) + + # 5. Assert that the raised error contains both the `RetryableError` and `SystemOverloadedError` error labels. + self.assertIn("RetryableError", str(error.exception)) + self.assertIn("SystemOverloadedError", str(error.exception)) + + # 6. Assert that the total number of started commands is max_retries + 1. + self.assertEqual(len(self.listener.started_events), max_retries + 1) + + +# Location of JSON test specifications. +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "client-backpressure") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "client-backpressure") + +globals().update( + generate_test_classes( + _TEST_PATH, + module=__name__, + ) +) + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_client_metadata.py b/test/test_client_metadata.py index 5f103f739a..8cdb728ea2 100644 --- a/test/test_client_metadata.py +++ b/test/test_client_metadata.py @@ -219,6 +219,19 @@ def test_duplicate_driver_name_no_op(self): # add same metadata again self.check_metadata_added(client, "Framework", None, None) + def test_handshake_documents_include_backpressure(self): + # Create a `MongoClient` that is configured to record all handshake documents sent to the server as a part of + # connection establishment. + client = self.rs_or_single_client("mongodb://" + self.server.address_string) + + # Send a `ping` command to the server and verify that the command succeeds. This ensure that a connection is + # established on all topologies. Note: MockupDB only supports standalone servers. + client.admin.command("ping") + + # Assert that for every handshake document intercepted: + # the document has a field `backpressure` whose value is `true`. + self.assertEqual(self.handshake_req["backpressure"], True) + if __name__ == "__main__": unittest.main() diff --git a/test/test_cursor.py b/test/test_cursor.py index b63638bfab..e9665e609d 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -30,7 +30,12 @@ sys.path[0:0] = [""] -from test import IntegrationTest, client_context, unittest +from test import ( + IntegrationTest, + client_context, + skip_if_rust_bson, + unittest, +) from test.utils import flaky from test.utils_shared import ( AllowListEventListener, @@ -1498,6 +1503,7 @@ def test_command_cursor_to_list_csot_applied(self): self.assertTrue(ctx.exception.timeout) +@skip_if_rust_bson class TestRawBatchCursor(IntegrationTest): def test_find_raw(self): c = self.db.test @@ -1671,6 +1677,7 @@ def test_monitoring(self): cursor.close() +@skip_if_rust_bson class TestRawBatchCommandCursor(IntegrationTest): def test_aggregate_raw(self): c = self.db.test diff --git a/test/test_custom_types.py b/test/test_custom_types.py index aba6b55119..782287efb9 100644 --- a/test/test_custom_types.py +++ b/test/test_custom_types.py @@ -28,7 +28,12 @@ sys.path[0:0] = [""] -from test import IntegrationTest, client_context, unittest +from test import ( + IntegrationTest, + client_context, + skip_if_rust_bson, + unittest, +) from bson import ( _BUILT_IN_TYPES, @@ -196,12 +201,14 @@ def test_decode_file_iter(self): fileobj.close() +@skip_if_rust_bson class TestCustomPythonBSONTypeToBSONMonolithicCodec(CustomBSONTypeTests, unittest.TestCase): @classmethod def setUpClass(cls): cls.codecopts = DECIMAL_CODECOPTS +@skip_if_rust_bson class TestCustomPythonBSONTypeToBSONMultiplexedCodec(CustomBSONTypeTests, unittest.TestCase): @classmethod def setUpClass(cls): @@ -211,6 +218,7 @@ def setUpClass(cls): cls.codecopts = codec_options +@skip_if_rust_bson class TestBSONFallbackEncoder(unittest.TestCase): def _get_codec_options(self, fallback_encoder): type_registry = TypeRegistry(fallback_encoder=fallback_encoder) @@ -273,6 +281,7 @@ def fallback_encoder(value): self.assertEqual(called_with, [2 << 65]) +@skip_if_rust_bson class TestBSONTypeEnDeCodecs(unittest.TestCase): def test_instantiation(self): msg = "Can't instantiate abstract class" @@ -336,6 +345,7 @@ def test_type_checks(self): self.assertFalse(issubclass(TypeEncoder, TypeDecoder)) +@skip_if_rust_bson class TestBSONCustomTypeEncoderAndFallbackEncoderTandem(unittest.TestCase): TypeA: Any TypeB: Any @@ -432,6 +442,7 @@ def test_infinite_loop_exceeds_max_recursion_depth(self): encode({"x": self.TypeA(100)}, codec_options=codecopts) +@skip_if_rust_bson class TestTypeRegistry(unittest.TestCase): types: Tuple[object, object] codecs: Tuple[Type[TypeCodec], Type[TypeCodec]] @@ -622,6 +633,7 @@ class MyType(pytype): # type: ignore run_test(TypeCodec, {"bson_type": Decimal128, "transform_bson": lambda x: x}) +@skip_if_rust_bson class TestCollectionWCustomType(IntegrationTest): def setUp(self): super().setUp() @@ -744,6 +756,7 @@ def test_find_one_and__w_custom_type_decoder(self): self.assertIsNone(c.find_one()) +@skip_if_rust_bson class TestGridFileCustomType(IntegrationTest): def setUp(self): super().setUp() @@ -910,6 +923,7 @@ def run_test(doc_cls): run_test(doc_cls) +@skip_if_rust_bson class TestCollectionChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin): @client_context.require_change_streams def setUp(self): @@ -927,6 +941,7 @@ def create_targets(self, *args, **kwargs): self.input_target.delete_many({}) +@skip_if_rust_bson class TestDatabaseChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin): @client_context.require_version_min(4, 2, 0) @client_context.require_change_streams @@ -945,6 +960,7 @@ def create_targets(self, *args, **kwargs): self.input_target.insert_one({"data": "dummy"}) +@skip_if_rust_bson class TestClusterChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin): @client_context.require_version_min(4, 2, 0) @client_context.require_change_streams diff --git a/test/test_dbref.py b/test/test_dbref.py index ac2767a1ce..4a6e745249 100644 --- a/test/test_dbref.py +++ b/test/test_dbref.py @@ -22,7 +22,7 @@ sys.path[0:0] = [""] from copy import deepcopy -from test import unittest +from test import skip_if_rust_bson, unittest from bson import decode, encode from bson.dbref import DBRef @@ -129,6 +129,7 @@ def test_dbref_hash(self): # https://github.com/mongodb/specifications/blob/master/source/dbref/dbref.md#test-plan +@skip_if_rust_bson class TestDBRefSpec(unittest.TestCase): def test_decoding_1_2_3(self): doc: Any diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index 8375d63e97..7fb6f312c5 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -25,7 +25,9 @@ from pathlib import Path from test.helpers import ConcurrentRunner from test.utils import flaky +from test.utils_shared import delay +from pymongo.errors import ConnectionFailure from pymongo.operations import _Op from pymongo.server_selectors import writable_server_selector from pymongo.synchronous.pool import Connection @@ -67,7 +69,12 @@ ) from pymongo.hello import Hello, HelloCompat from pymongo.helpers_shared import _check_command_response, _check_write_command_response -from pymongo.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStartedEvent +from pymongo.monitoring import ( + ConnectionCheckOutFailedEvent, + PoolClearedEvent, + ServerHeartbeatFailedEvent, + ServerHeartbeatStartedEvent, +) from pymongo.server_description import SERVER_TYPE, ServerDescription from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology, _ErrorContext @@ -131,6 +138,9 @@ def got_app_error(topology, app_error): raise AssertionError except (AutoReconnect, NotPrimaryError, OperationFailure) as e: if when == "beforeHandshakeCompletes": + # The pool would have added the SystemOverloadedError in this case. + if isinstance(e, AutoReconnect): + e._add_error_label("SystemOverloadedError") completed_handshake = False elif when == "afterHandshakeCompletes": completed_handshake = True @@ -437,6 +447,57 @@ def mock_close(self, reason): Connection.close_conn = original_close +class TestPoolBackpressure(IntegrationTest): + @client_context.require_version_min(7, 0, 0) + def test_connection_pool_is_not_cleared(self): + listener = CMAPListener() + + # Create a client that listens to CMAP events, with maxConnecting=100. + client = self.rs_or_single_client(maxConnecting=100, event_listeners=[listener]) + + # Enable the ingress rate limiter. + client.admin.command( + "setParameter", 1, ingressConnectionEstablishmentRateLimiterEnabled=True + ) + client.admin.command("setParameter", 1, ingressConnectionEstablishmentRatePerSec=20) + client.admin.command("setParameter", 1, ingressConnectionEstablishmentBurstCapacitySecs=1) + client.admin.command("setParameter", 1, ingressConnectionEstablishmentMaxQueueDepth=1) + + # Disable the ingress rate limiter on teardown. + # Sleep for 1 second before disabling to avoid the rate limiter. + def teardown(): + time.sleep(1) + client.admin.command( + "setParameter", 1, ingressConnectionEstablishmentRateLimiterEnabled=False + ) + + self.addCleanup(teardown) + + # Make sure the collection has at least one document. + client.test.test.delete_many({}) + client.test.test.insert_one({}) + + # Run a slow operation to tie up the connection. + def target(): + try: + client.test.test.find_one({"$where": delay(0.1)}) + except ConnectionFailure: + pass + + # Run 100 parallel operations that contend for connections. + tasks = [] + for _ in range(100): + tasks.append(ConcurrentRunner(target=target)) + for t in tasks: + t.start() + for t in tasks: + t.join() + + # Verify there were at least 10 connection checkout failed event but no pool cleared events. + self.assertGreater(len(listener.events_by_type(ConnectionCheckOutFailedEvent)), 10) + self.assertEqual(len(listener.events_by_type(PoolClearedEvent)), 0) + + class TestServerMonitoringMode(IntegrationTest): @client_context.require_no_load_balancer def setUp(self): diff --git a/test/test_encryption.py b/test/test_encryption.py index 88d37cfa0d..af9f2e3df7 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -872,6 +872,8 @@ def test_views_are_prohibited(self): class TestCorpus(EncryptionIntegrationTest): + # PYTHON-5708: Encryption tests sending large payloads fail on some mongocryptd versions. + @client_context.require_version_max(6, 99) @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") def setUp(self): super().setUp() @@ -1048,6 +1050,8 @@ class TestBsonSizeBatches(EncryptionIntegrationTest): client_encrypted: MongoClient listener: OvertCommandListener + # PYTHON-5708: Encryption tests sending large payloads fail on some mongocryptd versions. + @client_context.require_version_max(6, 99) def setUp(self): super().setUp() db = client_context.client.db diff --git a/test/test_pooling.py b/test/test_pooling.py index cb5b206996..95558d00d5 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -511,6 +511,39 @@ def test_connection_timeout_message(self): str(error.exception), ) + @client_context.require_failCommand_appName + def test_pool_backpressure_preserves_existing_connections(self): + client = self.rs_or_single_client() + coll = client.pymongo_test.t + pool = get_pool(client) + coll.insert_many([{"x": 1} for _ in range(10)]) + t = SocketGetter(self.c, pool) + t.start() + while t.state != "connection": + time.sleep(0.1) + + assert not t.sock.conn_closed() + + # Mock a session establishment overload. + mock_connection_fail = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "closeConnection": True, + }, + } + + with self.fail_point(mock_connection_fail): + coll.find_one({}) + + # Make sure the existing socket was not affected. + assert not t.sock.conn_closed() + + # Cleanup + t.release_conn() + t.join() + pool.close() + class TestPoolMaxSize(_TestPoolingBase): def test_max_pool_size(self): diff --git a/test/test_raw_bson.py b/test/test_raw_bson.py index 4d9a3ceb05..27d298e059 100644 --- a/test/test_raw_bson.py +++ b/test/test_raw_bson.py @@ -19,7 +19,12 @@ sys.path[0:0] = [""] -from test import IntegrationTest, client_context, unittest +from test import ( + IntegrationTest, + client_context, + skip_if_rust_bson, + unittest, +) from bson import Code, DBRef, decode, encode from bson.binary import JAVA_LEGACY, Binary, UuidRepresentation @@ -31,6 +36,7 @@ _IS_SYNC = True +@skip_if_rust_bson class TestRawBSONDocument(IntegrationTest): # {'_id': ObjectId('556df68b6e32ab21a95e0785'), # 'name': 'Sherlock', diff --git a/test/test_retryable_reads.py b/test/test_retryable_reads.py index c9f72ae547..9e6aac821c 100644 --- a/test/test_retryable_reads.py +++ b/test/test_retryable_reads.py @@ -259,6 +259,128 @@ def test_retryable_reads_are_retried_on_the_same_implicit_session(self): self.assertEqual(command_docs[0]["lsid"], command_docs[1]["lsid"]) self.assertIsNot(command_docs[0], command_docs[1]) + @client_context.require_replica_set + @client_context.require_secondaries_count(1) + @client_context.require_failCommand_fail_point + @client_context.require_version_min(4, 4, 0) + def test_03_01_retryable_reads_caused_by_overload_errors_are_retried_on_a_different_replicaset_server_when_one_is_available_and_overload_retargeting_is_enabled( + self + ): + listener = OvertCommandListener() + + # 1. Create a client `client` with `retryReads=true`, `readPreference=primaryPreferred`, `enableOverloadRetargeting=True`, and command event monitoring enabled. + client = self.rs_or_single_client( + event_listeners=[listener], + retryReads=True, + readPreference="primaryPreferred", + enableOverloadRetargeting=True, + ) + + # 2. Configure a fail point with the RetryableError and SystemOverloadedError error labels. + command_args = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "failCommands": ["find"], + "errorLabels": ["RetryableError", "SystemOverloadedError"], + "errorCode": 6, + }, + } + set_fail_point(client, command_args) + + # 3. Reset the command event monitor to clear the fail point command from its stored events. + listener.reset() + + # 4. Execute a `find` command with `client`. + client.t.t.find_one({}) + + # 5. Assert that one failed command event and one successful command event occurred. + self.assertEqual(len(listener.failed_events), 1) + self.assertEqual(len(listener.succeeded_events), 1) + + # 6. Assert that both events occurred on different servers. + assert listener.failed_events[0].connection_id != listener.succeeded_events[0].connection_id + + @client_context.require_replica_set + @client_context.require_secondaries_count(1) + @client_context.require_failCommand_fail_point + @client_context.require_version_min(4, 4, 0) + def test_03_02_retryable_reads_caused_by_non_overload_errors_are_retried_on_the_same_replicaset_server( + self + ): + listener = OvertCommandListener() + + # 1. Create a client `client` with `retryReads=true`, `readPreference=primaryPreferred`, and command event monitoring enabled. + client = self.rs_or_single_client( + event_listeners=[listener], retryReads=True, readPreference="primaryPreferred" + ) + + # 2. Configure a fail point with the RetryableError error label. + command_args = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "failCommands": ["find"], + "errorLabels": ["RetryableError"], + "errorCode": 6, + }, + } + set_fail_point(client, command_args) + + # 3. Reset the command event monitor to clear the fail point command from its stored events. + listener.reset() + + # 4. Execute a `find` command with `client`. + client.t.t.find_one({}) + + # 5. Assert that one failed command event and one successful command event occurred. + self.assertEqual(len(listener.failed_events), 1) + self.assertEqual(len(listener.succeeded_events), 1) + + # 6. Assert that both events occurred the same server. + assert listener.failed_events[0].connection_id == listener.succeeded_events[0].connection_id + + @client_context.require_replica_set + @client_context.require_secondaries_count(1) + @client_context.require_failCommand_fail_point + @client_context.require_version_min(4, 4, 0) + def test_03_03_retryable_reads_caused_by_overload_errors_are_retried_on_the_same_replicaset_server_when_one_is_available_and_overload_retargeting_is_disabled( + self + ): + listener = OvertCommandListener() + + # 1. Create a client `client` with `retryReads=true`, `readPreference=primaryPreferred`, and command event monitoring enabled. + client = self.rs_or_single_client( + event_listeners=[listener], + retryReads=True, + readPreference="primaryPreferred", + ) + + # 2. Configure a fail point with the RetryableError and SystemOverloadedError error labels. + command_args = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "failCommands": ["find"], + "errorLabels": ["RetryableError", "SystemOverloadedError"], + "errorCode": 6, + }, + } + set_fail_point(client, command_args) + + # 3. Reset the command event monitor to clear the fail point command from its stored events. + listener.reset() + + # 4. Execute a `find` command with `client`. + client.t.t.find_one({}) + + # 5. Assert that one failed command event and one successful command event occurred. + self.assertEqual(len(listener.failed_events), 1) + self.assertEqual(len(listener.succeeded_events), 1) + + # 6. Assert that both events occurred on the same server. + assert listener.failed_events[0].connection_id == listener.succeeded_events[0].connection_id + if __name__ == "__main__": unittest.main() diff --git a/test/test_retryable_writes.py b/test/test_retryable_writes.py index a74a3e8030..5509083162 100644 --- a/test/test_retryable_writes.py +++ b/test/test_retryable_writes.py @@ -43,14 +43,17 @@ from bson.int64 import Int64 from bson.raw_bson import RawBSONDocument from bson.son import SON +from pymongo import MongoClient from pymongo.errors import ( AutoReconnect, ConnectionFailure, - OperationFailure, + NotPrimaryError, + PyMongoError, ServerSelectionTimeoutError, WriteConcernError, ) from pymongo.monitoring import ( + CommandFailedEvent, CommandSucceededEvent, ConnectionCheckedOutEvent, ConnectionCheckOutFailedEvent, @@ -597,5 +600,186 @@ def raise_connection_err_select_server(*args, **kwargs): self.assertEqual(sent_txn_id, final_txn_id, msg) +class TestErrorPropagationAfterEncounteringMultipleErrors(IntegrationTest): + # Only run against replica sets as mongos does not propagate the NoWritesPerformed label to the drivers. + @client_context.require_replica_set + # Run against server versions 6.0 and above. + @client_context.require_version_min(6, 0) # type: ignore[untyped-decorator] + def setUp(self) -> None: + super().setUp() + self.setup_client = MongoClient(**client_context.default_client_options) + self.addCleanup(self.setup_client.close) + + # TODO: After PYTHON-4595 we can use async event handlers and remove this workaround. + def configure_fail_point_sync(self, command_args, off=False) -> None: + cmd = {"configureFailPoint": "failCommand"} + cmd.update(command_args) + if off: + cmd["mode"] = "off" + cmd.pop("data", None) + self.setup_client.admin.command(cmd) + + def test_01_drivers_return_the_correct_error_when_receiving_only_errors_without_NoWritesPerformed( + self + ) -> None: + # Create a client with retryWrites=true. + listener = OvertCommandListener() + + # Configure a fail point with error code 91 (ShutdownInProgress) with the RetryableError and SystemOverloadedError error labels. + command_args = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "failCommands": ["insert"], + "errorLabels": ["RetryableError", "SystemOverloadedError"], + "errorCode": 91, + }, + } + + # Via the command monitoring CommandFailedEvent, configure a fail point with error code 10107 (NotWritablePrimary). + command_args_inner = { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": ["insert"], + "errorCode": 10107, + "errorLabels": ["RetryableError", "SystemOverloadedError"], + }, + } + + def failed(event: CommandFailedEvent) -> None: + # Configure the 10107 fail point command only if the the failed event is for the 91 error configured in step 2. + if listener.failed_events: + return + assert event.failure["code"] == 91 + self.configure_fail_point_sync(command_args_inner) + self.addCleanup(self.configure_fail_point_sync, {}, off=True) + listener.failed_events.append(event) + + listener.failed = failed + + client = self.rs_client(retryWrites=True, event_listeners=[listener]) + + self.configure_fail_point_sync(command_args) + self.addCleanup(self.configure_fail_point_sync, {}, off=True) + + # Attempt an insertOne operation on any record for any database and collection. + # Expect the insertOne to fail with a server error. + with self.assertRaises(NotPrimaryError) as exc: + client.test.test.insert_one({}) + + # Assert that the error code of the server error is 10107. + assert exc.exception.errors["code"] == 10107 # type:ignore[call-overload] + + def test_02_drivers_return_the_correct_error_when_receiving_only_errors_with_NoWritesPerformed( + self + ) -> None: + # Create a client with retryWrites=true. + listener = OvertCommandListener() + + # Configure a fail point with error code 91 (ShutdownInProgress) with the RetryableError and SystemOverloadedError error labels. + command_args = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "failCommands": ["insert"], + "errorLabels": ["RetryableError", "SystemOverloadedError", "NoWritesPerformed"], + "errorCode": 91, + }, + } + + # Via the command monitoring CommandFailedEvent, configure a fail point with error code `10107` (NotWritablePrimary) + # and a NoWritesPerformed label. + command_args_inner = { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": ["insert"], + "errorCode": 10107, + "errorLabels": ["RetryableError", "SystemOverloadedError", "NoWritesPerformed"], + }, + } + + def failed(event: CommandFailedEvent) -> None: + if listener.failed_events: + return + # Configure the 10107 fail point command only if the the failed event is for the 91 error configured in step 2. + assert event.failure["code"] == 91 + self.configure_fail_point_sync(command_args_inner) + self.addCleanup(self.configure_fail_point_sync, {}, off=True) + listener.failed_events.append(event) + + listener.failed = failed + + client = self.rs_client(retryWrites=True, event_listeners=[listener]) + + self.configure_fail_point_sync(command_args) + self.addCleanup(self.configure_fail_point_sync, {}, off=True) + + # Attempt an insertOne operation on any record for any database and collection. + # Expect the insertOne to fail with a server error. + with self.assertRaises(NotPrimaryError) as exc: + client.test.test.insert_one({}) + + # Assert that the error code of the server error is 91. + assert exc.exception.errors["code"] == 91 # type:ignore[call-overload] + + def test_03_drivers_return_the_correct_error_when_receiving_some_errors_with_NoWritesPerformed_and_some_without_NoWritesPerformed( + self + ) -> None: + # Create a client with retryWrites=true. + listener = OvertCommandListener() + + # Configure the client to listen to CommandFailedEvents. In the attached listener, configure a fail point with error + # code `91` (NotWritablePrimary) and the `NoWritesPerformed`, `RetryableError` and `SystemOverloadedError` labels. + command_args_inner = { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": ["insert"], + "errorLabels": ["RetryableError", "SystemOverloadedError", "NoWritesPerformed"], + "errorCode": 91, + }, + } + + # Configure a fail point with error code `91` (ShutdownInProgress) with the `RetryableError` and + # `SystemOverloadedError` error labels but without the `NoWritesPerformed` error label. + command_args = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "failCommands": ["insert"], + "errorCode": 91, + "errorLabels": ["RetryableError", "SystemOverloadedError"], + }, + } + + def failed(event: CommandFailedEvent) -> None: + # Configure the fail point command only if the failed event is for the 91 error configured in step 2. + if listener.failed_events: + return + assert event.failure["code"] == 91 + self.configure_fail_point_sync(command_args_inner) + self.addCleanup(self.configure_fail_point_sync, {}, off=True) + listener.failed_events.append(event) + + listener.failed = failed + + client = self.rs_client(retryWrites=True, event_listeners=[listener]) + + self.configure_fail_point_sync(command_args) + self.addCleanup(self.configure_fail_point_sync, {}, off=True) + + # Attempt an insertOne operation on any record for any database and collection. + # Expect the insertOne to fail with a server error. + with self.assertRaises(PyMongoError) as exc: + client.test.test.insert_one({}) + + # Assert that the error code of the server error is 91. + assert exc.exception.errors["code"] == 91 + # Assert that the error does not contain the error label `NoWritesPerformed`. + assert "NoWritesPerformed" not in exc.exception.errors["errorLabels"] + + if __name__ == "__main__": unittest.main() diff --git a/test/test_session.py b/test/test_session.py index 40d0a53afb..3963f88da0 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -189,6 +189,52 @@ def _test_ops(self, client, *ops): f"{f.__name__} did not return implicit session to pool", ) + # Explicit bound session + for f, args, kw in ops: + with client.start_session() as s: + with s.bind(): + listener.reset() + s._materialize() + last_use = s._server_session.last_use + start = time.monotonic() + self.assertLessEqual(last_use, start) + # In case "f" modifies its inputs. + args = copy.copy(args) + kw = copy.copy(kw) + f(*args, **kw) + self.assertGreaterEqual(len(listener.started_events), 1) + for event in listener.started_events: + self.assertIn( + "lsid", + event.command, + f"{f.__name__} sent no lsid with {event.command_name}", + ) + + self.assertEqual( + s.session_id, + event.command["lsid"], + f"{f.__name__} sent wrong lsid with {event.command_name}", + ) + + self.assertFalse(s.has_ended) + + self.assertTrue(s.has_ended) + with self.assertRaisesRegex(InvalidOperation, "ended session"): + with s.bind(): + f(*args, **kw) + + # Test a session cannot be used on another client. + with self.client2.start_session() as s: + with s.bind(): + # In case "f" modifies its inputs. + args = copy.copy(args) + kw = copy.copy(kw) + with self.assertRaisesRegex( + InvalidOperation, + "Only the client that created the bound session can perform operations within its context block", + ): + f(*args, **kw) + def test_implicit_sessions_checkout(self): # "To confirm that implicit sessions only allocate their server session after a # successful connection checkout" test from Driver Sessions Spec. @@ -825,6 +871,73 @@ def test_session_not_copyable(self): with client.start_session() as s: self.assertRaises(TypeError, lambda: copy.copy(s)) + def test_nested_session_binding(self): + coll = self.client.pymongo_test.test + coll.insert_one({"x": 1}) + + session1 = self.client.start_session() + session2 = self.client.start_session() + session1._materialize() + session2._materialize() + try: + self.listener.reset() + # Uses implicit session + coll.find_one() + implicit_lsid = self.listener.started_events[0].command.get("lsid") + self.assertIsNotNone(implicit_lsid) + self.assertNotEqual(implicit_lsid, session1.session_id) + self.assertNotEqual(implicit_lsid, session2.session_id) + + with session1.bind(end_session=False): + self.listener.reset() + # Uses bound session1 + coll.find_one() + session1_lsid = self.listener.started_events[0].command.get("lsid") + self.assertEqual(session1_lsid, session1.session_id) + + with session2.bind(end_session=False): + self.listener.reset() + # Uses bound session2 + coll.find_one() + session2_lsid = self.listener.started_events[0].command.get("lsid") + self.assertEqual(session2_lsid, session2.session_id) + self.assertNotEqual(session2_lsid, session1.session_id) + + self.listener.reset() + # Use bound session1 again + coll.find_one() + session1_lsid = self.listener.started_events[0].command.get("lsid") + self.assertEqual(session1_lsid, session1.session_id) + self.assertNotEqual(session1_lsid, session2.session_id) + + self.listener.reset() + # Uses implicit session + coll.find_one() + implicit_lsid = self.listener.started_events[0].command.get("lsid") + self.assertIsNotNone(implicit_lsid) + self.assertNotEqual(implicit_lsid, session1.session_id) + self.assertNotEqual(implicit_lsid, session2.session_id) + + finally: + session1.end_session() + session2.end_session() + + def test_session_binding_end_session(self): + coll = self.client.pymongo_test.test + coll.insert_one({"x": 1}) + + with self.client.start_session().bind() as s1: + coll.find_one() + + self.assertTrue(s1.has_ended) + + with self.client.start_session().bind(end_session=False) as s2: + coll.find_one() + + self.assertFalse(s2.has_ended) + + s2.end_session() + class TestCausalConsistency(UnitTest): listener: SessionTestListener diff --git a/test/test_son.py b/test/test_son.py index 36a6834889..3d2069a4c2 100644 --- a/test/test_son.py +++ b/test/test_son.py @@ -145,13 +145,11 @@ def test_iteration(self): self.assertEqual(ele * 100, test_son[ele]) def test_contains_has(self): - """has_key and __contains__""" + """Test key membership via 'in' and __contains__.""" test_son = SON([(1, 100), (2, 200), (3, 300)]) self.assertIn(1, test_son) self.assertIn(2, test_son, "in failed") self.assertNotIn(22, test_son, "in succeeded when it shouldn't") - self.assertTrue(test_son.has_key(2), "has_key failed") - self.assertFalse(test_son.has_key(22), "has_key succeeded when it shouldn't") def test_clears(self): """Test clear()""" diff --git a/test/test_ssl.py b/test/test_ssl.py index b1e9a65eb5..77bb086ecb 100644 --- a/test/test_ssl.py +++ b/test/test_ssl.py @@ -48,19 +48,11 @@ _HAVE_PYOPENSSL = False try: - # All of these must be available to use PyOpenSSL - import OpenSSL - import requests - import service_identity - - # Ensure service_identity>=18.1 is installed - from service_identity.pyopenssl import verify_ip_address - - from pymongo.ocsp_support import _load_trusted_ca_certs + from pymongo import pyopenssl_context _HAVE_PYOPENSSL = True except ImportError: - _load_trusted_ca_certs = None # type: ignore + pass if HAVE_SSL: @@ -136,11 +128,6 @@ def test_config_ssl(self): def test_use_pyopenssl_when_available(self): self.assertTrue(HAVE_PYSSL) - @unittest.skipUnless(_HAVE_PYOPENSSL, "Cannot test without PyOpenSSL") - def test_load_trusted_ca_certs(self): - trusted_ca_certs = _load_trusted_ca_certs(CA_BUNDLE_PEM) - self.assertEqual(2, len(trusted_ca_certs)) - class TestSSL(IntegrationTest): saved_port: int diff --git a/test/test_transactions.py b/test/test_transactions.py index 01b7ba1553..609105ec21 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -16,9 +16,13 @@ from __future__ import annotations import asyncio +import random import sys +import time from io import BytesIO +from unittest.mock import patch +import pymongo from gridfs.synchronous.grid_file import GridFS, GridFSBucket from pymongo.server_selectors import writable_server_selector from pymongo.synchronous.pool import PoolState @@ -40,7 +44,9 @@ CollectionInvalid, ConfigurationError, ConnectionFailure, + ExecutionTimeout, InvalidOperation, + NetworkTimeout, OperationFailure, ) from pymongo.operations import IndexModel, InsertOne @@ -426,7 +432,7 @@ def set_fail_point(self, command_args): self.configure_fail_point(client, command_args) @client_context.require_transactions - def test_callback_raises_custom_error(self): + def test_1_callback_raises_custom_error(self): class _MyException(Exception): pass @@ -438,7 +444,7 @@ def raise_error(_): s.with_transaction(raise_error) @client_context.require_transactions - def test_callback_returns_value(self): + def test_2_callback_returns_value(self): def callback(_): return "Foo" @@ -466,7 +472,7 @@ def callback(_): self.assertEqual(s.with_transaction(callback), "Foo") @client_context.require_transactions - def test_callback_not_retried_after_timeout(self): + def test_3_1_callback_not_retried_after_timeout(self): listener = OvertCommandListener() client = self.rs_client(event_listeners=[listener]) coll = client[self.db.name].test @@ -487,14 +493,16 @@ def callback(session): listener.reset() with client.start_session() as s: with PatchSessionTimeout(0): - with self.assertRaises(OperationFailure): + with self.assertRaises(NetworkTimeout) as context: s.with_transaction(callback) self.assertEqual(listener.started_command_names(), ["insert", "abortTransaction"]) + # Assert that the timeout error has the same labels as the error it wraps. + self.assertTrue(context.exception.has_error_label("TransientTransactionError")) @client_context.require_test_commands @client_context.require_transactions - def test_callback_not_retried_after_commit_timeout(self): + def test_3_2_callback_not_retried_after_commit_timeout(self): listener = OvertCommandListener() client = self.rs_client(event_listeners=[listener]) coll = client[self.db.name].test @@ -519,14 +527,16 @@ def callback(session): with client.start_session() as s: with PatchSessionTimeout(0): - with self.assertRaises(OperationFailure): + with self.assertRaises(NetworkTimeout) as context: s.with_transaction(callback) self.assertEqual(listener.started_command_names(), ["insert", "commitTransaction"]) + # Assert that the timeout error has the same labels as the error it wraps. + self.assertTrue(context.exception.has_error_label("TransientTransactionError")) @client_context.require_test_commands @client_context.require_transactions - def test_commit_not_retried_after_timeout(self): + def test_3_3_commit_not_retried_after_timeout(self): listener = OvertCommandListener() client = self.rs_client(event_listeners=[listener]) coll = client[self.db.name].test @@ -548,7 +558,7 @@ def callback(session): with client.start_session() as s: with PatchSessionTimeout(0): - with self.assertRaises(ConnectionFailure): + with self.assertRaises(NetworkTimeout) as context: s.with_transaction(callback) # One insert for the callback and two commits (includes the automatic @@ -556,6 +566,40 @@ def callback(session): self.assertEqual( listener.started_command_names(), ["insert", "commitTransaction", "commitTransaction"] ) + # Assert that the timeout error has the same labels as the error it wraps. + self.assertTrue(context.exception.has_error_label("UnknownTransactionCommitResult")) + + @client_context.require_transactions + def test_callback_not_retried_after_csot_timeout(self): + listener = OvertCommandListener() + client = self.rs_client(event_listeners=[listener]) + coll = client[self.db.name].test + + def callback(session): + coll.insert_one({}, session=session) + err: dict = { + "ok": 0, + "errmsg": "Transaction 7819 has been aborted.", + "code": 251, + "codeName": "NoSuchTransaction", + "errorLabels": ["TransientTransactionError"], + } + raise OperationFailure(err["errmsg"], err["code"], err) + + # Create the collection. + coll.insert_one({}) + listener.reset() + with client.start_session() as s: + with pymongo.timeout(1.0): + with self.assertRaises(ExecutionTimeout): + s.with_transaction(callback) + + # At least two attempts: the original and one or more retries. + inserts = len([x for x in listener.started_command_names() if x == "insert"]) + aborts = len([x for x in listener.started_command_names() if x == "abortTransaction"]) + + self.assertGreaterEqual(inserts, 2) + self.assertGreaterEqual(aborts, 2) # Tested here because this supports Motor's convenient transactions API. @client_context.require_transactions @@ -594,6 +638,63 @@ def callback(session): s.with_transaction(callback) self.assertFalse(s.in_transaction) + @client_context.require_test_commands + @client_context.require_transactions + def test_4_retry_backoff_is_enforced(self): + client = client_context.client + coll = client[self.db.name].test + end = start = no_backoff_time = 0 + + # Make random.random always return 0 (no backoff) + with patch.object(random, "random", return_value=0): + # set fail point to trigger transaction failure and trigger backoff + self.set_fail_point( + { + "configureFailPoint": "failCommand", + "mode": {"times": 13}, + "data": { + "failCommands": ["commitTransaction"], + "errorCode": 251, + }, + } + ) + self.addCleanup( + self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"} + ) + + def callback(session): + coll.insert_one({}, session=session) + + start = time.monotonic() + with self.client.start_session() as s: + s.with_transaction(callback) + end = time.monotonic() + no_backoff_time = end - start + + # Make random.random always return 1 (max backoff) + with patch.object(random, "random", return_value=1): + # set fail point to trigger transaction failure and trigger backoff + self.set_fail_point( + { + "configureFailPoint": "failCommand", + "mode": { + "times": 13 + }, # sufficiently high enough such that the time effect of backoff is noticeable + "data": { + "failCommands": ["commitTransaction"], + "errorCode": 251, + }, + } + ) + self.addCleanup( + self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"} + ) + start = time.monotonic() + with self.client.start_session() as s: + s.with_transaction(callback) + end = time.monotonic() + self.assertLess(abs(end - start - (no_backoff_time + 2.2)), 1) # sum of 13 backoffs is 2.2 + class TestOptionsInsideTransactionProse(TransactionsBase): @client_context.require_transactions diff --git a/test/test_typing.py b/test/test_typing.py index 17dc21b4e0..41b475eea0 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -67,7 +67,7 @@ class ImplicitMovie(TypedDict): sys.path[0:0] = [""] -from test import IntegrationTest, PyMongoTestCase, client_context +from test import IntegrationTest, PyMongoTestCase, client_context, skip_if_rust_bson from bson import CodecOptions, ObjectId, decode, decode_all, decode_file_iter, decode_iter, encode from bson.raw_bson import RawBSONDocument @@ -272,6 +272,7 @@ def test_with_options(self) -> None: assert retrieved["other"] == 1 # type:ignore[misc] +@skip_if_rust_bson class TestDecode(unittest.TestCase): def test_bson_decode(self) -> None: doc = {"_id": 1} diff --git a/test/transactions/unified/backpressure-retryable-abort.json b/test/transactions/unified/backpressure-retryable-abort.json new file mode 100644 index 0000000000..3a2a3b4368 --- /dev/null +++ b/test/transactions/unified/backpressure-retryable-abort.json @@ -0,0 +1,342 @@ +{ + "description": "backpressure-retryable-abort", + "schemaVersion": "1.3", + "runOnRequirements": [ + { + "minServerVersion": "4.4", + "topologies": [ + "replicaset", + "sharded", + "load-balanced" + ] + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "useMultipleMongoses": false, + "observeEvents": [ + "commandStartedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "transaction-tests" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "test" + } + }, + { + "session": { + "id": "session0", + "client": "client0" + } + } + ], + "initialData": [ + { + "collectionName": "test", + "databaseName": "transaction-tests", + "documents": [] + } + ], + "tests": [ + { + "description": "abortTransaction retries if backpressure labels are added", + "operations": [ + { + "object": "testRunner", + "name": "failPoint", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "abortTransaction" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 112 + } + } + } + }, + { + "object": "session0", + "name": "startTransaction" + }, + { + "object": "collection0", + "name": "insertOne", + "arguments": { + "session": "session0", + "document": { + "_id": 1 + } + }, + "expectResult": { + "$$unsetOrMatches": { + "insertedId": { + "$$unsetOrMatches": 1 + } + } + } + }, + { + "object": "session0", + "name": "abortTransaction" + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "test", + "documents": [ + { + "_id": 1 + } + ], + "ordered": true, + "readConcern": { + "$$exists": false + }, + "lsid": { + "$$sessionLsid": "session0" + }, + "txnNumber": { + "$numberLong": "1" + }, + "startTransaction": true, + "autocommit": false, + "writeConcern": { + "$$exists": false + } + }, + "commandName": "insert", + "databaseName": "transaction-tests" + } + }, + { + "commandStartedEvent": { + "command": { + "abortTransaction": 1, + "lsid": { + "$$sessionLsid": "session0" + }, + "txnNumber": { + "$numberLong": "1" + }, + "startTransaction": { + "$$exists": false + }, + "autocommit": false, + "writeConcern": { + "$$exists": false + } + }, + "commandName": "abortTransaction", + "databaseName": "admin" + } + }, + { + "commandStartedEvent": { + "command": { + "abortTransaction": 1, + "lsid": { + "$$sessionLsid": "session0" + }, + "txnNumber": { + "$numberLong": "1" + }, + "startTransaction": { + "$$exists": false + }, + "autocommit": false, + "writeConcern": { + "$$exists": false + } + }, + "commandName": "abortTransaction", + "databaseName": "admin" + } + }, + { + "commandStartedEvent": { + "command": { + "abortTransaction": 1, + "lsid": { + "$$sessionLsid": "session0" + }, + "txnNumber": { + "$numberLong": "1" + }, + "startTransaction": { + "$$exists": false + }, + "autocommit": false, + "writeConcern": { + "$$exists": false + } + }, + "commandName": "abortTransaction", + "databaseName": "admin" + } + } + ] + } + ], + "outcome": [ + { + "collectionName": "test", + "databaseName": "transaction-tests", + "documents": [] + } + ] + }, + { + "description": "abortTransaction is retried maxAttempts=2 times if backpressure labels are added", + "operations": [ + { + "object": "testRunner", + "name": "failPoint", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "abortTransaction" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 112 + } + } + } + }, + { + "object": "session0", + "name": "startTransaction" + }, + { + "object": "collection0", + "name": "insertOne", + "arguments": { + "session": "session0", + "document": { + "_id": 1 + } + }, + "expectResult": { + "$$unsetOrMatches": { + "insertedId": { + "$$unsetOrMatches": 1 + } + } + } + }, + { + "object": "session0", + "name": "abortTransaction" + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "test", + "documents": [ + { + "_id": 1 + } + ], + "ordered": true, + "readConcern": { + "$$exists": false + }, + "lsid": { + "$$sessionLsid": "session0" + }, + "txnNumber": { + "$numberLong": "1" + }, + "startTransaction": true, + "autocommit": false, + "writeConcern": { + "$$exists": false + } + }, + "commandName": "insert", + "databaseName": "transaction-tests" + } + }, + { + "commandStartedEvent": { + "command": { + "abortTransaction": 1, + "lsid": { + "$$sessionLsid": "session0" + }, + "txnNumber": { + "$numberLong": "1" + }, + "startTransaction": { + "$$exists": false + }, + "autocommit": false, + "writeConcern": { + "$$exists": false + } + }, + "commandName": "abortTransaction", + "databaseName": "admin" + } + }, + { + "commandStartedEvent": { + "commandName": "abortTransaction" + } + }, + { + "commandStartedEvent": { + "commandName": "abortTransaction" + } + } + ] + } + ], + "outcome": [ + { + "collectionName": "test", + "databaseName": "transaction-tests", + "documents": [] + } + ] + } + ] +} diff --git a/test/transactions/unified/backpressure-retryable-commit.json b/test/transactions/unified/backpressure-retryable-commit.json new file mode 100644 index 0000000000..844ed25ab4 --- /dev/null +++ b/test/transactions/unified/backpressure-retryable-commit.json @@ -0,0 +1,359 @@ +{ + "description": "backpressure-retryable-commit", + "schemaVersion": "1.4", + "runOnRequirements": [ + { + "minServerVersion": "4.4", + "topologies": [ + "sharded", + "replicaset", + "load-balanced" + ] + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "useMultipleMongoses": false, + "observeEvents": [ + "commandStartedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "transaction-tests" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "test" + } + }, + { + "session": { + "id": "session0", + "client": "client0" + } + } + ], + "initialData": [ + { + "collectionName": "test", + "databaseName": "transaction-tests", + "documents": [] + } + ], + "tests": [ + { + "description": "commitTransaction retries if backpressure labels are added", + "runOnRequirements": [ + { + "serverless": "forbid" + } + ], + "operations": [ + { + "object": "testRunner", + "name": "failPoint", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 2 + }, + "data": { + "failCommands": [ + "commitTransaction" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 112 + } + } + } + }, + { + "object": "session0", + "name": "startTransaction" + }, + { + "object": "collection0", + "name": "insertOne", + "arguments": { + "session": "session0", + "document": { + "_id": 1 + } + }, + "expectResult": { + "$$unsetOrMatches": { + "insertedId": { + "$$unsetOrMatches": 1 + } + } + } + }, + { + "object": "session0", + "name": "commitTransaction" + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "test", + "documents": [ + { + "_id": 1 + } + ], + "ordered": true, + "readConcern": { + "$$exists": false + }, + "lsid": { + "$$sessionLsid": "session0" + }, + "txnNumber": { + "$numberLong": "1" + }, + "startTransaction": true, + "autocommit": false, + "writeConcern": { + "$$exists": false + } + }, + "commandName": "insert", + "databaseName": "transaction-tests" + } + }, + { + "commandStartedEvent": { + "command": { + "commitTransaction": 1, + "lsid": { + "$$sessionLsid": "session0" + }, + "txnNumber": { + "$numberLong": "1" + }, + "startTransaction": { + "$$exists": false + }, + "autocommit": false, + "writeConcern": { + "$$exists": false + } + }, + "commandName": "commitTransaction", + "databaseName": "admin" + } + }, + { + "commandStartedEvent": { + "command": { + "commitTransaction": 1, + "lsid": { + "$$sessionLsid": "session0" + }, + "txnNumber": { + "$numberLong": "1" + }, + "startTransaction": { + "$$exists": false + }, + "autocommit": false, + "writeConcern": { + "$$exists": false + } + }, + "commandName": "commitTransaction", + "databaseName": "admin" + } + }, + { + "commandStartedEvent": { + "command": { + "commitTransaction": 1, + "lsid": { + "$$sessionLsid": "session0" + }, + "txnNumber": { + "$numberLong": "1" + }, + "startTransaction": { + "$$exists": false + }, + "autocommit": false, + "writeConcern": { + "$$exists": false + } + }, + "commandName": "commitTransaction", + "databaseName": "admin" + } + } + ] + } + ], + "outcome": [ + { + "collectionName": "test", + "databaseName": "transaction-tests", + "documents": [ + { + "_id": 1 + } + ] + } + ] + }, + { + "description": "commitTransaction is retried maxAttempts=2 times if backpressure labels are added", + "runOnRequirements": [ + { + "serverless": "forbid" + } + ], + "operations": [ + { + "object": "testRunner", + "name": "failPoint", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "commitTransaction" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 112 + } + } + } + }, + { + "object": "session0", + "name": "startTransaction" + }, + { + "object": "collection0", + "name": "insertOne", + "arguments": { + "session": "session0", + "document": { + "_id": 1 + } + }, + "expectResult": { + "$$unsetOrMatches": { + "insertedId": { + "$$unsetOrMatches": 1 + } + } + } + }, + { + "object": "session0", + "name": "commitTransaction", + "expectError": { + "isError": true + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "test", + "documents": [ + { + "_id": 1 + } + ], + "ordered": true, + "readConcern": { + "$$exists": false + }, + "lsid": { + "$$sessionLsid": "session0" + }, + "txnNumber": { + "$numberLong": "1" + }, + "startTransaction": true, + "autocommit": false, + "writeConcern": { + "$$exists": false + } + }, + "commandName": "insert", + "databaseName": "transaction-tests" + } + }, + { + "commandStartedEvent": { + "command": { + "commitTransaction": 1, + "lsid": { + "$$sessionLsid": "session0" + }, + "txnNumber": { + "$numberLong": "1" + }, + "startTransaction": { + "$$exists": false + }, + "autocommit": false, + "writeConcern": { + "$$exists": false + } + }, + "commandName": "commitTransaction", + "databaseName": "admin" + } + }, + { + "commandStartedEvent": { + "commandName": "commitTransaction" + } + }, + { + "commandStartedEvent": { + "commandName": "commitTransaction" + } + } + ] + } + ], + "outcome": [ + { + "collectionName": "test", + "databaseName": "transaction-tests", + "documents": [] + } + ] + } + ] +} diff --git a/test/transactions/unified/backpressure-retryable-reads.json b/test/transactions/unified/backpressure-retryable-reads.json new file mode 100644 index 0000000000..a859ec4bda --- /dev/null +++ b/test/transactions/unified/backpressure-retryable-reads.json @@ -0,0 +1,313 @@ +{ + "description": "backpressure-retryable-reads", + "schemaVersion": "1.3", + "runOnRequirements": [ + { + "minServerVersion": "4.4", + "topologies": [ + "replicaset", + "sharded", + "load-balanced" + ] + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "useMultipleMongoses": false, + "observeEvents": [ + "commandStartedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "transaction-tests" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "test" + } + }, + { + "session": { + "id": "session0", + "client": "client0" + } + } + ], + "initialData": [ + { + "collectionName": "test", + "databaseName": "transaction-tests", + "documents": [] + } + ], + "tests": [ + { + "description": "reads are retried if backpressure labels are added", + "operations": [ + { + "object": "session0", + "name": "startTransaction" + }, + { + "object": "collection0", + "name": "insertOne", + "arguments": { + "session": "session0", + "document": { + "_id": 1 + } + }, + "expectResult": { + "$$unsetOrMatches": { + "insertedId": { + "$$unsetOrMatches": 1 + } + } + } + }, + { + "object": "testRunner", + "name": "failPoint", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "find" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 112 + } + } + } + }, + { + "object": "collection0", + "name": "find", + "arguments": { + "filter": {}, + "session": "session0" + } + }, + { + "object": "session0", + "name": "commitTransaction" + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "test", + "documents": [ + { + "_id": 1 + } + ], + "ordered": true, + "readConcern": { + "$$exists": false + }, + "lsid": { + "$$sessionLsid": "session0" + }, + "txnNumber": { + "$numberLong": "1" + }, + "startTransaction": true, + "autocommit": false, + "writeConcern": { + "$$exists": false + } + }, + "commandName": "insert", + "databaseName": "transaction-tests" + } + }, + { + "commandStartedEvent": { + "command": { + "find": "test", + "readConcern": { + "$$exists": false + }, + "lsid": { + "$$sessionLsid": "session0" + }, + "txnNumber": { + "$numberLong": "1" + }, + "autocommit": false, + "writeConcern": { + "$$exists": false + } + }, + "commandName": "find", + "databaseName": "transaction-tests" + } + }, + { + "commandStartedEvent": { + "command": { + "find": "test", + "readConcern": { + "$$exists": false + }, + "lsid": { + "$$sessionLsid": "session0" + }, + "txnNumber": { + "$numberLong": "1" + }, + "autocommit": false, + "writeConcern": { + "$$exists": false + } + }, + "commandName": "find", + "databaseName": "transaction-tests" + } + }, + { + "commandStartedEvent": { + "command": { + "abortTransaction": { + "$$exists": false + }, + "lsid": { + "$$sessionLsid": "session0" + }, + "txnNumber": { + "$numberLong": "1" + }, + "startTransaction": { + "$$exists": false + }, + "autocommit": false, + "writeConcern": { + "$$exists": false + } + }, + "commandName": "commitTransaction", + "databaseName": "admin" + } + } + ] + } + ] + }, + { + "description": "reads are retried maxAttempts=2 times if backpressure labels are added", + "operations": [ + { + "object": "session0", + "name": "startTransaction" + }, + { + "object": "collection0", + "name": "insertOne", + "arguments": { + "session": "session0", + "document": { + "_id": 1 + } + }, + "expectResult": { + "$$unsetOrMatches": { + "insertedId": { + "$$unsetOrMatches": 1 + } + } + } + }, + { + "object": "testRunner", + "name": "failPoint", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "find" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 112 + } + } + } + }, + { + "object": "collection0", + "name": "find", + "arguments": { + "filter": {}, + "session": "session0" + }, + "expectError": { + "isError": true + } + }, + { + "object": "session0", + "name": "abortTransaction" + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "commandName": "find" + } + }, + { + "commandStartedEvent": { + "commandName": "find" + } + }, + { + "commandStartedEvent": { + "commandName": "find" + } + }, + { + "commandStartedEvent": { + "commandName": "abortTransaction" + } + } + ] + } + ] + } + ] +} diff --git a/test/transactions/unified/backpressure-retryable-writes.json b/test/transactions/unified/backpressure-retryable-writes.json new file mode 100644 index 0000000000..6cbf450e5f --- /dev/null +++ b/test/transactions/unified/backpressure-retryable-writes.json @@ -0,0 +1,439 @@ +{ + "description": "backpressure-retryable-writes", + "schemaVersion": "1.3", + "runOnRequirements": [ + { + "minServerVersion": "4.4", + "topologies": [ + "replicaset", + "sharded", + "load-balanced" + ] + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "useMultipleMongoses": false, + "observeEvents": [ + "commandStartedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "transaction-tests" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "test" + } + }, + { + "session": { + "id": "session0", + "client": "client0" + } + } + ], + "initialData": [ + { + "collectionName": "test", + "databaseName": "transaction-tests", + "documents": [] + } + ], + "tests": [ + { + "description": "writes are retried if backpressure labels are added", + "operations": [ + { + "object": "session0", + "name": "startTransaction" + }, + { + "object": "collection0", + "name": "insertOne", + "arguments": { + "session": "session0", + "document": { + "_id": 1 + } + }, + "expectResult": { + "$$unsetOrMatches": { + "insertedId": { + "$$unsetOrMatches": 1 + } + } + } + }, + { + "object": "testRunner", + "name": "failPoint", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "insert" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 112 + } + } + } + }, + { + "object": "collection0", + "name": "insertOne", + "arguments": { + "session": "session0", + "document": { + "_id": 2 + } + } + }, + { + "object": "session0", + "name": "commitTransaction" + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "test", + "documents": [ + { + "_id": 1 + } + ], + "ordered": true, + "readConcern": { + "$$exists": false + }, + "lsid": { + "$$sessionLsid": "session0" + }, + "txnNumber": { + "$numberLong": "1" + }, + "startTransaction": true, + "autocommit": false, + "writeConcern": { + "$$exists": false + } + }, + "commandName": "insert", + "databaseName": "transaction-tests" + } + }, + { + "commandStartedEvent": { + "command": { + "insert": "test", + "documents": [ + { + "_id": 2 + } + ], + "ordered": true, + "readConcern": { + "$$exists": false + }, + "lsid": { + "$$sessionLsid": "session0" + }, + "txnNumber": { + "$numberLong": "1" + }, + "autocommit": false, + "writeConcern": { + "$$exists": false + } + }, + "commandName": "insert", + "databaseName": "transaction-tests" + } + }, + { + "commandStartedEvent": { + "command": { + "insert": "test", + "documents": [ + { + "_id": 2 + } + ], + "ordered": true, + "readConcern": { + "$$exists": false + }, + "lsid": { + "$$sessionLsid": "session0" + }, + "txnNumber": { + "$numberLong": "1" + }, + "autocommit": false, + "writeConcern": { + "$$exists": false + } + }, + "commandName": "insert", + "databaseName": "transaction-tests" + } + }, + { + "commandStartedEvent": { + "command": { + "abortTransaction": { + "$$exists": false + }, + "lsid": { + "$$sessionLsid": "session0" + }, + "txnNumber": { + "$numberLong": "1" + }, + "startTransaction": { + "$$exists": false + }, + "autocommit": false, + "writeConcern": { + "$$exists": false + } + }, + "commandName": "commitTransaction", + "databaseName": "admin" + } + } + ] + } + ], + "outcome": [ + { + "collectionName": "test", + "databaseName": "transaction-tests", + "documents": [ + { + "_id": 1 + }, + { + "_id": 2 + } + ] + } + ] + }, + { + "description": "writes are retried maxAttempts=2 times if backpressure labels are added", + "operations": [ + { + "object": "session0", + "name": "startTransaction" + }, + { + "object": "collection0", + "name": "insertOne", + "arguments": { + "session": "session0", + "document": { + "_id": 1 + } + }, + "expectResult": { + "$$unsetOrMatches": { + "insertedId": { + "$$unsetOrMatches": 1 + } + } + } + }, + { + "object": "testRunner", + "name": "failPoint", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "insert" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 112 + } + } + } + }, + { + "object": "collection0", + "name": "insertOne", + "arguments": { + "session": "session0", + "document": { + "_id": 2 + } + }, + "expectError": { + "isError": true + } + }, + { + "object": "session0", + "name": "abortTransaction" + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "commandName": "abortTransaction" + } + } + ] + } + ], + "outcome": [ + { + "collectionName": "test", + "databaseName": "transaction-tests", + "documents": [] + } + ] + }, + { + "description": "retry succeeds if backpressure labels are added to the first operation in a transaction", + "operations": [ + { + "object": "session0", + "name": "startTransaction" + }, + { + "object": "testRunner", + "name": "failPoint", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "insert" + ], + "errorLabels": [ + "RetryableError", + "SystemOverloadedError" + ], + "errorCode": 112 + } + } + } + }, + { + "object": "collection0", + "name": "insertOne", + "arguments": { + "session": "session0", + "document": { + "_id": 2 + } + } + }, + { + "object": "session0", + "name": "abortTransaction" + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "startTransaction": true + }, + "commandName": "insert", + "databaseName": "transaction-tests" + } + }, + { + "commandStartedEvent": { + "command": { + "startTransaction": true + }, + "commandName": "insert", + "databaseName": "transaction-tests" + } + }, + { + "commandStartedEvent": { + "command": { + "startTransaction": { + "$$exists": false + } + }, + "commandName": "abortTransaction", + "databaseName": "admin" + } + } + ] + } + ], + "outcome": [ + { + "collectionName": "test", + "databaseName": "transaction-tests", + "documents": [] + } + ] + } + ] +} diff --git a/test/unified_format.py b/test/unified_format.py index 9aee287256..5516a7adf1 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -1451,11 +1451,6 @@ def verify_outcome(self, spec): self.assertListEqual(sorted_expected_documents, actual_documents) def run_scenario(self, spec, uri=None): - # Kill all sessions before and after each test to prevent an open - # transaction (from a test failure) from blocking collection/database - # operations during test set up and tear down. - self.kill_all_sessions() - # Handle flaky tests. flaky_tests = [ ("PYTHON-5170", ".*test_discovery_and_monitoring.*"), @@ -1491,6 +1486,15 @@ def _run_scenario(self, spec, uri=None): if skip_reason is not None: raise unittest.SkipTest(f"{skip_reason}") + # Kill all sessions after each test with transactions to prevent an open + # transaction (from a test failure) from blocking collection/database + # operations during test set up and tear down. + for op in spec["operations"]: + name = op["name"] + if name == "startTransaction" or name == "withTransaction": + self.addCleanup(self.kill_all_sessions) + break + # process createEntities self._uri = uri self.entity_map = EntityMapUtil(self) diff --git a/test/uri_options/client-backpressure-options.json b/test/uri_options/client-backpressure-options.json new file mode 100644 index 0000000000..3e501d1f4c --- /dev/null +++ b/test/uri_options/client-backpressure-options.json @@ -0,0 +1,66 @@ +{ + "tests": [ + { + "description": "maxAdaptiveRetries is parsed correctly", + "uri": "mongodb://example.com/?maxAdaptiveRetries=3", + "valid": true, + "warning": false, + "hosts": null, + "auth": null, + "options": { + "maxAdaptiveRetries": 3 + } + }, + { + "description": "maxAdaptiveRetries=0 is parsed correctly", + "uri": "mongodb://example.com/?maxAdaptiveRetries=0", + "valid": true, + "warning": false, + "hosts": null, + "auth": null, + "options": { + "maxAdaptiveRetries": 0 + } + }, + { + "description": "maxAdaptiveRetries with invalid value causes a warning", + "uri": "mongodb://example.com/?maxAdaptiveRetries=-5", + "valid": true, + "warning": true, + "hosts": null, + "auth": null, + "options": null + }, + { + "description": "enableOverloadRetargeting is parsed correctly", + "uri": "mongodb://example.com/?enableOverloadRetargeting=true", + "valid": true, + "warning": false, + "hosts": null, + "auth": null, + "options": { + "enableOverloadRetargeting": true + } + }, + { + "description": "enableOverloadRetargeting=false is parsed correctly", + "uri": "mongodb://example.com/?enableOverloadRetargeting=false", + "valid": true, + "warning": false, + "hosts": null, + "auth": null, + "options": { + "enableOverloadRetargeting": false + } + }, + { + "description": "enableOverloadRetargeting with invalid value causes a warning", + "uri": "mongodb://example.com/?enableOverloadRetargeting=invalid", + "valid": true, + "warning": true, + "hosts": null, + "auth": null, + "options": null + } + ] +} diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 9bf155e8f3..f4c1c6bfcc 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -16,43 +16,13 @@ from __future__ import annotations import asyncio -import functools import os -import time -import unittest -from collections import abc -from inspect import iscoroutinefunction -from test import IntegrationTest, client_context, client_knobs +from test import client_context from test.helpers import ConcurrentRunner -from test.utils_shared import ( - CMAPListener, - CompareType, - EventListener, - OvertCommandListener, - ScenarioDict, - ServerAndTopologyEventListener, - camel_to_snake, - camel_to_snake_args, - parse_spec_options, - prepare_spec_arguments, -) -from typing import List - -from bson import ObjectId, decode, encode, json_util -from bson.binary import Binary -from bson.int64 import Int64 -from bson.son import SON -from gridfs import GridFSBucket -from gridfs.synchronous.grid_file import GridFSBucket -from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError +from test.utils_shared import ScenarioDict + +from bson import json_util from pymongo.lock import _cond_wait, _create_condition, _create_lock -from pymongo.read_concern import ReadConcern -from pymongo.read_preferences import ReadPreference -from pymongo.results import BulkWriteResult, _WriteResult -from pymongo.synchronous import client_session -from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.cursor import Cursor -from pymongo.write_concern import WriteConcern _IS_SYNC = True @@ -219,595 +189,3 @@ def create_tests(self): self._create_tests() else: asyncio.run(self._create_tests()) - - -class SpecRunner(IntegrationTest): - mongos_clients: List - knobs: client_knobs - listener: EventListener - - def setUp(self) -> None: - super().setUp() - self.mongos_clients = [] - - # Speed up the tests by decreasing the heartbeat frequency. - self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - self.knobs.enable() - self.targets = {} - self.listener = None # type: ignore - self.pool_listener = None - self.server_listener = None - self.maxDiff = None - - def tearDown(self) -> None: - self.knobs.disable() - - def set_fail_point(self, command_args): - clients = self.mongos_clients if self.mongos_clients else [self.client] - for client in clients: - self.configure_fail_point(client, command_args) - - def targeted_fail_point(self, session, fail_point): - """Run the targetedFailPoint test operation. - - Enable the fail point on the session's pinned mongos. - """ - clients = {c.address: c for c in self.mongos_clients} - client = clients[session._pinned_address] - self.configure_fail_point(client, fail_point) - self.addCleanup(self.set_fail_point, {"mode": "off"}) - - def assert_session_pinned(self, session): - """Run the assertSessionPinned test operation. - - Assert that the given session is pinned. - """ - self.assertIsNotNone(session._transaction.pinned_address) - - def assert_session_unpinned(self, session): - """Run the assertSessionUnpinned test operation. - - Assert that the given session is not pinned. - """ - self.assertIsNone(session._pinned_address) - self.assertIsNone(session._transaction.pinned_address) - - def assert_collection_exists(self, database, collection): - """Run the assertCollectionExists test operation.""" - db = self.client[database] - self.assertIn(collection, db.list_collection_names()) - - def assert_collection_not_exists(self, database, collection): - """Run the assertCollectionNotExists test operation.""" - db = self.client[database] - self.assertNotIn(collection, db.list_collection_names()) - - def assert_index_exists(self, database, collection, index): - """Run the assertIndexExists test operation.""" - coll = self.client[database][collection] - self.assertIn(index, [doc["name"] for doc in coll.list_indexes()]) - - def assert_index_not_exists(self, database, collection, index): - """Run the assertIndexNotExists test operation.""" - coll = self.client[database][collection] - self.assertNotIn(index, [doc["name"] for doc in coll.list_indexes()]) - - def wait(self, ms): - """Run the "wait" test operation.""" - time.sleep(ms / 1000.0) - - def assertErrorLabelsContain(self, exc, expected_labels): - labels = [l for l in expected_labels if exc.has_error_label(l)] - self.assertEqual(labels, expected_labels) - - def assertErrorLabelsOmit(self, exc, omit_labels): - for label in omit_labels: - self.assertFalse( - exc.has_error_label(label), msg=f"error labels should not contain {label}" - ) - - def kill_all_sessions(self): - clients = self.mongos_clients if self.mongos_clients else [self.client] - for client in clients: - try: - client.admin.command("killAllSessions", []) - except (OperationFailure, AutoReconnect): - # "operation was interrupted" by killing the command's - # own session. - # On 8.0+ killAllSessions sometimes returns a network error. - pass - - def check_command_result(self, expected_result, result): - # Only compare the keys in the expected result. - filtered_result = {} - for key in expected_result: - try: - filtered_result[key] = result[key] - except KeyError: - pass - self.assertEqual(filtered_result, expected_result) - - # TODO: factor the following function with test_crud.py. - def check_result(self, expected_result, result): - if isinstance(result, _WriteResult): - for res in expected_result: - prop = camel_to_snake(res) - # SPEC-869: Only BulkWriteResult has upserted_count. - if prop == "upserted_count" and not isinstance(result, BulkWriteResult): - if result.upserted_id is not None: - upserted_count = 1 - else: - upserted_count = 0 - self.assertEqual(upserted_count, expected_result[res], prop) - elif prop == "inserted_ids": - # BulkWriteResult does not have inserted_ids. - if isinstance(result, BulkWriteResult): - self.assertEqual(len(expected_result[res]), result.inserted_count) - else: - # InsertManyResult may be compared to [id1] from the - # crud spec or {"0": id1} from the retryable write spec. - ids = expected_result[res] - if isinstance(ids, dict): - ids = [ids[str(i)] for i in range(len(ids))] - - self.assertEqual(ids, result.inserted_ids, prop) - elif prop == "upserted_ids": - # Convert indexes from strings to integers. - ids = expected_result[res] - expected_ids = {} - for str_index in ids: - expected_ids[int(str_index)] = ids[str_index] - self.assertEqual(expected_ids, result.upserted_ids, prop) - else: - self.assertEqual(getattr(result, prop), expected_result[res], prop) - - return True - else: - - def _helper(expected_result, result): - if isinstance(expected_result, abc.Mapping): - for i in expected_result.keys(): - self.assertEqual(expected_result[i], result[i]) - - elif isinstance(expected_result, list): - for i, k in zip(expected_result, result): - _helper(i, k) - else: - self.assertEqual(expected_result, result) - - _helper(expected_result, result) - return None - - def get_object_name(self, op): - """Allow subclasses to override handling of 'object' - - Transaction spec says 'object' is required. - """ - return op["object"] - - @staticmethod - def parse_options(opts): - return parse_spec_options(opts) - - def run_operation(self, sessions, collection, operation): - original_collection = collection - name = camel_to_snake(operation["name"]) - if name == "run_command": - name = "command" - elif name == "download_by_name": - name = "open_download_stream_by_name" - elif name == "download": - name = "open_download_stream" - elif name == "map_reduce": - self.skipTest("PyMongo does not support mapReduce") - elif name == "count": - self.skipTest("PyMongo does not support count") - - database = collection.database - collection = database.get_collection(collection.name) - if "collectionOptions" in operation: - collection = collection.with_options( - **self.parse_options(operation["collectionOptions"]) - ) - - object_name = self.get_object_name(operation) - if object_name == "gridfsbucket": - # Only create the GridFSBucket when we need it (for the gridfs - # retryable reads tests). - obj = GridFSBucket(database, bucket_name=collection.name) - else: - objects = { - "client": database.client, - "database": database, - "collection": collection, - "testRunner": self, - } - objects.update(sessions) - obj = objects[object_name] - - # Combine arguments with options and handle special cases. - arguments = operation.get("arguments", {}) - arguments.update(arguments.pop("options", {})) - self.parse_options(arguments) - - cmd = getattr(obj, name) - - with_txn_callback = functools.partial( - self.run_operations, sessions, original_collection, in_with_transaction=True - ) - prepare_spec_arguments(operation, arguments, name, sessions, with_txn_callback) - - if name == "run_on_thread": - args = {"sessions": sessions, "collection": collection} - args.update(arguments) - arguments = args - - if not _IS_SYNC and iscoroutinefunction(cmd): - result = cmd(**dict(arguments)) - else: - result = cmd(**dict(arguments)) - # Cleanup open change stream cursors. - if name == "watch": - self.addCleanup(result.close) - - if name == "aggregate": - if arguments["pipeline"] and "$out" in arguments["pipeline"][-1]: - # Read from the primary to ensure causal consistency. - out = collection.database.get_collection( - arguments["pipeline"][-1]["$out"], read_preference=ReadPreference.PRIMARY - ) - return out.find() - if "download" in name: - result = Binary(result.read()) - - if isinstance(result, Cursor) or isinstance(result, CommandCursor): - return result.to_list() - - return result - - def allowable_errors(self, op): - """Allow encryption spec to override expected error classes.""" - return (PyMongoError,) - - def _run_op(self, sessions, collection, op, in_with_transaction): - expected_result = op.get("result") - if expect_error(op): - with self.assertRaises(self.allowable_errors(op), msg=op["name"]) as context: - self.run_operation(sessions, collection, op.copy()) - exc = context.exception - if expect_error_message(expected_result): - if isinstance(exc, BulkWriteError): - errmsg = str(exc.details).lower() - else: - errmsg = str(exc).lower() - self.assertIn(expected_result["errorContains"].lower(), errmsg) - if expect_error_code(expected_result): - self.assertEqual(expected_result["errorCodeName"], exc.details.get("codeName")) - if expect_error_labels_contain(expected_result): - self.assertErrorLabelsContain(exc, expected_result["errorLabelsContain"]) - if expect_error_labels_omit(expected_result): - self.assertErrorLabelsOmit(exc, expected_result["errorLabelsOmit"]) - if expect_timeout_error(expected_result): - self.assertIsInstance(exc, PyMongoError) - if not exc.timeout: - # Re-raise the exception for better diagnostics. - raise exc - - # Reraise the exception if we're in the with_transaction - # callback. - if in_with_transaction: - raise context.exception - else: - result = self.run_operation(sessions, collection, op.copy()) - if "result" in op: - if op["name"] == "runCommand": - self.check_command_result(expected_result, result) - else: - self.check_result(expected_result, result) - - def run_operations(self, sessions, collection, ops, in_with_transaction=False): - for op in ops: - self._run_op(sessions, collection, op, in_with_transaction) - - # TODO: factor with test_command_monitoring.py - def check_events(self, test, listener, session_ids): - events = listener.started_events - if not len(test["expectations"]): - return - - # Give a nicer message when there are missing or extra events - cmds = decode_raw([event.command for event in events]) - self.assertEqual(len(events), len(test["expectations"]), cmds) - for i, expectation in enumerate(test["expectations"]): - event_type = next(iter(expectation)) - event = events[i] - - # The tests substitute 42 for any number other than 0. - if event.command_name == "getMore" and event.command["getMore"]: - event.command["getMore"] = Int64(42) - elif event.command_name == "killCursors": - event.command["cursors"] = [Int64(42)] - elif event.command_name == "update": - # TODO: remove this once PYTHON-1744 is done. - # Add upsert and multi fields back into expectations. - updates = expectation[event_type]["command"]["updates"] - for update in updates: - update.setdefault("upsert", False) - update.setdefault("multi", False) - - # Replace afterClusterTime: 42 with actual afterClusterTime. - expected_cmd = expectation[event_type]["command"] - expected_read_concern = expected_cmd.get("readConcern") - if expected_read_concern is not None: - time = expected_read_concern.get("afterClusterTime") - if time == 42: - actual_time = event.command.get("readConcern", {}).get("afterClusterTime") - if actual_time is not None: - expected_read_concern["afterClusterTime"] = actual_time - - recovery_token = expected_cmd.get("recoveryToken") - if recovery_token == 42: - expected_cmd["recoveryToken"] = CompareType(dict) - - # Replace lsid with a name like "session0" to match test. - if "lsid" in event.command: - for name, lsid in session_ids.items(): - if event.command["lsid"] == lsid: - event.command["lsid"] = name - break - - for attr, expected in expectation[event_type].items(): - actual = getattr(event, attr) - expected = wrap_types(expected) - if isinstance(expected, dict): - for key, val in expected.items(): - if val is None: - if key in actual: - self.fail(f"Unexpected key [{key}] in {actual!r}") - elif key not in actual: - self.fail(f"Expected key [{key}] in {actual!r}") - else: - self.assertEqual( - val, decode_raw(actual[key]), f"Key [{key}] in {actual}" - ) - else: - self.assertEqual(actual, expected) - - def maybe_skip_scenario(self, test): - if test.get("skipReason"): - self.skipTest(test.get("skipReason")) - - def get_scenario_db_name(self, scenario_def): - """Allow subclasses to override a test's database name.""" - return scenario_def["database_name"] - - def get_scenario_coll_name(self, scenario_def): - """Allow subclasses to override a test's collection name.""" - return scenario_def["collection_name"] - - def get_outcome_coll_name(self, outcome, collection): - """Allow subclasses to override outcome collection.""" - return collection.name - - def run_test_ops(self, sessions, collection, test): - """Added to allow retryable writes spec to override a test's - operation. - """ - self.run_operations(sessions, collection, test["operations"]) - - def parse_client_options(self, opts): - """Allow encryption spec to override a clientOptions parsing.""" - return opts - - def setup_scenario(self, scenario_def): - """Allow specs to override a test's setup.""" - db_name = self.get_scenario_db_name(scenario_def) - coll_name = self.get_scenario_coll_name(scenario_def) - documents = scenario_def["data"] - - # Setup the collection with as few majority writes as possible. - db = client_context.client.get_database(db_name) - coll_exists = bool(db.list_collection_names(filter={"name": coll_name})) - if coll_exists: - db[coll_name].delete_many({}) - # Only use majority wc only on the final write. - wc = WriteConcern(w="majority") - if documents: - db.get_collection(coll_name, write_concern=wc).insert_many(documents) - elif not coll_exists: - # Ensure collection exists. - db.create_collection(coll_name, write_concern=wc) - - def run_scenario(self, scenario_def, test): - self.maybe_skip_scenario(test) - - # Kill all sessions before and after each test to prevent an open - # transaction (from a test failure) from blocking collection/database - # operations during test set up and tear down. - self.kill_all_sessions() - self.addCleanup(self.kill_all_sessions) - self.setup_scenario(scenario_def) - database_name = self.get_scenario_db_name(scenario_def) - collection_name = self.get_scenario_coll_name(scenario_def) - # SPEC-1245 workaround StaleDbVersion on distinct - for c in self.mongos_clients: - c[database_name][collection_name].distinct("x") - - # Configure the fail point before creating the client. - if "failPoint" in test: - fp = test["failPoint"] - self.set_fail_point(fp) - self.addCleanup( - self.set_fail_point, {"configureFailPoint": fp["configureFailPoint"], "mode": "off"} - ) - - listener = OvertCommandListener() - pool_listener = CMAPListener() - server_listener = ServerAndTopologyEventListener() - # Create a new client, to avoid interference from pooled sessions. - client_options = self.parse_client_options(test["clientOptions"]) - use_multi_mongos = test["useMultipleMongoses"] - host = None - if use_multi_mongos: - if client_context.load_balancer: - host = client_context.MULTI_MONGOS_LB_URI - elif client_context.is_mongos: - host = client_context.mongos_seeds() - client = self.rs_client( - h=host, event_listeners=[listener, pool_listener, server_listener], **client_options - ) - self.scenario_client = client - self.listener = listener - self.pool_listener = pool_listener - self.server_listener = server_listener - - # Create session0 and session1. - sessions = {} - session_ids = {} - for i in range(2): - # Don't attempt to create sessions if they are not supported by - # the running server version. - if not client_context.sessions_enabled: - break - session_name = "session%d" % i - opts = camel_to_snake_args(test["sessionOptions"][session_name]) - if "default_transaction_options" in opts: - txn_opts = self.parse_options(opts["default_transaction_options"]) - txn_opts = client_session.TransactionOptions(**txn_opts) - opts["default_transaction_options"] = txn_opts - - s = client.start_session(**dict(opts)) - - sessions[session_name] = s - # Store lsid so we can access it after end_session, in check_events. - session_ids[session_name] = s.session_id - - self.addCleanup(end_sessions, sessions) - - collection = client[database_name][collection_name] - self.run_test_ops(sessions, collection, test) - - end_sessions(sessions) - - self.check_events(test, listener, session_ids) - - # Disable fail points. - if "failPoint" in test: - fp = test["failPoint"] - self.set_fail_point({"configureFailPoint": fp["configureFailPoint"], "mode": "off"}) - - # Assert final state is expected. - outcome = test["outcome"] - expected_c = outcome.get("collection") - if expected_c is not None: - outcome_coll_name = self.get_outcome_coll_name(outcome, collection) - - # Read from the primary with local read concern to ensure causal - # consistency. - outcome_coll = client_context.client[collection.database.name].get_collection( - outcome_coll_name, - read_preference=ReadPreference.PRIMARY, - read_concern=ReadConcern("local"), - ) - actual_data = outcome_coll.find(sort=[("_id", 1)]).to_list() - - # The expected data needs to be the left hand side here otherwise - # CompareType(Binary) doesn't work. - self.assertEqual(wrap_types(expected_c["data"]), actual_data) - - -def expect_any_error(op): - if isinstance(op, dict): - return op.get("error") - - return False - - -def expect_error_message(expected_result): - if isinstance(expected_result, dict): - return isinstance(expected_result["errorContains"], str) - - return False - - -def expect_error_code(expected_result): - if isinstance(expected_result, dict): - return expected_result["errorCodeName"] - - return False - - -def expect_error_labels_contain(expected_result): - if isinstance(expected_result, dict): - return expected_result["errorLabelsContain"] - - return False - - -def expect_error_labels_omit(expected_result): - if isinstance(expected_result, dict): - return expected_result["errorLabelsOmit"] - - return False - - -def expect_timeout_error(expected_result): - if isinstance(expected_result, dict): - return expected_result["isTimeoutError"] - - return False - - -def expect_error(op): - expected_result = op.get("result") - return ( - expect_any_error(op) - or expect_error_message(expected_result) - or expect_error_code(expected_result) - or expect_error_labels_contain(expected_result) - or expect_error_labels_omit(expected_result) - or expect_timeout_error(expected_result) - ) - - -def end_sessions(sessions): - for s in sessions.values(): - # Aborts the transaction if it's open. - s.end_session() - - -def decode_raw(val): - """Decode RawBSONDocuments in the given container.""" - if isinstance(val, (list, abc.Mapping)): - return decode(encode({"v": val}))["v"] - return val - - -TYPES = { - "binData": Binary, - "long": Int64, - "int": int, - "string": str, - "objectId": ObjectId, - "object": dict, - "array": list, -} - - -def wrap_types(val): - """Support $$type assertion in command results.""" - if isinstance(val, list): - return [wrap_types(v) for v in val] - if isinstance(val, abc.Mapping): - typ = val.get("$$type") - if typ: - if isinstance(typ, str): - types = TYPES[typ] - else: - types = tuple(TYPES[t] for t in typ) - return CompareType(types) - d = {} - for key in val: - d[key] = wrap_types(val[key]) - return d - return val diff --git a/tools/clean.py b/tools/clean.py index b6e1867a0a..15db9a411b 100644 --- a/tools/clean.py +++ b/tools/clean.py @@ -41,7 +41,7 @@ pass try: - from bson import _cbson # type: ignore[attr-defined] # noqa: F401 + from bson import _cbson # noqa: F401 sys.exit("could still import _cbson") except ImportError: diff --git a/tools/fail_if_no_c.py b/tools/fail_if_no_c.py index 64280a81d2..d8bc9d1e65 100644 --- a/tools/fail_if_no_c.py +++ b/tools/fail_if_no_c.py @@ -37,7 +37,7 @@ def main() -> None: except Exception as e: LOGGER.exception(e) try: - from bson import _cbson # type:ignore[attr-defined] # noqa: F401 + from bson import _cbson # noqa: F401 except Exception as e: LOGGER.exception(e) sys.exit("could not load C extensions") diff --git a/tools/synchro.py b/tools/synchro.py index 5735d0052a..ed794c5963 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -37,6 +37,7 @@ "AsyncRawBatchCursor": "RawBatchCursor", "AsyncRawBatchCommandCursor": "RawBatchCommandCursor", "AsyncClientSession": "ClientSession", + "_AsyncBoundSessionContext": "_BoundSessionContext", "AsyncChangeStream": "ChangeStream", "AsyncCollectionChangeStream": "CollectionChangeStream", "AsyncDatabaseChangeStream": "DatabaseChangeStream", @@ -212,6 +213,7 @@ def async_only_test(f: str) -> bool: "test_bulk.py", "test_change_stream.py", "test_client.py", + "test_client_backpressure.py", "test_client_bulk_write.py", "test_client_context.py", "test_client_metadata.py", @@ -349,7 +351,7 @@ def translate_async_sleeps(lines: list[str]) -> list[str]: sleeps = [line for line in lines if "asyncio.sleep" in line] for line in sleeps: - res = re.search(r"asyncio.sleep\(([^()]*)\)", line) + res = re.search(r"asyncio\.sleep\(\s*(.*?)\)", line) if res: old = res[0] index = lines.index(line) diff --git a/uv.lock b/uv.lock index 7d0ff9fb0f..78d0cc213f 100644 --- a/uv.lock +++ b/uv.lock @@ -1562,7 +1562,6 @@ zstd = [ [package.dev-dependencies] coverage = [ { name = "coverage", extra = ["toml"] }, - { name = "pytest-cov" }, ] gevent = [ { name = "gevent" }, @@ -1612,10 +1611,7 @@ requires-dist = [ provides-extras = ["aws", "docs", "encryption", "gssapi", "ocsp", "snappy", "test", "zstd"] [package.metadata.requires-dev] -coverage = [ - { name = "coverage", extras = ["toml"], specifier = ">=5,<=7.10.7" }, - { name = "pytest-cov", specifier = ">=4.0.0" }, -] +coverage = [{ name = "coverage", extras = ["toml"], specifier = ">=5,<=7.10.7" }] dev = [] gevent = [{ name = "gevent", specifier = ">=21.12" }] mockupdb = [{ name = "mockupdb", git = "https://github.com/mongodb-labs/mongo-mockup-db?rev=master" }] @@ -1763,21 +1759,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" }, ] -[[package]] -name = "pytest-cov" -version = "7.0.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "coverage", extra = ["toml"] }, - { name = "pluggy" }, - { name = "pytest", version = "8.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, - { name = "pytest", version = "9.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/5e/f7/c933acc76f5208b3b00089573cf6a2bc26dc80a8aece8f52bb7d6b1855ca/pytest_cov-7.0.0.tar.gz", hash = "sha256:33c97eda2e049a0c5298e91f519302a1334c26ac65c1a483d6206fd458361af1", size = 54328, upload-time = "2025-09-09T10:57:02.113Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ee/49/1377b49de7d0c1ce41292161ea0f721913fa8722c19fb9c1e3aa0367eecb/pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861", size = 22424, upload-time = "2025-09-09T10:57:00.695Z" }, -] - [[package]] name = "python-dateutil" version = "2.9.0.post0"