diff --git a/airflow-core/docs/migrations-ref.rst b/airflow-core/docs/migrations-ref.rst index 07d6f5afc16e9..27d26c86d281a 100644 --- a/airflow-core/docs/migrations-ref.rst +++ b/airflow-core/docs/migrations-ref.rst @@ -39,7 +39,10 @@ Here's the list of all the Database Migrations that are executed via when you ru +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | Revision ID | Revises ID | Airflow Version | Description | +=========================+==================+===================+==============================================================+ -| ``1d6611b6ab7c`` (head) | ``888b59e02a5b`` | ``3.2.0`` | Add bundle_name to callback table. | +| ``a7e6d4c3b2f1`` (head) | ``1d6611b6ab7c`` | ``3.2.0`` | Add connection_test_request table for async connection | +| | | | testing. | ++-------------------------+------------------+-------------------+--------------------------------------------------------------+ +| ``1d6611b6ab7c`` | ``888b59e02a5b`` | ``3.2.0`` | Add bundle_name to callback table. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | ``888b59e02a5b`` | ``6222ce48e289`` | ``3.2.0`` | Fix migration file ORM inconsistencies. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/connections.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/connections.py index f7cb944ebbf6b..f13c40c0c194b 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/connections.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/connections.py @@ -19,6 +19,7 @@ import json from collections.abc import Iterable, Mapping +from datetime import datetime from typing import Annotated, Any from pydantic import Field, field_validator @@ -72,12 +73,46 @@ class ConnectionCollectionResponse(BaseModel): class ConnectionTestResponse(BaseModel): - """Connection Test serializer for responses.""" + """Connection Test serializer for synchronous test responses.""" status: bool message: str +class ConnectionTestRequestBody(StrictBaseModel): + """Request body for async connection test.""" + + connection_id: str + conn_type: str + host: str | None = None + login: str | None = None + schema_: str | None = Field(None, alias="schema") + port: int | None = None + password: str | None = None + extra: str | None = None + commit_on_success: bool = False + executor: str | None = None + queue: str | None = None + + +class ConnectionTestQueuedResponse(BaseModel): + """Response returned when an async connection test is queued.""" + + token: str + connection_id: str + state: str + + +class ConnectionTestStatusResponse(BaseModel): + """Response returned when polling for async connection test status.""" + + token: str + connection_id: str + state: str + result_message: str | None = None + created_at: datetime + + class ConnectionHookFieldBehavior(BaseModel): """A class to store the behavior of each standard field of a Hook.""" diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml index 17c4f87a6e44f..7bde23abb32e2 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml +++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml @@ -1335,6 +1335,12 @@ paths: schema: $ref: '#/components/schemas/HTTPExceptionResponse' description: Not Found + '409': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Conflict '422': description: Validation Error content: @@ -1681,6 +1687,110 @@ paths: security: - OAuth2PasswordBearer: [] - HTTPBearer: [] + /api/v2/connections/test-async: + post: + tags: + - Connection + summary: Test Connection Async + description: 'Queue an async connection test to be executed on a worker. + + + The connection data is stored in the test request table and the worker + + reads from there. Returns a token to poll for the result via + + GET /connections/test-async/{token}.' + operationId: test_connection_async + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/ConnectionTestRequestBody' + required: true + responses: + '202': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/ConnectionTestQueuedResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + '403': + description: Forbidden + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + '409': + description: Conflict + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + security: + - OAuth2PasswordBearer: [] + - HTTPBearer: [] + /api/v2/connections/test-async/{connection_test_token}: + get: + tags: + - Connection + summary: Get Connection Test + description: "Poll for the status of an async connection test.\n\nKnowledge\ + \ of the token serves as authorization \u2014 only the client\nthat initiated\ + \ the test knows the crypto-random token." + operationId: get_connection_test + security: + - OAuth2PasswordBearer: [] + - HTTPBearer: [] + parameters: + - name: connection_test_token + in: path + required: true + schema: + type: string + title: Connection Test Token + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/ConnectionTestStatusResponse' + '401': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unauthorized + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '404': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Not Found + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' /api/v2/connections/defaults: post: tags: @@ -10096,6 +10206,83 @@ components: - team_name title: ConnectionResponse description: Connection serializer for responses. + ConnectionTestQueuedResponse: + properties: + token: + type: string + title: Token + connection_id: + type: string + title: Connection Id + state: + type: string + title: State + type: object + required: + - token + - connection_id + - state + title: ConnectionTestQueuedResponse + description: Response returned when an async connection test is queued. + ConnectionTestRequestBody: + properties: + connection_id: + type: string + title: Connection Id + conn_type: + type: string + title: Conn Type + host: + anyOf: + - type: string + - type: 'null' + title: Host + login: + anyOf: + - type: string + - type: 'null' + title: Login + schema: + anyOf: + - type: string + - type: 'null' + title: Schema + port: + anyOf: + - type: integer + - type: 'null' + title: Port + password: + anyOf: + - type: string + - type: 'null' + title: Password + extra: + anyOf: + - type: string + - type: 'null' + title: Extra + commit_on_success: + type: boolean + title: Commit On Success + default: false + executor: + anyOf: + - type: string + - type: 'null' + title: Executor + queue: + anyOf: + - type: string + - type: 'null' + title: Queue + additionalProperties: false + type: object + required: + - connection_id + - conn_type + title: ConnectionTestRequestBody + description: Request body for async connection test. ConnectionTestResponse: properties: status: @@ -10109,7 +10296,35 @@ components: - status - message title: ConnectionTestResponse - description: Connection Test serializer for responses. + description: Connection Test serializer for synchronous test responses. + ConnectionTestStatusResponse: + properties: + token: + type: string + title: Token + connection_id: + type: string + title: Connection Id + state: + type: string + title: State + result_message: + anyOf: + - type: string + - type: 'null' + title: Result Message + created_at: + type: string + format: date-time + title: Created At + type: object + required: + - token + - connection_id + - state + - created_at + title: ConnectionTestStatusResponse + description: Response returned when polling for async connection test status. CreateAssetEventsBody: properties: asset_id: diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py index 05ec6b642941d..ac086caa2e434 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py @@ -41,7 +41,10 @@ ConnectionBodyPartial, ConnectionCollectionResponse, ConnectionResponse, + ConnectionTestQueuedResponse, + ConnectionTestRequestBody, ConnectionTestResponse, + ConnectionTestStatusResponse, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc from airflow.api_fastapi.core_api.security import ( @@ -57,6 +60,7 @@ from airflow.configuration import conf from airflow.exceptions import AirflowNotFoundException from airflow.models import Connection +from airflow.models.connection_test import ACTIVE_STATES, ConnectionTestRequest from airflow.secrets.environment_variables import CONN_ENV_PREFIX from airflow.utils.db import create_default_connections as db_create_default_connections from airflow.utils.strings import get_random_string @@ -64,10 +68,36 @@ connections_router = AirflowRouter(tags=["Connection"], prefix="/connections") +def _ensure_test_connection_enabled() -> None: + """Raise 403 if connection testing is not enabled in the Airflow configuration.""" + if conf.get("core", "test_connection", fallback="Disabled").lower().strip() != "enabled": + raise HTTPException( + status.HTTP_403_FORBIDDEN, + "Testing connections is disabled in Airflow configuration. " + "Contact your deployment admin to enable it.", + ) + + +def _check_no_active_test(connection_id: str, session: SessionDep) -> None: + """Raise 409 if there is an active connection test request for the given connection_id.""" + active_test = session.scalar( + select(ConnectionTestRequest).filter( + ConnectionTestRequest.connection_id == connection_id, + ConnectionTestRequest.state.in_(ACTIVE_STATES), + ) + ) + if active_test is not None: + raise HTTPException( + status.HTTP_409_CONFLICT, + f"Cannot modify connection `{connection_id}` while an async test is running. " + "This typically takes only a few seconds — please retry shortly.", + ) + + @connections_router.delete( "/{connection_id}", status_code=status.HTTP_204_NO_CONTENT, - responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), + responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND, status.HTTP_409_CONFLICT]), dependencies=[Depends(requires_access_connection(method="DELETE")), Depends(action_logging())], ) def delete_connection( @@ -75,6 +105,8 @@ def delete_connection( session: SessionDep, ): """Delete a connection entry.""" + _check_no_active_test(connection_id, session) + connection = session.scalar(select(Connection).filter_by(conn_id=connection_id)) if connection is None: @@ -191,6 +223,8 @@ def patch_connection( update_mask: list[str] | None = Query(None), ) -> ConnectionResponse: """Update a connection entry.""" + _check_no_active_test(connection_id, session) + if patch_body.connection_id != connection_id: raise HTTPException( status.HTTP_400_BAD_REQUEST, @@ -229,12 +263,7 @@ def test_connection(test_body: ConnectionBody) -> ConnectionTestResponse: as some hook classes tries to find out the `conn` from their __init__ method & errors out if not found. It also deletes the conn id env connection after the test. """ - if conf.get("core", "test_connection", fallback="Disabled").lower().strip() != "enabled": - raise HTTPException( - status.HTTP_403_FORBIDDEN, - "Testing connections is disabled in Airflow configuration. " - "Contact your deployment admin to enable it.", - ) + _ensure_test_connection_enabled() transient_conn_id = get_random_string() conn_env_var = f"{CONN_ENV_PREFIX}{transient_conn_id.upper()}" @@ -257,6 +286,83 @@ def test_connection(test_body: ConnectionBody) -> ConnectionTestResponse: os.environ.pop(conn_env_var, None) +@connections_router.post( + "/test-async", + status_code=status.HTTP_202_ACCEPTED, + responses=create_openapi_http_exception_doc([status.HTTP_403_FORBIDDEN, status.HTTP_409_CONFLICT]), + dependencies=[Depends(requires_access_connection(method="POST")), Depends(action_logging())], +) +def test_connection_async( + test_body: ConnectionTestRequestBody, + session: SessionDep, +) -> ConnectionTestQueuedResponse: + """ + Queue an async connection test to be executed on a worker. + + The connection data is stored in the test request table and the worker + reads from there. Returns a token to poll for the result via + GET /connections/test-async/{token}. + """ + _ensure_test_connection_enabled() + + # Only one active test per connection_id at a time. + _check_no_active_test(test_body.connection_id, session) + + connection_test = ConnectionTestRequest( + connection_id=test_body.connection_id, + conn_type=test_body.conn_type, + host=test_body.host, + login=test_body.login, + password=test_body.password, + schema=test_body.schema_, + port=test_body.port, + extra=test_body.extra, + commit_on_success=test_body.commit_on_success, + executor=test_body.executor, + queue=test_body.queue, + ) + session.add(connection_test) + session.flush() + + return ConnectionTestQueuedResponse( + token=connection_test.token, + connection_id=connection_test.connection_id, + state=connection_test.state, + ) + + +@connections_router.get( + "/test-async/{connection_test_token}", + responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), + dependencies=[Depends(requires_access_connection(method="GET"))], +) +def get_connection_test( + connection_test_token: str, + session: SessionDep, +) -> ConnectionTestStatusResponse: + """ + Poll for the status of an async connection test. + + Knowledge of the token serves as authorization — only the client + that initiated the test knows the crypto-random token. + """ + connection_test = session.scalar(select(ConnectionTestRequest).filter_by(token=connection_test_token)) + + if connection_test is None: + raise HTTPException( + status.HTTP_404_NOT_FOUND, + f"No connection test found for token: `{connection_test_token}`", + ) + + return ConnectionTestStatusResponse( + token=connection_test.token, + connection_id=connection_test.connection_id, + state=connection_test.state, + result_message=connection_test.result_message, + created_at=connection_test.created_at, + ) + + @connections_router.post( "/defaults", status_code=status.HTTP_204_NO_CONTENT, diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/connection_test.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/connection_test.py new file mode 100644 index 0000000000000..e0b63e8dae55c --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/connection_test.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from pydantic import Field + +from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel +from airflow.models.connection_test import ConnectionTestState + + +class ConnectionTestResultBody(StrictBaseModel): + """Payload sent by workers to report connection test results.""" + + state: ConnectionTestState + result_message: str | None = None + + +class ConnectionTestConnectionResponse(BaseModel): + """Connection data returned to workers from a test request.""" + + conn_id: str + conn_type: str + host: str | None = None + login: str | None = None + password: str | None = None + schema_: str | None = Field(None, alias="schema") + port: int | None = None + extra: str | None = None diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py index a076592d6471a..0d4291bbcc2de 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py @@ -22,6 +22,7 @@ from airflow.api_fastapi.execution_api.routes import ( asset_events, assets, + connection_tests, connections, dag_runs, dags, @@ -42,6 +43,9 @@ authenticated_router.include_router(assets.router, prefix="/assets", tags=["Assets"]) authenticated_router.include_router(asset_events.router, prefix="/asset-events", tags=["Asset Events"]) +authenticated_router.include_router( + connection_tests.router, prefix="/connection-tests", tags=["Connection Tests"] +) authenticated_router.include_router(connections.router, prefix="/connections", tags=["Connections"]) authenticated_router.include_router(dag_runs.router, prefix="/dag-runs", tags=["Dag Runs"]) authenticated_router.include_router(dags.router, prefix="/dags", tags=["Dags"]) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/connection_tests.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/connection_tests.py new file mode 100644 index 0000000000000..47986523271cf --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/connection_tests.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from uuid import UUID + +from cadwyn import VersionedAPIRouter +from fastapi import HTTPException, status + +from airflow.api_fastapi.common.db.common import SessionDep +from airflow.api_fastapi.execution_api.datamodels.connection_test import ( + ConnectionTestConnectionResponse, + ConnectionTestResultBody, +) +from airflow.models.connection_test import ( + TERMINAL_STATES, + ConnectionTestRequest, + ConnectionTestState, +) + +router = VersionedAPIRouter() + + +@router.get( + "/{connection_test_id}/connection", + responses={ + status.HTTP_404_NOT_FOUND: {"description": "Connection test not found"}, + }, +) +def get_connection_test_connection( + connection_test_id: UUID, + session: SessionDep, +) -> ConnectionTestConnectionResponse: + """Return the connection data stored in a test request (called by workers).""" + ct = session.get(ConnectionTestRequest, connection_test_id) + if ct is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": f"Connection test {connection_test_id} not found", + }, + ) + + return ConnectionTestConnectionResponse( + conn_id=ct.connection_id, + conn_type=ct.conn_type, + host=ct.host, + login=ct.login, + password=ct.password, + schema=ct.schema, + port=ct.port, + extra=ct.extra, + ) + + +@router.patch( + "/{connection_test_id}", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + status.HTTP_404_NOT_FOUND: {"description": "Connection test not found"}, + status.HTTP_409_CONFLICT: {"description": "Connection test already in a terminal state"}, + }, +) +def patch_connection_test( + connection_test_id: UUID, + body: ConnectionTestResultBody, + session: SessionDep, +) -> None: + """Update the result of a connection test (called by workers).""" + ct = session.get(ConnectionTestRequest, connection_test_id, with_for_update=True) + if ct is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": f"Connection test {connection_test_id} not found", + }, + ) + + if ct.state in TERMINAL_STATES: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail={ + "reason": "conflict", + "message": f"Connection test {connection_test_id} is already in terminal state: {ct.state}", + }, + ) + + ct.state = body.state + ct.result_message = body.result_message + + if body.state == ConnectionTestState.SUCCESS and ct.commit_on_success: + ct.commit_to_connection_table(session=session) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py index 2cbe2e3007b3f..2ddbf81fd03d7 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py @@ -34,6 +34,7 @@ MovePreviousRunEndpoint, ) from airflow.api_fastapi.execution_api.versions.v2026_03_31 import ( + AddConnectionTestEndpoint, AddNoteField, MakeDagRunStartDateNullable, ModifyDeferredTaskKwargsToJsonValue, @@ -46,6 +47,7 @@ Version("2026-04-13", AddDagEndpoint), Version( "2026-03-31", + AddConnectionTestEndpoint, MakeDagRunStartDateNullable, ModifyDeferredTaskKwargsToJsonValue, RemoveUpstreamMapIndexesField, diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_03_31.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_03_31.py index 2d14493e81fe6..76dab90340abe 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_03_31.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_03_31.py @@ -19,7 +19,7 @@ from typing import Any -from cadwyn import ResponseInfo, VersionChange, convert_response_to_previous_version_for, schema +from cadwyn import ResponseInfo, VersionChange, convert_response_to_previous_version_for, endpoint, schema from airflow.api_fastapi.common.types import UtcDateTime from airflow.api_fastapi.execution_api.datamodels.taskinstance import ( @@ -29,6 +29,17 @@ ) +class AddConnectionTestEndpoint(VersionChange): + """Add connection-tests endpoints for async connection testing.""" + + description = __doc__ + + instructions_to_migrate_to_previous_version = ( + endpoint("/connection-tests/{connection_test_id}", ["PATCH"]).didnt_exist, + endpoint("/connection-tests/{connection_test_id}/connection", ["GET"]).didnt_exist, + ) + + class ModifyDeferredTaskKwargsToJsonValue(VersionChange): """Change the types of `trigger_kwargs` and `next_kwargs` in TIDeferredStatePayload to JsonValue.""" diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index 0c002f5276cfe..cc86cc92dec1d 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -481,6 +481,24 @@ core: type: string example: ~ default: "Disabled" + connection_test_timeout: + description: | + Maximum number of seconds an async connection test is allowed to run + before it is considered timed out. The scheduler reaper uses this value + plus a grace period to mark stale tests as failed. + version_added: 3.2.0 + type: integer + example: ~ + default: "60" + max_connection_test_concurrency: + description: | + Maximum number of connection tests that can be active (QUEUED + RUNNING) + at the same time. Excess tests will remain in PENDING state until slots + become available. + version_added: 3.2.0 + type: integer + example: ~ + default: "4" max_templated_field_length: description: | The maximum length of the rendered template field. If the value to be stored in the @@ -2549,6 +2567,15 @@ scheduler: type: float example: ~ default: "120.0" + connection_test_reaper_interval: + description: | + How often (in seconds) the scheduler should check for stale + connection tests (QUEUED or RUNNING past their timeout + grace period) + and mark them as failed. + version_added: 3.2.0 + type: float + example: ~ + default: "30.0" allowed_run_id_pattern: description: | The run_id pattern used to verify the validity of user input to the run_id parameter when diff --git a/airflow-core/src/airflow/executors/base_executor.py b/airflow-core/src/airflow/executors/base_executor.py index d67c25c7bafaa..24f77add4bec3 100644 --- a/airflow-core/src/airflow/executors/base_executor.py +++ b/airflow-core/src/airflow/executors/base_executor.py @@ -143,6 +143,7 @@ class BaseExecutor(LoggingMixin): supports_ad_hoc_ti_run: bool = False supports_callbacks: bool = False supports_multi_team: bool = False + supports_connection_test: bool = False sentry_integration: str = "" is_local: bool = False @@ -186,6 +187,7 @@ def __init__(self, parallelism: int = PARALLELISM, team_name: str | None = None) self.team_name: str | None = team_name self.queued_tasks: dict[TaskInstanceKey, workloads.ExecuteTask] = {} self.queued_callbacks: dict[str, workloads.ExecuteCallback] = {} + self.queued_connection_tests: dict[str, workloads.TestConnection] = {} self.running: set[WorkloadKey] = set() self.event_buffer: dict[WorkloadKey, EventBufferValueType] = {} self._task_event_logs: deque[Log] = deque() @@ -231,10 +233,18 @@ def queue_workload(self, workload: workloads.All, session: Session) -> None: f"See LocalExecutor or CeleryExecutor for reference implementation." ) self.queued_callbacks[workload.callback.id] = workload + elif isinstance(workload, workloads.TestConnection): + if not self.supports_connection_test: + raise NotImplementedError( + f"{type(self).__name__} does not support TestConnection workloads. " + f"Set supports_connection_test = True and implement connection test handling " + f"in _process_workloads(). See LocalExecutor for reference implementation." + ) + self.queued_connection_tests[str(workload.connection_test_id)] = workload else: raise ValueError( f"Un-handled workload type {type(workload).__name__!r} in {type(self).__name__}. " - f"Workload must be one of: ExecuteTask, ExecuteCallback." + f"Workload must be one of: ExecuteTask, ExecuteCallback, TestConnection." ) def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey, workloads.All]]: @@ -305,10 +315,24 @@ def heartbeat(self) -> None: self._emit_metrics(open_slots, num_running_workloads, num_queued_workloads) self.trigger_tasks(open_slots) + self.trigger_connection_tests() + # Calling child class sync method self.log.debug("Calling the %s sync method", self.__class__) self.sync() + def trigger_connection_tests(self) -> None: + """Process queued connection tests, respecting available slot capacity.""" + if not self.supports_connection_test or not self.queued_connection_tests: + return + + available = self.slots_available + if available <= 0: + return + + tests_to_run = list(self.queued_connection_tests.values())[:available] + self._process_workloads(tests_to_run) + def _get_metric_name(self, metric_base_name: str) -> str: return ( f"{metric_base_name}.{self.__class__.__name__}" @@ -529,13 +553,24 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task @property def slots_available(self): - """Number of new workloads (tasks and callbacks) this executor instance can accept.""" - return self.parallelism - len(self.running) - len(self.queued_tasks) - len(self.queued_callbacks) + """Number of new workloads (tasks, callbacks, and connection tests) this executor instance can accept.""" + return ( + self.parallelism + - len(self.running) + - len(self.queued_tasks) + - len(self.queued_callbacks) + - len(self.queued_connection_tests) + ) @property def slots_occupied(self): - """Number of workloads (tasks and callbacks) this executor instance is currently managing.""" - return len(self.running) + len(self.queued_tasks) + len(self.queued_callbacks) + """Number of workloads (tasks, callbacks, and connection tests) this executor instance is currently managing.""" + return ( + len(self.running) + + len(self.queued_tasks) + + len(self.queued_callbacks) + + len(self.queued_connection_tests) + ) def debug_dump(self): """Get called in response to SIGUSR2 by the scheduler.""" diff --git a/airflow-core/src/airflow/executors/local_executor.py b/airflow-core/src/airflow/executors/local_executor.py index 9b5939a0bd2e7..1bb705593f22b 100644 --- a/airflow-core/src/airflow/executors/local_executor.py +++ b/airflow-core/src/airflow/executors/local_executor.py @@ -29,6 +29,7 @@ import multiprocessing import multiprocessing.sharedctypes import os +import signal import sys from multiprocessing import Queue, SimpleQueue from typing import TYPE_CHECKING @@ -38,6 +39,8 @@ from airflow.executors import workloads from airflow.executors.base_executor import BaseExecutor from airflow.executors.workloads.callback import execute_callback_workload +from airflow.models.connection import Connection +from airflow.models.connection_test import ConnectionTestState, run_connection_test from airflow.utils.state import CallbackState, TaskInstanceState # add logger to parameter of setproctitle to support logging @@ -117,6 +120,12 @@ def _run_worker( log.exception("Callback execution failed") output.put((workload.callback.id, CallbackState.FAILED, e)) + elif isinstance(workload, workloads.TestConnection): + try: + _execute_connection_test(log, workload, team_conf) + except Exception: + log.exception("Connection test failed") + else: raise ValueError(f"LocalExecutor does not know how to handle {type(workload)}") @@ -168,6 +177,81 @@ def _execute_callback(log: Logger, workload: workloads.ExecuteCallback, team_con raise RuntimeError(error_msg or "Callback execution failed") +def _execute_connection_test(log: Logger, workload: workloads.TestConnection, team_conf) -> None: + """ + Execute a connection test workload. + + Constructs an SDK ``Client``, fetches the connection via the Execution API, + enforces a timeout via ``signal.alarm``, and reports all outcomes back + through the Execution API. + + :param log: Logger instance + :param workload: The TestConnection workload to execute + :param team_conf: Team-specific executor configuration + """ + # Lazy import: SDK modules must not be loaded at module level to avoid + # coupling core (scheduler-loaded) code to the SDK. + from airflow.sdk.api.client import Client + + setproctitle( + f"{_get_executor_process_title_prefix(team_conf.team_name)} connection-test {workload.connection_id}", + log, + ) + + base_url = team_conf.get("api", "base_url", fallback="/") + if base_url.startswith("/"): + base_url = f"http://localhost:8080{base_url}" + default_execution_api_server = f"{base_url.rstrip('/')}/execution/" + server = team_conf.get("core", "execution_api_server_url", fallback=default_execution_api_server) + + client = Client(base_url=server, token=workload.token) + + def _handle_timeout(signum, frame): + raise TimeoutError(f"Connection test timed out after {workload.timeout}s") + + signal.signal(signal.SIGALRM, _handle_timeout) + signal.alarm(workload.timeout) + try: + client.connection_tests.update_state(workload.connection_test_id, ConnectionTestState.RUNNING) + + conn_response = client.connection_tests.get_connection(workload.connection_test_id) + + conn = Connection( + conn_id=conn_response.conn_id, + conn_type=conn_response.conn_type, + host=conn_response.host, + login=conn_response.login, + password=conn_response.password, + schema=conn_response.schema_, + port=conn_response.port, + extra=conn_response.extra, + ) + success, message = run_connection_test(conn=conn) + + state = ConnectionTestState.SUCCESS if success else ConnectionTestState.FAILED + client.connection_tests.update_state(workload.connection_test_id, state, message) + except TimeoutError: + log.error( + "Connection test timed out after %ds", + workload.timeout, + connection_id=workload.connection_id, + ) + client.connection_tests.update_state( + workload.connection_test_id, + ConnectionTestState.FAILED, + f"Connection test timed out after {workload.timeout}s", + ) + except Exception as e: + log.exception("Connection test failed unexpectedly", connection_id=workload.connection_id) + client.connection_tests.update_state( + workload.connection_test_id, + ConnectionTestState.FAILED, + f"Connection test failed unexpectedly: {type(e).__name__}", + ) + finally: + signal.alarm(0) + + class LocalExecutor(BaseExecutor): """ LocalExecutor executes tasks locally in parallel. @@ -183,6 +267,7 @@ class LocalExecutor(BaseExecutor): supports_multi_team: bool = True serve_logs: bool = True supports_callbacks: bool = True + supports_connection_test: bool = True activity_queue: SimpleQueue[workloads.All | None] result_queue: SimpleQueue[WorkloadResultType] @@ -336,6 +421,8 @@ def _process_workloads(self, workload_list): del self.queued_tasks[workload.ti.key] elif isinstance(workload, workloads.ExecuteCallback): del self.queued_callbacks[workload.callback.id] + elif isinstance(workload, workloads.TestConnection): + del self.queued_connection_tests[str(workload.connection_test_id)] with self._unread_messages: self._unread_messages.value += len(workload_list) self._check_workers() diff --git a/airflow-core/src/airflow/executors/workloads/__init__.py b/airflow-core/src/airflow/executors/workloads/__init__.py index 462e38ad0aaac..136ab37734cbb 100644 --- a/airflow-core/src/airflow/executors/workloads/__init__.py +++ b/airflow-core/src/airflow/executors/workloads/__init__.py @@ -24,11 +24,12 @@ from airflow.executors.workloads.base import BaseWorkload, BundleInfo from airflow.executors.workloads.callback import CallbackFetchMethod, ExecuteCallback +from airflow.executors.workloads.connection_test import TestConnection from airflow.executors.workloads.task import ExecuteTask, TaskInstanceDTO from airflow.executors.workloads.trigger import RunTrigger All = Annotated[ - ExecuteTask | ExecuteCallback | RunTrigger, + ExecuteTask | ExecuteCallback | RunTrigger | TestConnection, Field(discriminator="type"), ] @@ -43,4 +44,5 @@ "ExecuteTask", "TaskInstance", "TaskInstanceDTO", + "TestConnection", ] diff --git a/airflow-core/src/airflow/executors/workloads/connection_test.py b/airflow-core/src/airflow/executors/workloads/connection_test.py new file mode 100644 index 0000000000000..9cbb04229c075 --- /dev/null +++ b/airflow-core/src/airflow/executors/workloads/connection_test.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Connection test workload schema for executor communication.""" + +from __future__ import annotations + +import uuid +from typing import TYPE_CHECKING, Literal + +from pydantic import Field + +from airflow.executors.workloads.base import BaseWorkloadSchema + +if TYPE_CHECKING: + from airflow.api_fastapi.auth.tokens import JWTGenerator + + +class TestConnection(BaseWorkloadSchema): + """Execute a connection test on a worker.""" + + connection_test_id: uuid.UUID + connection_id: str + timeout: int = 60 + queue: str | None = None + + type: Literal["TestConnection"] = Field(init=False, default="TestConnection") + + @classmethod + def make( + cls, + *, + connection_test_id: uuid.UUID, + connection_id: str, + timeout: int = 60, + queue: str | None = None, + generator: JWTGenerator | None = None, + ) -> TestConnection: + return cls( + connection_test_id=connection_test_id, + connection_id=connection_id, + timeout=timeout, + queue=queue, + token=cls.generate_token(str(connection_test_id), generator), + ) diff --git a/airflow-core/src/airflow/executors/workloads/types.py b/airflow-core/src/airflow/executors/workloads/types.py index 31cda7028466f..fa9e5f7ff774f 100644 --- a/airflow-core/src/airflow/executors/workloads/types.py +++ b/airflow-core/src/airflow/executors/workloads/types.py @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, TypeAlias from airflow.models.callback import ExecutorCallback +from airflow.models.connection_test import ConnectionTestRequest from airflow.models.taskinstance import TaskInstance if TYPE_CHECKING: @@ -37,4 +38,4 @@ # Type alias for scheduler workloads (ORM models that can be routed to executors) # Must be outside TYPE_CHECKING for use in function signatures -SchedulerWorkload: TypeAlias = TaskInstance | ExecutorCallback +SchedulerWorkload: TypeAlias = TaskInstance | ExecutorCallback | ConnectionTestRequest diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index c64bd166f1b3a..5783d2ddef874 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -73,6 +73,12 @@ ) from airflow.models.backfill import Backfill from airflow.models.callback import Callback, CallbackType, ExecutorCallback +from airflow.models.connection_test import ( + ACTIVE_STATES, + DISPATCHED_STATES, + ConnectionTestRequest, + ConnectionTestState, +) from airflow.models.dag import DagModel from airflow.models.dag_version import DagVersion from airflow.models.dagbag import DBDagBag @@ -1585,6 +1591,11 @@ def _run_scheduler_loop(self) -> None: action=bundle_cleanup_mgr.remove_stale_bundle_versions, ) + timers.call_regular_interval( + delay=conf.getfloat("scheduler", "connection_test_reaper_interval", fallback=30.0), + action=self._reap_stale_connection_tests, + ) + idle_count = 0 for loop_count in itertools.count(start=1): @@ -1629,6 +1640,9 @@ def _run_scheduler_loop(self) -> None: # Route ExecutorCallback workloads to executors (similar to task routing) self._enqueue_executor_callbacks(session) + # Enqueue pending connection tests to executors + self._enqueue_connection_tests(session=session) + # Heartbeat the scheduler periodically perform_heartbeat( job=self.job, heartbeat_callback=self.heartbeat_callback, only_if_necessary=True @@ -3114,6 +3128,86 @@ def _activate_assets_generate_warnings() -> Iterator[tuple[str, str]]: session.add(warning) existing_warned_dag_ids.add(warning.dag_id) + def _enqueue_connection_tests(self, *, session: Session) -> None: + """Enqueue pending connection tests to executors that support them.""" + max_concurrency = conf.getint("core", "max_connection_test_concurrency", fallback=4) + timeout = conf.getint("core", "connection_test_timeout", fallback=60) + + num_occupied_slots = sum(executor.slots_occupied for executor in self.executors) + parallelism_budget = conf.getint("core", "parallelism") - num_occupied_slots + if parallelism_budget <= 0: + return + + active_count = session.scalar( + select(func.count(ConnectionTestRequest.id)).where( + ConnectionTestRequest.state.in_(DISPATCHED_STATES) + ) + ) + concurrency_budget = max_concurrency - (active_count or 0) + budget = min(concurrency_budget, parallelism_budget) + if budget <= 0: + return + + pending_stmt = ( + select(ConnectionTestRequest) + .where(ConnectionTestRequest.state == ConnectionTestState.PENDING) + .order_by(ConnectionTestRequest.created_at) + .limit(budget) + ) + pending_stmt = with_row_locks(pending_stmt, session, of=ConnectionTestRequest, skip_locked=True) + pending_tests = session.scalars(pending_stmt).all() + + if not pending_tests: + return + + for ct in pending_tests: + executor = self._try_to_load_executor(ct, session) + if executor is not None and not executor.supports_connection_test: + executor = None + if executor is None: + reason = ( + f"No executor matches '{ct.executor}'" + if ct.executor + else "No executor supports connection testing" + ) + ct.state = ConnectionTestState.FAILED + ct.result_message = reason + self.log.warning("Failing connection test %s: %s", ct.id, reason) + continue + + workload = workloads.TestConnection.make( + connection_test_id=ct.id, + connection_id=ct.connection_id, + timeout=timeout, + queue=ct.queue, + generator=executor.jwt_generator, + ) + executor.queue_workload(workload, session=session) + ct.state = ConnectionTestState.QUEUED + + session.flush() + + @provide_session + def _reap_stale_connection_tests(self, *, session: Session = NEW_SESSION) -> None: + """Mark connection tests that have exceeded their timeout as FAILED.""" + timeout = conf.getint("core", "connection_test_timeout", fallback=60) + grace_period = max(30, timeout // 2) + cutoff = timezone.utcnow() - timedelta(seconds=timeout + grace_period) + + stale_stmt = select(ConnectionTestRequest).where( + ConnectionTestRequest.state.in_(ACTIVE_STATES), + ConnectionTestRequest.updated_at < cutoff, + ) + stale_stmt = with_row_locks(stale_stmt, session, of=ConnectionTestRequest, skip_locked=True) + stale_tests = session.scalars(stale_stmt).all() + + for ct in stale_tests: + ct.state = ConnectionTestState.FAILED + ct.result_message = f"Connection test timed out (exceeded {timeout}s + {grace_period}s grace)" + self.log.warning("Reaped stale connection test %s", ct.id) + + session.flush() + def _executor_to_workloads( self, workloads: Iterable[SchedulerWorkload], diff --git a/airflow-core/src/airflow/migrations/versions/0110_3_2_0_add_connection_test_table.py b/airflow-core/src/airflow/migrations/versions/0110_3_2_0_add_connection_test_table.py new file mode 100644 index 0000000000000..d8d3c360eb0c9 --- /dev/null +++ b/airflow-core/src/airflow/migrations/versions/0110_3_2_0_add_connection_test_table.py @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Add connection_test_request table for async connection testing. + +Revision ID: a7e6d4c3b2f1 +Revises: 1d6611b6ab7c +Create Date: 2026-02-22 00:00:00.000000 + +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +from airflow.utils.sqlalchemy import UtcDateTime + +# revision identifiers, used by Alembic. +revision = "a7e6d4c3b2f1" +down_revision = "1d6611b6ab7c" +branch_labels = None +depends_on = None +airflow_version = "3.2.0" + + +def upgrade(): + """Create connection_test_request table.""" + op.create_table( + "connection_test_request", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("token", sa.String(64), nullable=False), + sa.Column("connection_id", sa.String(250), nullable=False), + sa.Column("state", sa.String(20), nullable=False), + sa.Column("result_message", sa.Text(), nullable=True), + sa.Column("created_at", UtcDateTime(timezone=True), nullable=False), + sa.Column("updated_at", UtcDateTime(timezone=True), nullable=False), + sa.Column("executor", sa.String(256), nullable=True), + sa.Column("queue", sa.String(256), nullable=True), + sa.Column("conn_type", sa.String(500), nullable=False), + sa.Column("host", sa.String(500), nullable=True), + sa.Column("login", sa.Text(), nullable=True), + sa.Column("password", sa.Text(), nullable=True), + sa.Column("schema", sa.String(500), nullable=True), + sa.Column("port", sa.Integer(), nullable=True), + sa.Column("extra", sa.Text(), nullable=True), + sa.Column("commit_on_success", sa.Boolean(), nullable=False, server_default="0"), + sa.PrimaryKeyConstraint("id", name=op.f("connection_test_request_pkey")), + sa.UniqueConstraint("token", name=op.f("connection_test_request_token_uq")), + ) + op.create_index( + op.f("idx_connection_test_request_state_created_at"), + "connection_test_request", + ["state", "created_at"], + ) + + +def downgrade(): + """Drop connection_test_request table.""" + op.drop_index(op.f("idx_connection_test_request_state_created_at"), table_name="connection_test_request") + op.drop_table("connection_test_request") diff --git a/airflow-core/src/airflow/models/__init__.py b/airflow-core/src/airflow/models/__init__.py index 8e12325f568eb..49a1c1f41129b 100644 --- a/airflow-core/src/airflow/models/__init__.py +++ b/airflow-core/src/airflow/models/__init__.py @@ -62,6 +62,7 @@ def import_all_models(): import airflow.models.asset import airflow.models.backfill + import airflow.models.connection_test import airflow.models.dag_favorite import airflow.models.dag_version import airflow.models.dagbag diff --git a/airflow-core/src/airflow/models/connection_test.py b/airflow-core/src/airflow/models/connection_test.py new file mode 100644 index 0000000000000..8f8ca4abdbbb2 --- /dev/null +++ b/airflow-core/src/airflow/models/connection_test.py @@ -0,0 +1,234 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import secrets +from datetime import datetime +from enum import Enum +from typing import TYPE_CHECKING +from uuid import UUID + +import structlog +import uuid6 +from sqlalchemy import Boolean, Index, Integer, String, Text, Uuid, select +from sqlalchemy.orm import Mapped, declared_attr, mapped_column, synonym + +from airflow._shared.timezones import timezone +from airflow.models.base import Base +from airflow.models.connection import Connection +from airflow.models.crypto import get_fernet +from airflow.utils.sqlalchemy import UtcDateTime + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + +log = structlog.get_logger(__name__) + + +class ConnectionTestState(str, Enum): + """All possible states of a connection test.""" + + PENDING = "pending" + QUEUED = "queued" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + + def __str__(self) -> str: + return self.value + + +ACTIVE_STATES = frozenset( + (ConnectionTestState.PENDING, ConnectionTestState.QUEUED, ConnectionTestState.RUNNING) +) +DISPATCHED_STATES = frozenset((ConnectionTestState.QUEUED, ConnectionTestState.RUNNING)) +TERMINAL_STATES = frozenset((ConnectionTestState.SUCCESS, ConnectionTestState.FAILED)) + + +class ConnectionTestRequest(Base): + """ + Tracks an async connection test request dispatched to a worker. + + Stores the full connection details so the worker reads from this table + instead of the real ``connection`` table. The real ``connection`` table + is only modified if the test succeeds and ``commit_on_success`` is True. + """ + + __tablename__ = "connection_test_request" + + id: Mapped[UUID] = mapped_column(Uuid(), primary_key=True, default=uuid6.uuid7) + token: Mapped[str] = mapped_column(String(64), nullable=False, unique=True) + connection_id: Mapped[str] = mapped_column(String(250), nullable=False) + state: Mapped[str] = mapped_column(String(20), nullable=False, default=ConnectionTestState.PENDING) + result_message: Mapped[str | None] = mapped_column(Text, nullable=True) + created_at: Mapped[datetime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) + updated_at: Mapped[datetime] = mapped_column( + UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False + ) + executor: Mapped[str | None] = mapped_column(String(256), nullable=True) + queue: Mapped[str | None] = mapped_column(String(256), nullable=True) + + # Connection fields — password and extra are Fernet-encrypted. + conn_type: Mapped[str] = mapped_column(String(500), nullable=False) + host: Mapped[str | None] = mapped_column(String(500), nullable=True) + login: Mapped[str | None] = mapped_column(Text, nullable=True) + _password: Mapped[str | None] = mapped_column("password", Text(), nullable=True) + schema: Mapped[str | None] = mapped_column("schema", String(500), nullable=True) + port: Mapped[int | None] = mapped_column(Integer, nullable=True) + _extra: Mapped[str | None] = mapped_column("extra", Text(), nullable=True) + commit_on_success: Mapped[bool] = mapped_column( + Boolean, nullable=False, default=False, server_default="0" + ) + + __table_args__ = (Index("idx_connection_test_request_state_created_at", state, created_at),) + + def __init__( + self, + *, + connection_id: str, + conn_type: str, + host: str | None = None, + login: str | None = None, + password: str | None = None, + schema: str | None = None, + port: int | None = None, + extra: str | None = None, + commit_on_success: bool = False, + executor: str | None = None, + queue: str | None = None, + **kwargs, + ): + super().__init__(**kwargs) + self.connection_id = connection_id + self.conn_type = conn_type + self.host = host + self.login = login + self.password = password + self.schema = schema + self.port = port + self.extra = extra + self.commit_on_success = commit_on_success + self.executor = executor + self.queue = queue + self.token = secrets.token_urlsafe(32) + self.state = ConnectionTestState.PENDING + + def __repr__(self) -> str: + return ( + f"" + ) + + def get_executor_name(self) -> str | None: + """Return the executor name for scheduler routing.""" + return self.executor + + def get_dag_id(self) -> None: + """Return None — connection tests are not associated with any DAG.""" + return None + + def get_password(self) -> str | None: + if self._password: + fernet = get_fernet() + if not fernet.is_encrypted: + return self._password + return fernet.decrypt(bytes(self._password, "utf-8")).decode() + return self._password + + def set_password(self, value: str | None): + if value: + fernet = get_fernet() + self._password = fernet.encrypt(bytes(value, "utf-8")).decode() + else: + self._password = value + + @declared_attr + def password(cls): + """Password. The value is decrypted/encrypted when reading/setting the value.""" + return synonym("_password", descriptor=property(cls.get_password, cls.set_password)) + + def get_extra(self) -> str | None: + if self._extra: + fernet = get_fernet() + if not fernet.is_encrypted: + return self._extra + return fernet.decrypt(bytes(self._extra, "utf-8")).decode() + return self._extra + + def set_extra(self, value: str | None): + if value: + fernet = get_fernet() + self._extra = fernet.encrypt(bytes(value, "utf-8")).decode() + else: + self._extra = value + + @declared_attr + def extra(cls): + """Extra data. The value is decrypted/encrypted when reading/setting the value.""" + return synonym("_extra", descriptor=property(cls.get_extra, cls.set_extra)) + + def to_connection(self) -> Connection: + """Build a transient Connection object from the stored fields for testing.""" + return Connection( + conn_id=self.connection_id, + conn_type=self.conn_type, + host=self.host, + login=self.login, + password=self.password, + schema=self.schema, + port=self.port, + extra=self.extra, + ) + + def commit_to_connection_table(self, *, session: Session) -> None: + """Upsert the tested connection into the real ``connection`` table.""" + conn = session.scalar(select(Connection).filter_by(conn_id=self.connection_id)) + if conn is None: + conn = Connection( + conn_id=self.connection_id, + conn_type=self.conn_type, + host=self.host, + login=self.login, + password=self.password, + schema=self.schema, + port=self.port, + extra=self.extra, + ) + session.add(conn) + log.info("Created new connection from successful test", connection_id=self.connection_id) + else: + conn.conn_type = self.conn_type + conn.host = self.host + conn.login = self.login + conn.password = self.password + conn.schema = self.schema + conn.port = self.port + conn.extra = self.extra + log.info("Updated existing connection from successful test", connection_id=self.connection_id) + + +def run_connection_test(*, conn: Connection) -> tuple[bool, str]: + """ + Worker-side function to execute a connection test. + + Returns a (success, message) tuple. The caller is responsible for + reporting the result back via the Execution API. + """ + try: + return conn.test_connection() + except Exception as e: + log.exception("Connection test failed", connection_id=conn.conn_id) + return False, str(e) diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts index 612a3d56747e3..c68fe4bd49e49 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts @@ -122,6 +122,12 @@ export const UseConnectionServiceGetConnectionsKeyFn = ({ connectionIdPattern, l offset?: number; orderBy?: string[]; } = {}, queryKey?: Array) => [useConnectionServiceGetConnectionsKey, ...(queryKey ?? [{ connectionIdPattern, limit, offset, orderBy }])]; +export type ConnectionServiceGetConnectionTestDefaultResponse = Awaited>; +export type ConnectionServiceGetConnectionTestQueryResult = UseQueryResult; +export const useConnectionServiceGetConnectionTestKey = "ConnectionServiceGetConnectionTest"; +export const UseConnectionServiceGetConnectionTestKeyFn = ({ connectionTestToken }: { + connectionTestToken: string; +}, queryKey?: Array) => [useConnectionServiceGetConnectionTestKey, ...(queryKey ?? [{ connectionTestToken }])]; export type ConnectionServiceHookMetaDataDefaultResponse = Awaited>; export type ConnectionServiceHookMetaDataQueryResult = UseQueryResult; export const useConnectionServiceHookMetaDataKey = "ConnectionServiceHookMetaData"; @@ -923,6 +929,7 @@ export type BackfillServiceCreateBackfillMutationResult = Awaited>; export type ConnectionServicePostConnectionMutationResult = Awaited>; export type ConnectionServiceTestConnectionMutationResult = Awaited>; +export type ConnectionServiceTestConnectionAsyncMutationResult = Awaited>; export type ConnectionServiceCreateDefaultConnectionsMutationResult = Awaited>; export type DagRunServiceClearDagRunMutationResult = Awaited>; export type DagRunServiceTriggerDagRunMutationResult = Awaited>; diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts index 8bc9e9df8e6df..e1751b9116ee1 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts @@ -225,6 +225,20 @@ export const ensureUseConnectionServiceGetConnectionsData = (queryClient: QueryC orderBy?: string[]; } = {}) => queryClient.ensureQueryData({ queryKey: Common.UseConnectionServiceGetConnectionsKeyFn({ connectionIdPattern, limit, offset, orderBy }), queryFn: () => ConnectionService.getConnections({ connectionIdPattern, limit, offset, orderBy }) }); /** +* Get Connection Test +* Poll for the status of an async connection test. +* +* Knowledge of the token serves as authorization — only the client +* that initiated the test knows the crypto-random token. +* @param data The data for the request. +* @param data.connectionTestToken +* @returns ConnectionTestStatusResponse Successful Response +* @throws ApiError +*/ +export const ensureUseConnectionServiceGetConnectionTestData = (queryClient: QueryClient, { connectionTestToken }: { + connectionTestToken: string; +}) => queryClient.ensureQueryData({ queryKey: Common.UseConnectionServiceGetConnectionTestKeyFn({ connectionTestToken }), queryFn: () => ConnectionService.getConnectionTest({ connectionTestToken }) }); +/** * Hook Meta Data * Retrieve information about available connection types (hook classes) and their parameters. * @returns ConnectionHookMetaData Successful Response diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts index f4cb6f482bdf6..65b9ef1a0dc1c 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts @@ -225,6 +225,20 @@ export const prefetchUseConnectionServiceGetConnections = (queryClient: QueryCli orderBy?: string[]; } = {}) => queryClient.prefetchQuery({ queryKey: Common.UseConnectionServiceGetConnectionsKeyFn({ connectionIdPattern, limit, offset, orderBy }), queryFn: () => ConnectionService.getConnections({ connectionIdPattern, limit, offset, orderBy }) }); /** +* Get Connection Test +* Poll for the status of an async connection test. +* +* Knowledge of the token serves as authorization — only the client +* that initiated the test knows the crypto-random token. +* @param data The data for the request. +* @param data.connectionTestToken +* @returns ConnectionTestStatusResponse Successful Response +* @throws ApiError +*/ +export const prefetchUseConnectionServiceGetConnectionTest = (queryClient: QueryClient, { connectionTestToken }: { + connectionTestToken: string; +}) => queryClient.prefetchQuery({ queryKey: Common.UseConnectionServiceGetConnectionTestKeyFn({ connectionTestToken }), queryFn: () => ConnectionService.getConnectionTest({ connectionTestToken }) }); +/** * Hook Meta Data * Retrieve information about available connection types (hook classes) and their parameters. * @returns ConnectionHookMetaData Successful Response diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts index 8e9ef5aa29d0c..dd341e301993c 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts @@ -2,7 +2,7 @@ import { UseMutationOptions, UseQueryOptions, useMutation, useQuery } from "@tanstack/react-query"; import { AssetService, AuthLinksService, BackfillService, CalendarService, ConfigService, ConnectionService, DagParsingService, DagRunService, DagService, DagSourceService, DagStatsService, DagVersionService, DagWarningService, DashboardService, DeadlinesService, DependenciesService, EventLogService, ExperimentalService, ExtraLinksService, GanttService, GridService, ImportErrorService, JobService, LoginService, MonitorService, PartitionedDagRunService, PluginService, PoolService, ProviderService, StructureService, TaskInstanceService, TaskService, TeamsService, VariableService, VersionService, XcomService } from "../requests/services.gen"; -import { BackfillPostBody, BulkBody_BulkTaskInstanceBody_, BulkBody_ConnectionBody_, BulkBody_PoolBody_, BulkBody_VariableBody_, ClearTaskInstancesBody, ConnectionBody, CreateAssetEventsBody, DAGPatchBody, DAGRunClearBody, DAGRunPatchBody, DAGRunsBatchBody, DagRunState, DagWarningType, GenerateTokenBody, PatchTaskInstanceBody, PoolBody, PoolPatchBody, TaskInstancesBatchBody, TriggerDAGRunPostBody, UpdateHITLDetailPayload, VariableBody, XComCreateBody, XComUpdateBody } from "../requests/types.gen"; +import { BackfillPostBody, BulkBody_BulkTaskInstanceBody_, BulkBody_ConnectionBody_, BulkBody_PoolBody_, BulkBody_VariableBody_, ClearTaskInstancesBody, ConnectionBody, ConnectionTestRequestBody, CreateAssetEventsBody, DAGPatchBody, DAGRunClearBody, DAGRunPatchBody, DAGRunsBatchBody, DagRunState, DagWarningType, GenerateTokenBody, PatchTaskInstanceBody, PoolBody, PoolPatchBody, TaskInstancesBatchBody, TriggerDAGRunPostBody, UpdateHITLDetailPayload, VariableBody, XComCreateBody, XComUpdateBody } from "../requests/types.gen"; import * as Common from "./common"; /** * Get Assets @@ -225,6 +225,20 @@ export const useConnectionServiceGetConnections = , "queryKey" | "queryFn">) => useQuery({ queryKey: Common.UseConnectionServiceGetConnectionsKeyFn({ connectionIdPattern, limit, offset, orderBy }, queryKey), queryFn: () => ConnectionService.getConnections({ connectionIdPattern, limit, offset, orderBy }) as TData, ...options }); /** +* Get Connection Test +* Poll for the status of an async connection test. +* +* Knowledge of the token serves as authorization — only the client +* that initiated the test knows the crypto-random token. +* @param data The data for the request. +* @param data.connectionTestToken +* @returns ConnectionTestStatusResponse Successful Response +* @throws ApiError +*/ +export const useConnectionServiceGetConnectionTest = = unknown[]>({ connectionTestToken }: { + connectionTestToken: string; +}, queryKey?: TQueryKey, options?: Omit, "queryKey" | "queryFn">) => useQuery({ queryKey: Common.UseConnectionServiceGetConnectionTestKeyFn({ connectionTestToken }, queryKey), queryFn: () => ConnectionService.getConnectionTest({ connectionTestToken }) as TData, ...options }); +/** * Hook Meta Data * Retrieve information about available connection types (hook classes) and their parameters. * @returns ConnectionHookMetaData Successful Response @@ -1828,6 +1842,23 @@ export const useConnectionServiceTestConnection = ({ mutationFn: ({ requestBody }) => ConnectionService.testConnection({ requestBody }) as unknown as Promise, ...options }); /** +* Test Connection Async +* Queue an async connection test to be executed on a worker. +* +* The connection data is stored in the test request table and the worker +* reads from there. Returns a token to poll for the result via +* GET /connections/test-async/{token}. +* @param data The data for the request. +* @param data.requestBody +* @returns ConnectionTestQueuedResponse Successful Response +* @throws ApiError +*/ +export const useConnectionServiceTestConnectionAsync = (options?: Omit, "mutationFn">) => useMutation({ mutationFn: ({ requestBody }) => ConnectionService.testConnectionAsync({ requestBody }) as unknown as Promise, ...options }); +/** * Create Default Connections * Create default connections. * @returns void Successful Response diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts index c4a41691b1a2e..e2f532998f5eb 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts @@ -225,6 +225,20 @@ export const useConnectionServiceGetConnectionsSuspense = , "queryKey" | "queryFn">) => useSuspenseQuery({ queryKey: Common.UseConnectionServiceGetConnectionsKeyFn({ connectionIdPattern, limit, offset, orderBy }, queryKey), queryFn: () => ConnectionService.getConnections({ connectionIdPattern, limit, offset, orderBy }) as TData, ...options }); /** +* Get Connection Test +* Poll for the status of an async connection test. +* +* Knowledge of the token serves as authorization — only the client +* that initiated the test knows the crypto-random token. +* @param data The data for the request. +* @param data.connectionTestToken +* @returns ConnectionTestStatusResponse Successful Response +* @throws ApiError +*/ +export const useConnectionServiceGetConnectionTestSuspense = = unknown[]>({ connectionTestToken }: { + connectionTestToken: string; +}, queryKey?: TQueryKey, options?: Omit, "queryKey" | "queryFn">) => useSuspenseQuery({ queryKey: Common.UseConnectionServiceGetConnectionTestKeyFn({ connectionTestToken }, queryKey), queryFn: () => ConnectionService.getConnectionTest({ connectionTestToken }) as TData, ...options }); +/** * Hook Meta Data * Retrieve information about available connection types (hook classes) and their parameters. * @returns ConnectionHookMetaData Successful Response diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts index f47aa107aea24..189e703f162ac 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -1698,6 +1698,138 @@ export const $ConnectionResponse = { description: 'Connection serializer for responses.' } as const; +export const $ConnectionTestQueuedResponse = { + properties: { + token: { + type: 'string', + title: 'Token' + }, + connection_id: { + type: 'string', + title: 'Connection Id' + }, + state: { + type: 'string', + title: 'State' + } + }, + type: 'object', + required: ['token', 'connection_id', 'state'], + title: 'ConnectionTestQueuedResponse', + description: 'Response returned when an async connection test is queued.' +} as const; + +export const $ConnectionTestRequestBody = { + properties: { + connection_id: { + type: 'string', + title: 'Connection Id' + }, + conn_type: { + type: 'string', + title: 'Conn Type' + }, + host: { + anyOf: [ + { + type: 'string' + }, + { + type: 'null' + } + ], + title: 'Host' + }, + login: { + anyOf: [ + { + type: 'string' + }, + { + type: 'null' + } + ], + title: 'Login' + }, + schema: { + anyOf: [ + { + type: 'string' + }, + { + type: 'null' + } + ], + title: 'Schema' + }, + port: { + anyOf: [ + { + type: 'integer' + }, + { + type: 'null' + } + ], + title: 'Port' + }, + password: { + anyOf: [ + { + type: 'string' + }, + { + type: 'null' + } + ], + title: 'Password' + }, + extra: { + anyOf: [ + { + type: 'string' + }, + { + type: 'null' + } + ], + title: 'Extra' + }, + commit_on_success: { + type: 'boolean', + title: 'Commit On Success', + default: false + }, + executor: { + anyOf: [ + { + type: 'string' + }, + { + type: 'null' + } + ], + title: 'Executor' + }, + queue: { + anyOf: [ + { + type: 'string' + }, + { + type: 'null' + } + ], + title: 'Queue' + } + }, + additionalProperties: false, + type: 'object', + required: ['connection_id', 'conn_type'], + title: 'ConnectionTestRequestBody', + description: 'Request body for async connection test.' +} as const; + export const $ConnectionTestResponse = { properties: { status: { @@ -1712,7 +1844,44 @@ export const $ConnectionTestResponse = { type: 'object', required: ['status', 'message'], title: 'ConnectionTestResponse', - description: 'Connection Test serializer for responses.' + description: 'Connection Test serializer for synchronous test responses.' +} as const; + +export const $ConnectionTestStatusResponse = { + properties: { + token: { + type: 'string', + title: 'Token' + }, + connection_id: { + type: 'string', + title: 'Connection Id' + }, + state: { + type: 'string', + title: 'State' + }, + result_message: { + anyOf: [ + { + type: 'string' + }, + { + type: 'null' + } + ], + title: 'Result Message' + }, + created_at: { + type: 'string', + format: 'date-time', + title: 'Created At' + } + }, + type: 'object', + required: ['token', 'connection_id', 'state', 'created_at'], + title: 'ConnectionTestStatusResponse', + description: 'Response returned when polling for async connection test status.' } as const; export const $CreateAssetEventsBody = { diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts index 6e701bf68dc6e..7a633a252face 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts @@ -3,7 +3,7 @@ import type { CancelablePromise } from './core/CancelablePromise'; import { OpenAPI } from './core/OpenAPI'; import { request as __request } from './core/request'; -import type { GetAssetsData, GetAssetsResponse, GetAssetAliasesData, GetAssetAliasesResponse, GetAssetAliasData, GetAssetAliasResponse, GetAssetEventsData, GetAssetEventsResponse, CreateAssetEventData, CreateAssetEventResponse, MaterializeAssetData, MaterializeAssetResponse, GetAssetQueuedEventsData, GetAssetQueuedEventsResponse, DeleteAssetQueuedEventsData, DeleteAssetQueuedEventsResponse, GetAssetData, GetAssetResponse, GetDagAssetQueuedEventsData, GetDagAssetQueuedEventsResponse, DeleteDagAssetQueuedEventsData, DeleteDagAssetQueuedEventsResponse, GetDagAssetQueuedEventData, GetDagAssetQueuedEventResponse, DeleteDagAssetQueuedEventData, DeleteDagAssetQueuedEventResponse, NextRunAssetsData, NextRunAssetsResponse, ListBackfillsData, ListBackfillsResponse, CreateBackfillData, CreateBackfillResponse, GetBackfillData, GetBackfillResponse, PauseBackfillData, PauseBackfillResponse, UnpauseBackfillData, UnpauseBackfillResponse, CancelBackfillData, CancelBackfillResponse, CreateBackfillDryRunData, CreateBackfillDryRunResponse, ListBackfillsUiData, ListBackfillsUiResponse, DeleteConnectionData, DeleteConnectionResponse, GetConnectionData, GetConnectionResponse, PatchConnectionData, PatchConnectionResponse, GetConnectionsData, GetConnectionsResponse, PostConnectionData, PostConnectionResponse, BulkConnectionsData, BulkConnectionsResponse, TestConnectionData, TestConnectionResponse, CreateDefaultConnectionsResponse, HookMetaDataResponse, GetDagRunData, GetDagRunResponse, DeleteDagRunData, DeleteDagRunResponse, PatchDagRunData, PatchDagRunResponse, GetUpstreamAssetEventsData, GetUpstreamAssetEventsResponse, ClearDagRunData, ClearDagRunResponse, GetDagRunsData, GetDagRunsResponse, TriggerDagRunData, TriggerDagRunResponse, WaitDagRunUntilFinishedData, WaitDagRunUntilFinishedResponse, GetListDagRunsBatchData, GetListDagRunsBatchResponse, GetDagSourceData, GetDagSourceResponse, GetDagStatsData, GetDagStatsResponse, GetConfigData, GetConfigResponse, GetConfigValueData, GetConfigValueResponse, GetConfigsResponse, ListDagWarningsData, ListDagWarningsResponse, GetDagsData, GetDagsResponse, PatchDagsData, PatchDagsResponse, GetDagData, GetDagResponse, PatchDagData, PatchDagResponse, DeleteDagData, DeleteDagResponse, GetDagDetailsData, GetDagDetailsResponse, FavoriteDagData, FavoriteDagResponse, UnfavoriteDagData, UnfavoriteDagResponse, GetDagTagsData, GetDagTagsResponse, GetDagsUiData, GetDagsUiResponse, GetLatestRunInfoData, GetLatestRunInfoResponse, GetEventLogData, GetEventLogResponse, GetEventLogsData, GetEventLogsResponse, GetExtraLinksData, GetExtraLinksResponse, GetTaskInstanceData, GetTaskInstanceResponse, PatchTaskInstanceData, PatchTaskInstanceResponse, DeleteTaskInstanceData, DeleteTaskInstanceResponse, GetMappedTaskInstancesData, GetMappedTaskInstancesResponse, GetTaskInstanceDependenciesByMapIndexData, GetTaskInstanceDependenciesByMapIndexResponse, GetTaskInstanceDependenciesData, GetTaskInstanceDependenciesResponse, GetTaskInstanceTriesData, GetTaskInstanceTriesResponse, GetMappedTaskInstanceTriesData, GetMappedTaskInstanceTriesResponse, GetMappedTaskInstanceData, GetMappedTaskInstanceResponse, PatchTaskInstanceByMapIndexData, PatchTaskInstanceByMapIndexResponse, GetTaskInstancesData, GetTaskInstancesResponse, BulkTaskInstancesData, BulkTaskInstancesResponse, GetTaskInstancesBatchData, GetTaskInstancesBatchResponse, GetTaskInstanceTryDetailsData, GetTaskInstanceTryDetailsResponse, GetMappedTaskInstanceTryDetailsData, GetMappedTaskInstanceTryDetailsResponse, PostClearTaskInstancesData, PostClearTaskInstancesResponse, PatchTaskInstanceDryRunByMapIndexData, PatchTaskInstanceDryRunByMapIndexResponse, PatchTaskInstanceDryRunData, PatchTaskInstanceDryRunResponse, GetLogData, GetLogResponse, GetExternalLogUrlData, GetExternalLogUrlResponse, UpdateHitlDetailData, UpdateHitlDetailResponse, GetHitlDetailData, GetHitlDetailResponse, GetHitlDetailTryDetailData, GetHitlDetailTryDetailResponse, GetHitlDetailsData, GetHitlDetailsResponse, GetImportErrorData, GetImportErrorResponse, GetImportErrorsData, GetImportErrorsResponse, GetJobsData, GetJobsResponse, GetPluginsData, GetPluginsResponse, ImportErrorsResponse, DeletePoolData, DeletePoolResponse, GetPoolData, GetPoolResponse, PatchPoolData, PatchPoolResponse, GetPoolsData, GetPoolsResponse, PostPoolData, PostPoolResponse, BulkPoolsData, BulkPoolsResponse, GetProvidersData, GetProvidersResponse, GetXcomEntryData, GetXcomEntryResponse, UpdateXcomEntryData, UpdateXcomEntryResponse, DeleteXcomEntryData, DeleteXcomEntryResponse, GetXcomEntriesData, GetXcomEntriesResponse, CreateXcomEntryData, CreateXcomEntryResponse, GetTasksData, GetTasksResponse, GetTaskData, GetTaskResponse, DeleteVariableData, DeleteVariableResponse, GetVariableData, GetVariableResponse, PatchVariableData, PatchVariableResponse, GetVariablesData, GetVariablesResponse, PostVariableData, PostVariableResponse, BulkVariablesData, BulkVariablesResponse, ReparseDagFileData, ReparseDagFileResponse, GetDagVersionData, GetDagVersionResponse, GetDagVersionsData, GetDagVersionsResponse, GetHealthResponse, GetVersionResponse, LoginData, LoginResponse, LogoutResponse, GetAuthMenusResponse, GetCurrentUserInfoResponse, GenerateTokenData, GenerateTokenResponse2, GetPartitionedDagRunsData, GetPartitionedDagRunsResponse, GetPendingPartitionedDagRunData, GetPendingPartitionedDagRunResponse, GetDependenciesData, GetDependenciesResponse, HistoricalMetricsData, HistoricalMetricsResponse, DagStatsResponse2, GetDagRunDeadlinesData, GetDagRunDeadlinesResponse, StructureDataData, StructureDataResponse2, GetDagStructureData, GetDagStructureResponse, GetGridRunsData, GetGridRunsResponse, GetGridTiSummariesStreamData, GetGridTiSummariesStreamResponse, GetGanttDataData, GetGanttDataResponse, GetCalendarData, GetCalendarResponse, ListTeamsData, ListTeamsResponse } from './types.gen'; +import type { GetAssetsData, GetAssetsResponse, GetAssetAliasesData, GetAssetAliasesResponse, GetAssetAliasData, GetAssetAliasResponse, GetAssetEventsData, GetAssetEventsResponse, CreateAssetEventData, CreateAssetEventResponse, MaterializeAssetData, MaterializeAssetResponse, GetAssetQueuedEventsData, GetAssetQueuedEventsResponse, DeleteAssetQueuedEventsData, DeleteAssetQueuedEventsResponse, GetAssetData, GetAssetResponse, GetDagAssetQueuedEventsData, GetDagAssetQueuedEventsResponse, DeleteDagAssetQueuedEventsData, DeleteDagAssetQueuedEventsResponse, GetDagAssetQueuedEventData, GetDagAssetQueuedEventResponse, DeleteDagAssetQueuedEventData, DeleteDagAssetQueuedEventResponse, NextRunAssetsData, NextRunAssetsResponse, ListBackfillsData, ListBackfillsResponse, CreateBackfillData, CreateBackfillResponse, GetBackfillData, GetBackfillResponse, PauseBackfillData, PauseBackfillResponse, UnpauseBackfillData, UnpauseBackfillResponse, CancelBackfillData, CancelBackfillResponse, CreateBackfillDryRunData, CreateBackfillDryRunResponse, ListBackfillsUiData, ListBackfillsUiResponse, DeleteConnectionData, DeleteConnectionResponse, GetConnectionData, GetConnectionResponse, PatchConnectionData, PatchConnectionResponse, GetConnectionsData, GetConnectionsResponse, PostConnectionData, PostConnectionResponse, BulkConnectionsData, BulkConnectionsResponse, TestConnectionData, TestConnectionResponse, TestConnectionAsyncData, TestConnectionAsyncResponse, GetConnectionTestData, GetConnectionTestResponse, CreateDefaultConnectionsResponse, HookMetaDataResponse, GetDagRunData, GetDagRunResponse, DeleteDagRunData, DeleteDagRunResponse, PatchDagRunData, PatchDagRunResponse, GetUpstreamAssetEventsData, GetUpstreamAssetEventsResponse, ClearDagRunData, ClearDagRunResponse, GetDagRunsData, GetDagRunsResponse, TriggerDagRunData, TriggerDagRunResponse, WaitDagRunUntilFinishedData, WaitDagRunUntilFinishedResponse, GetListDagRunsBatchData, GetListDagRunsBatchResponse, GetDagSourceData, GetDagSourceResponse, GetDagStatsData, GetDagStatsResponse, GetConfigData, GetConfigResponse, GetConfigValueData, GetConfigValueResponse, GetConfigsResponse, ListDagWarningsData, ListDagWarningsResponse, GetDagsData, GetDagsResponse, PatchDagsData, PatchDagsResponse, GetDagData, GetDagResponse, PatchDagData, PatchDagResponse, DeleteDagData, DeleteDagResponse, GetDagDetailsData, GetDagDetailsResponse, FavoriteDagData, FavoriteDagResponse, UnfavoriteDagData, UnfavoriteDagResponse, GetDagTagsData, GetDagTagsResponse, GetDagsUiData, GetDagsUiResponse, GetLatestRunInfoData, GetLatestRunInfoResponse, GetEventLogData, GetEventLogResponse, GetEventLogsData, GetEventLogsResponse, GetExtraLinksData, GetExtraLinksResponse, GetTaskInstanceData, GetTaskInstanceResponse, PatchTaskInstanceData, PatchTaskInstanceResponse, DeleteTaskInstanceData, DeleteTaskInstanceResponse, GetMappedTaskInstancesData, GetMappedTaskInstancesResponse, GetTaskInstanceDependenciesByMapIndexData, GetTaskInstanceDependenciesByMapIndexResponse, GetTaskInstanceDependenciesData, GetTaskInstanceDependenciesResponse, GetTaskInstanceTriesData, GetTaskInstanceTriesResponse, GetMappedTaskInstanceTriesData, GetMappedTaskInstanceTriesResponse, GetMappedTaskInstanceData, GetMappedTaskInstanceResponse, PatchTaskInstanceByMapIndexData, PatchTaskInstanceByMapIndexResponse, GetTaskInstancesData, GetTaskInstancesResponse, BulkTaskInstancesData, BulkTaskInstancesResponse, GetTaskInstancesBatchData, GetTaskInstancesBatchResponse, GetTaskInstanceTryDetailsData, GetTaskInstanceTryDetailsResponse, GetMappedTaskInstanceTryDetailsData, GetMappedTaskInstanceTryDetailsResponse, PostClearTaskInstancesData, PostClearTaskInstancesResponse, PatchTaskInstanceDryRunByMapIndexData, PatchTaskInstanceDryRunByMapIndexResponse, PatchTaskInstanceDryRunData, PatchTaskInstanceDryRunResponse, GetLogData, GetLogResponse, GetExternalLogUrlData, GetExternalLogUrlResponse, UpdateHitlDetailData, UpdateHitlDetailResponse, GetHitlDetailData, GetHitlDetailResponse, GetHitlDetailTryDetailData, GetHitlDetailTryDetailResponse, GetHitlDetailsData, GetHitlDetailsResponse, GetImportErrorData, GetImportErrorResponse, GetImportErrorsData, GetImportErrorsResponse, GetJobsData, GetJobsResponse, GetPluginsData, GetPluginsResponse, ImportErrorsResponse, DeletePoolData, DeletePoolResponse, GetPoolData, GetPoolResponse, PatchPoolData, PatchPoolResponse, GetPoolsData, GetPoolsResponse, PostPoolData, PostPoolResponse, BulkPoolsData, BulkPoolsResponse, GetProvidersData, GetProvidersResponse, GetXcomEntryData, GetXcomEntryResponse, UpdateXcomEntryData, UpdateXcomEntryResponse, DeleteXcomEntryData, DeleteXcomEntryResponse, GetXcomEntriesData, GetXcomEntriesResponse, CreateXcomEntryData, CreateXcomEntryResponse, GetTasksData, GetTasksResponse, GetTaskData, GetTaskResponse, DeleteVariableData, DeleteVariableResponse, GetVariableData, GetVariableResponse, PatchVariableData, PatchVariableResponse, GetVariablesData, GetVariablesResponse, PostVariableData, PostVariableResponse, BulkVariablesData, BulkVariablesResponse, ReparseDagFileData, ReparseDagFileResponse, GetDagVersionData, GetDagVersionResponse, GetDagVersionsData, GetDagVersionsResponse, GetHealthResponse, GetVersionResponse, LoginData, LoginResponse, LogoutResponse, GetAuthMenusResponse, GetCurrentUserInfoResponse, GenerateTokenData, GenerateTokenResponse2, GetPartitionedDagRunsData, GetPartitionedDagRunsResponse, GetPendingPartitionedDagRunData, GetPendingPartitionedDagRunResponse, GetDependenciesData, GetDependenciesResponse, HistoricalMetricsData, HistoricalMetricsResponse, DagStatsResponse2, GetDagRunDeadlinesData, GetDagRunDeadlinesResponse, StructureDataData, StructureDataResponse2, GetDagStructureData, GetDagStructureResponse, GetGridRunsData, GetGridRunsResponse, GetGridTiSummariesStreamData, GetGridTiSummariesStreamResponse, GetGanttDataData, GetGanttDataResponse, GetCalendarData, GetCalendarResponse, ListTeamsData, ListTeamsResponse } from './types.gen'; export class AssetService { /** @@ -632,6 +632,7 @@ export class ConnectionService { 401: 'Unauthorized', 403: 'Forbidden', 404: 'Not Found', + 409: 'Conflict', 422: 'Validation Error' } }); @@ -794,6 +795,60 @@ export class ConnectionService { }); } + /** + * Test Connection Async + * Queue an async connection test to be executed on a worker. + * + * The connection data is stored in the test request table and the worker + * reads from there. Returns a token to poll for the result via + * GET /connections/test-async/{token}. + * @param data The data for the request. + * @param data.requestBody + * @returns ConnectionTestQueuedResponse Successful Response + * @throws ApiError + */ + public static testConnectionAsync(data: TestConnectionAsyncData): CancelablePromise { + return __request(OpenAPI, { + method: 'POST', + url: '/api/v2/connections/test-async', + body: data.requestBody, + mediaType: 'application/json', + errors: { + 401: 'Unauthorized', + 403: 'Forbidden', + 409: 'Conflict', + 422: 'Validation Error' + } + }); + } + + /** + * Get Connection Test + * Poll for the status of an async connection test. + * + * Knowledge of the token serves as authorization — only the client + * that initiated the test knows the crypto-random token. + * @param data The data for the request. + * @param data.connectionTestToken + * @returns ConnectionTestStatusResponse Successful Response + * @throws ApiError + */ + public static getConnectionTest(data: GetConnectionTestData): CancelablePromise { + return __request(OpenAPI, { + method: 'GET', + url: '/api/v2/connections/test-async/{connection_test_token}', + path: { + connection_test_token: data.connectionTestToken + }, + errors: { + 401: 'Unauthorized', + 403: 'Forbidden', + 404: 'Not Found', + 422: 'Validation Error' + } + }); + } + /** * Create Default Connections * Create default connections. diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts index 7423c08d42c05..c56c0b59cb55e 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts @@ -496,13 +496,50 @@ export type ConnectionResponse = { }; /** - * Connection Test serializer for responses. + * Response returned when an async connection test is queued. + */ +export type ConnectionTestQueuedResponse = { + token: string; + connection_id: string; + state: string; +}; + +/** + * Request body for async connection test. + */ +export type ConnectionTestRequestBody = { + connection_id: string; + conn_type: string; + host?: string | null; + login?: string | null; + schema?: string | null; + port?: number | null; + password?: string | null; + extra?: string | null; + commit_on_success?: boolean; + executor?: string | null; + queue?: string | null; +}; + +/** + * Connection Test serializer for synchronous test responses. */ export type ConnectionTestResponse = { status: boolean; message: string; }; +/** + * Response returned when polling for async connection test status. + */ +export type ConnectionTestStatusResponse = { + token: string; + connection_id: string; + state: string; + result_message?: string | null; + created_at: string; +}; + /** * Create asset events request. */ @@ -2482,6 +2519,18 @@ export type TestConnectionData = { export type TestConnectionResponse = ConnectionTestResponse; +export type TestConnectionAsyncData = { + requestBody: ConnectionTestRequestBody; +}; + +export type TestConnectionAsyncResponse = ConnectionTestQueuedResponse; + +export type GetConnectionTestData = { + connectionTestToken: string; +}; + +export type GetConnectionTestResponse = ConnectionTestStatusResponse; + export type CreateDefaultConnectionsResponse = void; export type HookMetaDataResponse = Array; @@ -4308,6 +4357,10 @@ export type $OpenApiTs = { * Not Found */ 404: HTTPExceptionResponse; + /** + * Conflict + */ + 409: HTTPExceptionResponse; /** * Validation Error */ @@ -4465,6 +4518,60 @@ export type $OpenApiTs = { }; }; }; + '/api/v2/connections/test-async': { + post: { + req: TestConnectionAsyncData; + res: { + /** + * Successful Response + */ + 202: ConnectionTestQueuedResponse; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Conflict + */ + 409: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; + }; + '/api/v2/connections/test-async/{connection_test_token}': { + get: { + req: GetConnectionTestData; + res: { + /** + * Successful Response + */ + 200: ConnectionTestStatusResponse; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Not Found + */ + 404: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; + }; '/api/v2/connections/defaults': { post: { res: { diff --git a/airflow-core/src/airflow/utils/db.py b/airflow-core/src/airflow/utils/db.py index 9bc0608611b5a..5479aecf1d20b 100644 --- a/airflow-core/src/airflow/utils/db.py +++ b/airflow-core/src/airflow/utils/db.py @@ -115,7 +115,7 @@ class MappedClassProtocol(Protocol): "3.0.3": "fe199e1abd77", "3.1.0": "cc92b33c6709", "3.1.8": "509b94a1042d", - "3.2.0": "1d6611b6ab7c", + "3.2.0": "a7e6d4c3b2f1", } # Prefix used to identify tables holding data moved during migration. diff --git a/airflow-core/src/airflow/utils/db_cleanup.py b/airflow-core/src/airflow/utils/db_cleanup.py index e6b5283669b86..46fa983e7c1bd 100644 --- a/airflow-core/src/airflow/utils/db_cleanup.py +++ b/airflow-core/src/airflow/utils/db_cleanup.py @@ -172,6 +172,7 @@ def readable_config(self): ), _TableConfig(table_name="deadline", recency_column_name="deadline_time", dag_id_column_name="dag_id"), _TableConfig(table_name="revoked_token", recency_column_name="exp"), + _TableConfig(table_name="connection_test_request", recency_column_name="created_at"), ] # We need to have `fallback="database"` because this is executed at top level code and provider configuration diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py index 1f63247cfa9ab..fead99e94ca0b 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py @@ -29,12 +29,18 @@ from airflow.api_fastapi.core_api.datamodels.connections import ConnectionBody from airflow.api_fastapi.core_api.services.public.connections import BulkConnectionService from airflow.models import Connection +from airflow.models.connection_test import ConnectionTestRequest, ConnectionTestState from airflow.secrets.environment_variables import CONN_ENV_PREFIX from airflow.utils.session import NEW_SESSION, provide_session from tests_common.test_utils.api_fastapi import _check_last_log from tests_common.test_utils.asserts import assert_queries_count -from tests_common.test_utils.db import clear_db_connections, clear_db_logs, clear_test_connections +from tests_common.test_utils.db import ( + clear_db_connection_tests, + clear_db_connections, + clear_db_logs, + clear_test_connections, +) from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker pytestmark = pytest.mark.db_test @@ -94,10 +100,12 @@ class TestConnectionEndpoint: def setup(self) -> None: clear_test_connections(False) clear_db_connections(False) + clear_db_connection_tests() clear_db_logs() def teardown_method(self) -> None: clear_db_connections() + clear_db_connection_tests() def create_connection(self, team_name: str | None = None): _create_connection(team_name=team_name) @@ -1168,6 +1176,171 @@ def test_should_test_new_connection_without_existing(self, test_client): assert response.json()["status"] is True +class TestAsyncConnectionTest(TestConnectionEndpoint): + """Tests for the async connection test endpoints (POST + GET polling).""" + + TEST_REQUEST_BODY = { + "connection_id": TEST_CONN_ID, + "conn_type": TEST_CONN_TYPE, + "host": TEST_CONN_HOST, + } + + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_post_should_respond_202(self, test_client, session): + """POST /connections/test-async returns 202 + token.""" + response = test_client.post("/connections/test-async", json=self.TEST_REQUEST_BODY) + assert response.status_code == 202 + body = response.json() + assert "token" in body + assert body["connection_id"] == TEST_CONN_ID + assert body["state"] == "pending" + assert len(body["token"]) > 0 + + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.post("/connections/test-async", json=self.TEST_REQUEST_BODY) + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.post("/connections/test-async", json=self.TEST_REQUEST_BODY) + assert response.status_code == 403 + + def test_should_respond_403_by_default(self, test_client): + """Connection testing is disabled by default.""" + response = test_client.post("/connections/test-async", json=self.TEST_REQUEST_BODY) + assert response.status_code == 403 + + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_post_creates_connection_test_request_row(self, test_client, session): + """POST creates a ConnectionTestRequest row in PENDING state with connection fields.""" + response = test_client.post("/connections/test-async", json=self.TEST_REQUEST_BODY) + assert response.status_code == 202 + token = response.json()["token"] + + ct = session.scalar(select(ConnectionTestRequest).filter_by(token=token)) + assert ct is not None + assert ct.connection_id == TEST_CONN_ID + assert ct.conn_type == TEST_CONN_TYPE + assert ct.host == TEST_CONN_HOST + assert ct.state == "pending" + + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_post_passes_queue_parameter(self, test_client, session): + """POST /connections/test-async passes the queue parameter.""" + body = {**self.TEST_REQUEST_BODY, "queue": "gpu_workers"} + response = test_client.post("/connections/test-async", json=body) + assert response.status_code == 202 + token = response.json()["token"] + + ct = session.scalar(select(ConnectionTestRequest).filter_by(token=token)) + assert ct is not None + assert ct.queue == "gpu_workers" + + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_post_stores_commit_on_success(self, test_client, session): + """POST /connections/test-async stores the commit_on_success flag.""" + body = {**self.TEST_REQUEST_BODY, "commit_on_success": True} + response = test_client.post("/connections/test-async", json=body) + assert response.status_code == 202 + token = response.json()["token"] + + ct = session.scalar(select(ConnectionTestRequest).filter_by(token=token)) + assert ct is not None + assert ct.commit_on_success is True + + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_post_returns_409_for_duplicate_active_test(self, test_client, session): + """POST returns 409 when there's already an active test for the same connection_id.""" + response = test_client.post("/connections/test-async", json=self.TEST_REQUEST_BODY) + assert response.status_code == 202 + + response = test_client.post("/connections/test-async", json=self.TEST_REQUEST_BODY) + assert response.status_code == 409 + assert "async test is running" in response.json()["detail"] + + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_get_status_returns_pending(self, test_client, session): + """GET /connections/test-async/{token} returns current status.""" + post_response = test_client.post("/connections/test-async", json=self.TEST_REQUEST_BODY) + token = post_response.json()["token"] + + response = test_client.get(f"/connections/test-async/{token}") + assert response.status_code == 200 + body = response.json() + assert body["token"] == token + assert body["connection_id"] == TEST_CONN_ID + assert body["state"] == "pending" + assert body["result_message"] is None + assert "created_at" in body + assert "reverted" not in body + + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_get_status_returns_completed_result(self, test_client, session): + """GET returns result after the worker has updated the test.""" + post_response = test_client.post("/connections/test-async", json=self.TEST_REQUEST_BODY) + token = post_response.json()["token"] + + ct = session.scalar(select(ConnectionTestRequest).filter_by(token=token)) + ct.state = ConnectionTestState.SUCCESS + ct.result_message = "Connection successfully tested" + session.commit() + + response = test_client.get(f"/connections/test-async/{token}") + assert response.status_code == 200 + body = response.json() + assert body["state"] == "success" + assert body["result_message"] == "Connection successfully tested" + + def test_get_status_returns_404_for_invalid_token(self, test_client): + """GET with an unknown token returns 404.""" + response = test_client.get("/connections/test-async/nonexistent-token") + assert response.status_code == 404 + + +class TestBlockEditDeleteDuringActiveTest(TestConnectionEndpoint): + """Tests that edit/delete is blocked while an async test is running.""" + + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_patch_blocked_during_active_test(self, test_client, session): + """PATCH /{connection_id} returns 409 when an active test exists.""" + self.create_connection() + test_client.post( + "/connections/test-async", + json={ + "connection_id": TEST_CONN_ID, + "conn_type": TEST_CONN_TYPE, + "host": TEST_CONN_HOST, + }, + ) + + response = test_client.patch( + f"/connections/{TEST_CONN_ID}", + json={ + "connection_id": TEST_CONN_ID, + "conn_type": TEST_CONN_TYPE, + "host": "updated-host.example.com", + }, + ) + assert response.status_code == 409 + assert "async test is running" in response.json()["detail"] + + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_delete_blocked_during_active_test(self, test_client, session): + """DELETE /{connection_id} returns 409 when an active test exists.""" + self.create_connection() + test_client.post( + "/connections/test-async", + json={ + "connection_id": TEST_CONN_ID, + "conn_type": TEST_CONN_TYPE, + "host": TEST_CONN_HOST, + }, + ) + + response = test_client.delete(f"/connections/{TEST_CONN_ID}") + assert response.status_code == 409 + assert "async test is running" in response.json()["detail"] + + class TestCreateDefaultConnections(TestConnectionEndpoint): def test_should_respond_204(self, test_client, session): response = test_client.post("/connections/defaults") diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_connection_tests.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_connection_tests.py new file mode 100644 index 0000000000000..ac31cfd487159 --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_connection_tests.py @@ -0,0 +1,237 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest +from sqlalchemy import select + +from airflow.models.connection import Connection +from airflow.models.connection_test import ConnectionTestRequest, ConnectionTestState + +from tests_common.test_utils.db import clear_db_connection_tests, clear_db_connections + +pytestmark = pytest.mark.db_test + + +class TestPatchConnectionTest: + @pytest.fixture(autouse=True) + def setup_teardown(self): + clear_db_connection_tests() + yield + clear_db_connection_tests() + + def test_patch_updates_result(self, client, session): + """PATCH sets the state and result fields.""" + ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres") + ct.state = ConnectionTestState.RUNNING + session.add(ct) + session.commit() + + response = client.patch( + f"/execution/connection-tests/{ct.id}", + json={ + "state": "success", + "result_message": "Connection successfully tested", + }, + ) + assert response.status_code == 204 + + session.expire_all() + ct = session.get(ConnectionTestRequest, ct.id) + assert ct.state == "success" + assert ct.result_message == "Connection successfully tested" + + def test_patch_returns_404_for_nonexistent(self, client): + """PATCH with unknown id returns 404.""" + response = client.patch( + "/execution/connection-tests/00000000-0000-0000-0000-000000000000", + json={"state": "success", "result_message": "ok"}, + ) + assert response.status_code == 404 + + def test_patch_returns_422_for_invalid_uuid(self, client): + """PATCH with invalid uuid returns 422.""" + response = client.patch( + "/execution/connection-tests/not-a-uuid", + json={"state": "success", "result_message": "ok"}, + ) + assert response.status_code == 422 + + def test_patch_returns_409_for_terminal_state(self, client, session): + """PATCH on a test already in terminal state returns 409.""" + ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres") + ct.state = ConnectionTestState.SUCCESS + ct.result_message = "Already done" + session.add(ct) + session.commit() + + response = client.patch( + f"/execution/connection-tests/{ct.id}", + json={"state": "failed", "result_message": "retry"}, + ) + assert response.status_code == 409 + assert "terminal state" in response.json()["detail"]["message"] + + +class TestPatchConnectionTestCommitOnSuccess: + """Tests for the commit_on_success behavior in the execution API.""" + + @pytest.fixture(autouse=True) + def setup_teardown(self): + clear_db_connections(add_default_connections_back=False) + clear_db_connection_tests() + yield + clear_db_connections(add_default_connections_back=False) + clear_db_connection_tests() + + def test_success_with_commit_creates_connection(self, client, session): + """PATCH with state=success and commit_on_success creates a new connection.""" + ct = ConnectionTestRequest( + connection_id="new_conn", + conn_type="postgres", + host="db.example.com", + login="user", + password="secret", + commit_on_success=True, + ) + ct.state = ConnectionTestState.RUNNING + session.add(ct) + session.commit() + + response = client.patch( + f"/execution/connection-tests/{ct.id}", + json={"state": "success", "result_message": "Connection OK"}, + ) + assert response.status_code == 204 + + conn = session.scalar(select(Connection).filter_by(conn_id="new_conn")) + assert conn is not None + assert conn.conn_type == "postgres" + assert conn.host == "db.example.com" + + def test_success_with_commit_updates_existing(self, client, session): + """PATCH with state=success and commit_on_success updates an existing connection.""" + conn = Connection(conn_id="existing_conn", conn_type="http", host="old-host.example.com") + session.add(conn) + session.flush() + + ct = ConnectionTestRequest( + connection_id="existing_conn", + conn_type="postgres", + host="new-host.example.com", + login="new_user", + commit_on_success=True, + ) + ct.state = ConnectionTestState.RUNNING + session.add(ct) + session.commit() + + response = client.patch( + f"/execution/connection-tests/{ct.id}", + json={"state": "success", "result_message": "Connection OK"}, + ) + assert response.status_code == 204 + + session.expire_all() + conn = session.scalar(select(Connection).filter_by(conn_id="existing_conn")) + assert conn.conn_type == "postgres" + assert conn.host == "new-host.example.com" + + def test_success_without_commit_does_not_create(self, client, session): + """PATCH with state=success but commit_on_success=False does not create a connection.""" + ct = ConnectionTestRequest( + connection_id="no_commit_conn", + conn_type="postgres", + host="db.example.com", + commit_on_success=False, + ) + ct.state = ConnectionTestState.RUNNING + session.add(ct) + session.commit() + + response = client.patch( + f"/execution/connection-tests/{ct.id}", + json={"state": "success", "result_message": "Connection OK"}, + ) + assert response.status_code == 204 + + conn = session.scalar(select(Connection).filter_by(conn_id="no_commit_conn")) + assert conn is None + + def test_failed_with_commit_does_not_create(self, client, session): + """PATCH with state=failed and commit_on_success=True does NOT create a connection.""" + ct = ConnectionTestRequest( + connection_id="fail_conn", + conn_type="postgres", + host="db.example.com", + commit_on_success=True, + ) + ct.state = ConnectionTestState.RUNNING + session.add(ct) + session.commit() + + response = client.patch( + f"/execution/connection-tests/{ct.id}", + json={"state": "failed", "result_message": "Connection refused"}, + ) + assert response.status_code == 204 + + conn = session.scalar(select(Connection).filter_by(conn_id="fail_conn")) + assert conn is None + + +class TestGetConnectionTestConnection: + """Tests for the GET /{connection_test_id}/connection endpoint.""" + + @pytest.fixture(autouse=True) + def setup_teardown(self): + clear_db_connection_tests() + yield + clear_db_connection_tests() + + def test_get_connection_returns_data(self, client, session): + """GET returns decrypted connection data from the test request.""" + ct = ConnectionTestRequest( + connection_id="test_conn", + conn_type="postgres", + host="db.example.com", + login="user", + password="secret", + schema="mydb", + port=5432, + extra='{"key": "value"}', + ) + session.add(ct) + session.commit() + + response = client.get(f"/execution/connection-tests/{ct.id}/connection") + assert response.status_code == 200 + + data = response.json() + assert data["conn_id"] == "test_conn" + assert data["conn_type"] == "postgres" + assert data["host"] == "db.example.com" + assert data["login"] == "user" + assert data["password"] == "secret" + assert data["schema"] == "mydb" + assert data["port"] == 5432 + assert data["extra"] == '{"key": "value"}' + + def test_get_connection_returns_404_for_nonexistent(self, client): + """GET with unknown id returns 404.""" + response = client.get("/execution/connection-tests/00000000-0000-0000-0000-000000000000/connection") + assert response.status_code == 404 diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_03_31/test_connection_tests.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_03_31/test_connection_tests.py new file mode 100644 index 0000000000000..f907611f38ae9 --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_03_31/test_connection_tests.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.models.connection_test import ConnectionTestRequest, ConnectionTestState + +from tests_common.test_utils.db import clear_db_connection_tests + +pytestmark = pytest.mark.db_test + + +@pytest.fixture +def old_ver_client(client): + """Client configured to use API version before connection-tests endpoint was added.""" + client.headers["Airflow-API-Version"] = "2025-12-08" + return client + + +class TestConnectionTestEndpointVersioning: + """Test that the connection-tests endpoint didn't exist in older API versions.""" + + @pytest.fixture(autouse=True) + def setup_teardown(self): + clear_db_connection_tests() + yield + clear_db_connection_tests() + + def test_old_version_returns_404(self, old_ver_client, session): + """PATCH /connection-tests/{id} should not exist in older API versions.""" + ct = ConnectionTestRequest(conn_type="test_type", connection_id="test_conn") + ct.state = ConnectionTestState.RUNNING + session.add(ct) + session.commit() + + response = old_ver_client.patch( + f"/execution/connection-tests/{ct.id}", + json={"state": "success", "result_message": "ok"}, + ) + assert response.status_code == 404 + + def test_head_version_works(self, client, session): + """PATCH /connection-tests/{id} should work in the current API version.""" + ct = ConnectionTestRequest(conn_type="test_type", connection_id="test_conn") + ct.state = ConnectionTestState.RUNNING + session.add(ct) + session.commit() + + response = client.patch( + f"/execution/connection-tests/{ct.id}", + json={"state": "success", "result_message": "ok"}, + ) + assert response.status_code == 204 diff --git a/airflow-core/tests/unit/executors/test_base_executor.py b/airflow-core/tests/unit/executors/test_base_executor.py index fa0f311d018fe..8171e223fcffd 100644 --- a/airflow-core/tests/unit/executors/test_base_executor.py +++ b/airflow-core/tests/unit/executors/test_base_executor.py @@ -21,12 +21,13 @@ import textwrap from datetime import timedelta from unittest import mock -from uuid import UUID +from uuid import UUID, uuid4 import pendulum import pytest import structlog import time_machine +from sqlalchemy.orm import Session from airflow._shared.timezones import timezone from airflow.callbacks.callback_requests import CallbackRequest @@ -407,6 +408,43 @@ def test_repr(): assert repr(executor) == "BaseExecutor(parallelism=10, team_name='teamA')" +def test_supports_connection_test_default_value(): + assert not BaseExecutor.supports_connection_test + + +def test_queue_connection_test_workload_rejected_by_default(): + """BaseExecutor (supports_connection_test=False) rejects TestConnection workloads.""" + executor = BaseExecutor() + wl = workloads.TestConnection.make( + connection_test_id=uuid4(), + connection_id="test_conn", + ) + with pytest.raises(NotImplementedError, match="does not support TestConnection workloads"): + executor.queue_workload(wl, session=mock.MagicMock(spec=Session)) + + +def test_queue_connection_test_workload_accepted_when_supported(): + """An executor with supports_connection_test=True accepts TestConnection workloads.""" + executor = LocalExecutor() + executor.queued_connection_tests.clear() + wl = workloads.TestConnection.make( + connection_test_id=uuid4(), + connection_id="test_conn", + ) + executor.queue_workload(wl, session=mock.MagicMock(spec=Session)) + assert len(executor.queued_connection_tests) == 1 + assert executor.queued_connection_tests[str(wl.connection_test_id)] is wl + + +def test_trigger_connection_tests_skipped_when_not_supported(): + """trigger_connection_tests is a no-op when supports_connection_test is False.""" + executor = BaseExecutor() + executor.queued_connection_tests["dummy"] = mock.MagicMock(spec=workloads.TestConnection) + with mock.patch.object(executor, "_process_workloads") as mock_process: + executor.trigger_connection_tests() + mock_process.assert_not_called() + + @mock.patch.dict("os.environ", {}, clear=True) class TestExecutorConf: """Test ExecutorConf shim class that provides team-specific configuration access.""" diff --git a/airflow-core/tests/unit/executors/test_local_executor.py b/airflow-core/tests/unit/executors/test_local_executor.py index 59afffe6833fe..4bca67542ddf3 100644 --- a/airflow-core/tests/unit/executors/test_local_executor.py +++ b/airflow-core/tests/unit/executors/test_local_executor.py @@ -23,16 +23,20 @@ from unittest import mock import pytest +import structlog from kgb import spy_on from uuid6 import uuid7 from airflow._shared.timezones import timezone from airflow.executors import workloads -from airflow.executors.local_executor import LocalExecutor, _execute_work +from airflow.executors.base_executor import ExecutorConf +from airflow.executors.local_executor import LocalExecutor, _execute_connection_test, _execute_work from airflow.executors.workloads.base import BundleInfo from airflow.executors.workloads.callback import CallbackDTO from airflow.executors.workloads.task import TaskInstanceDTO from airflow.models.callback import CallbackFetchMethod +from airflow.models.connection_test import ConnectionTestState +from airflow.sdk.api.datamodels._generated import ConnectionResponse from airflow.settings import Session from airflow.utils.state import State @@ -376,6 +380,244 @@ def test_global_executor_without_team_name(self): executor.end() +class TestLocalExecutorConnectionTestSupport: + def test_supports_connection_test_flag_is_true(self): + executor = LocalExecutor() + assert executor.supports_connection_test is True + + +@mock.patch("airflow.executors.local_executor.signal", autospec=True) +@mock.patch("airflow.sdk.api.client.Client", autospec=True) +class TestLocalExecutorConnectionTestExecution: + def test_successful_connection_test(self, MockClient, _mock_signal): + """Fetches connection via Execution API, runs test, reports SUCCESS.""" + mock_client = MockClient.return_value + mock_client.connection_tests.get_connection.return_value = ConnectionResponse( + conn_id="test_conn", + conn_type="http", + host="httpbin.org", + port=443, + ) + + test_id = uuid7() + workload = workloads.TestConnection( + connection_test_id=test_id, + connection_id="test_conn", + timeout=60, + token="test-token", + ) + + with mock.patch( + "airflow.executors.local_executor.run_connection_test", + return_value=(True, "Connection OK"), + ): + _execute_connection_test( + mock.MagicMock(spec=structlog.typing.FilteringBoundLogger), + workload, + ExecutorConf(team_name=None), + ) + + calls = mock_client.connection_tests.update_state.call_args_list + assert len(calls) == 2 + assert calls[0].args == (test_id, ConnectionTestState.RUNNING) + assert calls[1].args == (test_id, ConnectionTestState.SUCCESS, "Connection OK") + + def test_failed_connection_test(self, MockClient, _mock_signal): + """Fetches connection via Execution API, test fails, reports FAILED.""" + mock_client = MockClient.return_value + mock_client.connection_tests.get_connection.return_value = ConnectionResponse( + conn_id="test_conn", + conn_type="postgres", + host="db.example.com", + ) + + test_id = uuid7() + workload = workloads.TestConnection( + connection_test_id=test_id, + connection_id="test_conn", + timeout=60, + token="test-token", + ) + + with mock.patch( + "airflow.executors.local_executor.run_connection_test", + return_value=(False, "Connection refused"), + ): + _execute_connection_test( + mock.MagicMock(spec=structlog.typing.FilteringBoundLogger), + workload, + ExecutorConf(team_name=None), + ) + + calls = mock_client.connection_tests.update_state.call_args_list + assert len(calls) == 2 + assert calls[0].args == (test_id, ConnectionTestState.RUNNING) + assert calls[1].args == (test_id, ConnectionTestState.FAILED, "Connection refused") + + def test_connection_not_found_via_execution_api(self, MockClient, _mock_signal): + """Reports FAILED when connection test is not found via Execution API.""" + mock_client = MockClient.return_value + mock_client.connection_tests.get_connection.side_effect = RuntimeError("Connection test not found") + + test_id = uuid7() + workload = workloads.TestConnection( + connection_test_id=test_id, + connection_id="missing_conn", + timeout=60, + token="test-token", + ) + + _execute_connection_test( + mock.MagicMock(spec=structlog.typing.FilteringBoundLogger), + workload, + ExecutorConf(team_name=None), + ) + + calls = mock_client.connection_tests.update_state.call_args_list + assert calls[-1].args[1] == ConnectionTestState.FAILED + assert "Connection test failed unexpectedly" in calls[-1].args[2] + + def test_unexpected_exception_reports_failed(self, MockClient, _mock_signal): + """Reports FAILED when an unexpected exception occurs.""" + mock_client = MockClient.return_value + mock_client.connection_tests.get_connection.return_value = ConnectionResponse( + conn_id="test_conn", + conn_type="http", + ) + + test_id = uuid7() + workload = workloads.TestConnection( + connection_test_id=test_id, + connection_id="test_conn", + timeout=60, + token="test-token", + ) + + with mock.patch( + "airflow.executors.local_executor.run_connection_test", + side_effect=RuntimeError("Something broke"), + ): + _execute_connection_test( + mock.MagicMock(spec=structlog.typing.FilteringBoundLogger), + workload, + ExecutorConf(team_name=None), + ) + + calls = mock_client.connection_tests.update_state.call_args_list + assert calls[-1].args[1] == ConnectionTestState.FAILED + assert "Connection test failed unexpectedly: RuntimeError" in calls[-1].args[2] + + def test_connection_fields_passed_correctly(self, MockClient, _mock_signal): + """Verifies all connection fields from the API response are passed to Connection.""" + mock_client = MockClient.return_value + mock_client.connection_tests.get_connection.return_value = ConnectionResponse( + conn_id="full_conn", + conn_type="postgres", + host="db.example.com", + login="admin", + password="s3cret", + schema="mydb", + port=5432, + extra='{"sslmode": "require"}', + ) + + workload = workloads.TestConnection( + connection_test_id=uuid7(), + connection_id="full_conn", + timeout=60, + token="test-token", + ) + + captured_conn = None + + def capture_conn(*, conn): + nonlocal captured_conn + captured_conn = conn + return True, "OK" + + with mock.patch( + "airflow.executors.local_executor.run_connection_test", + side_effect=capture_conn, + ): + _execute_connection_test( + mock.MagicMock(spec=structlog.typing.FilteringBoundLogger), + workload, + ExecutorConf(team_name=None), + ) + + assert captured_conn is not None + assert captured_conn.conn_id == "full_conn" + assert captured_conn.conn_type == "postgres" + assert captured_conn.host == "db.example.com" + assert captured_conn.login == "admin" + assert captured_conn.password == "s3cret" + assert captured_conn.schema == "mydb" + assert captured_conn.port == 5432 + assert captured_conn.extra == '{"sslmode": "require"}' + + def test_timeout_reports_failed(self, MockClient, _mock_signal): + """Reports FAILED with timeout message when TimeoutError is raised.""" + mock_client = MockClient.return_value + mock_client.connection_tests.get_connection.return_value = ConnectionResponse( + conn_id="test_conn", + conn_type="http", + ) + + test_id = uuid7() + workload = workloads.TestConnection( + connection_test_id=test_id, + connection_id="test_conn", + timeout=30, + token="test-token", + ) + + def raise_timeout(*, conn): + raise TimeoutError("Connection test timed out after 30s") + + with mock.patch( + "airflow.executors.local_executor.run_connection_test", + side_effect=raise_timeout, + ): + _execute_connection_test( + mock.MagicMock(spec=structlog.typing.FilteringBoundLogger), + workload, + ExecutorConf(team_name=None), + ) + + calls = mock_client.connection_tests.update_state.call_args_list + assert calls[-1].args[1] == ConnectionTestState.FAILED + assert "timed out" in calls[-1].args[2] + + def test_alarm_is_cancelled_in_finally(self, MockClient, mock_signal): + """signal.alarm(0) is called to cancel the timer even on success.""" + mock_client = MockClient.return_value + mock_client.connection_tests.get_connection.return_value = ConnectionResponse( + conn_id="test_conn", + conn_type="http", + ) + + workload = workloads.TestConnection( + connection_test_id=uuid7(), + connection_id="test_conn", + timeout=60, + token="test-token", + ) + + with mock.patch( + "airflow.executors.local_executor.run_connection_test", + return_value=(True, "OK"), + ): + _execute_connection_test( + mock.MagicMock(spec=structlog.typing.FilteringBoundLogger), + workload, + ExecutorConf(team_name=None), + ) + + alarm_calls = mock_signal.alarm.call_args_list + assert alarm_calls[0].args == (60,) + assert alarm_calls[-1].args == (0,) + + class TestLocalExecutorCallbackSupport: def test_supports_callbacks_flag_is_true(self): executor = LocalExecutor() diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 997ebd24bcf0d..10e0e2c0db6ea 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -67,6 +67,7 @@ ) from airflow.models.backfill import Backfill, _create_backfill from airflow.models.callback import ExecutorCallback +from airflow.models.connection_test import ConnectionTestRequest, ConnectionTestState from airflow.models.dag import DagModel, get_last_dagrun, infer_automated_data_interval from airflow.models.dag_version import DagVersion from airflow.models.dagbundle import DagBundleModel @@ -9445,3 +9446,450 @@ def test_fallback_values_used_only_when_dag_version_is_none(self): assert _extract_bundle_name(ti) == "fallback-bundle" assert _extract_bundle_version(ti) == "fallback-v1" + + +@pytest.fixture +def scheduler_job_runner_for_connection_tests(session): + """Create a SchedulerJobRunner with a mock Job and supporting executor.""" + session.execute(delete(ConnectionTestRequest)) + session.commit() + + mock_job = mock.MagicMock(spec=Job) + mock_job.id = 1 + mock_job.max_tis_per_query = 16 + executor = LocalExecutor() + executor.queued_connection_tests.clear() + runner = SchedulerJobRunner.__new__(SchedulerJobRunner) + runner.job = mock_job + runner.executors = [executor] + runner.executor = executor + runner._log = mock.MagicMock(spec=logging.Logger) + yield runner + session.execute(delete(ConnectionTestRequest)) + session.commit() + + +class TestDispatchConnectionTests: + @mock.patch.dict( + os.environ, + { + "AIRFLOW__CORE__MAX_CONNECTION_TEST_CONCURRENCY": "4", + "AIRFLOW__CORE__CONNECTION_TEST_TIMEOUT": "60", + }, + ) + def test_dispatch_pending_tests(self, scheduler_job_runner_for_connection_tests, session): + """Pending connection tests are dispatched to a supporting executor.""" + ct = ConnectionTestRequest(conn_type="test_type", connection_id="test_conn") + session.add(ct) + session.commit() + assert ct.state == ConnectionTestState.PENDING + + scheduler_job_runner_for_connection_tests._enqueue_connection_tests(session=session) + + session.expire_all() + ct = session.get(ConnectionTestRequest, ct.id) + assert ct.state == ConnectionTestState.QUEUED + assert len(scheduler_job_runner_for_connection_tests.executor.queued_connection_tests) == 1 + + @mock.patch.dict( + os.environ, + { + "AIRFLOW__CORE__MAX_CONNECTION_TEST_CONCURRENCY": "1", + "AIRFLOW__CORE__CONNECTION_TEST_TIMEOUT": "60", + }, + ) + def test_dispatch_respects_concurrency_limit(self, scheduler_job_runner_for_connection_tests, session): + """Excess pending tests stay PENDING when concurrency is at capacity.""" + ct_active = ConnectionTestRequest(conn_type="test_type", connection_id="active_conn") + ct_active.state = ConnectionTestState.QUEUED + session.add(ct_active) + + ct_pending = ConnectionTestRequest(conn_type="test_type", connection_id="pending_conn") + session.add(ct_pending) + session.commit() + + scheduler_job_runner_for_connection_tests._enqueue_connection_tests(session=session) + + session.expire_all() + ct_pending = session.get(ConnectionTestRequest, ct_pending.id) + assert ct_pending.state == ConnectionTestState.PENDING + + @mock.patch.dict( + os.environ, + { + "AIRFLOW__CORE__MAX_CONNECTION_TEST_CONCURRENCY": "4", + "AIRFLOW__CORE__CONNECTION_TEST_TIMEOUT": "60", + }, + ) + def test_dispatch_fails_fast_when_no_executor_supports( + self, scheduler_job_runner_for_connection_tests, session + ): + """Tests fail immediately when no executor supports connection testing.""" + unsupporting_executor = BaseExecutor() + unsupporting_executor.supports_connection_test = False + scheduler_job_runner_for_connection_tests.executors = [unsupporting_executor] + scheduler_job_runner_for_connection_tests.executor = unsupporting_executor + + ct = ConnectionTestRequest(conn_type="test_type", connection_id="test_conn") + session.add(ct) + session.commit() + + scheduler_job_runner_for_connection_tests._enqueue_connection_tests(session=session) + + session.expire_all() + ct = session.get(ConnectionTestRequest, ct.id) + assert ct.state == ConnectionTestState.FAILED + assert "No executor supports connection testing" in ct.result_message + + @mock.patch.dict( + os.environ, + { + "AIRFLOW__CORE__MAX_CONNECTION_TEST_CONCURRENCY": "4", + "AIRFLOW__CORE__CONNECTION_TEST_TIMEOUT": "60", + }, + ) + def test_dispatch_with_unmatched_executor_fails_fast( + self, scheduler_job_runner_for_connection_tests, session + ): + """Tests requesting an executor with no match are failed immediately.""" + ct = ConnectionTestRequest(conn_type="test_type", connection_id="test_conn", executor="gpu_workers") + session.add(ct) + session.commit() + + scheduler_job_runner_for_connection_tests._enqueue_connection_tests(session=session) + + session.expire_all() + ct = session.get(ConnectionTestRequest, ct.id) + assert ct.state == ConnectionTestState.FAILED + assert "gpu_workers" in ct.result_message + + @mock.patch.dict( + os.environ, + { + "AIRFLOW__CORE__MAX_CONNECTION_TEST_CONCURRENCY": "3", + "AIRFLOW__CORE__CONNECTION_TEST_TIMEOUT": "60", + }, + ) + def test_dispatch_budget_dispatches_up_to_remaining_slots( + self, scheduler_job_runner_for_connection_tests, session + ): + """When 1 slot is occupied, only budget (cap - active) pending tests are dispatched.""" + ct_active = ConnectionTestRequest(conn_type="test_type", connection_id="active_conn") + ct_active.state = ConnectionTestState.RUNNING + session.add(ct_active) + + pending_tests = [] + for i in range(3): + ct = ConnectionTestRequest(conn_type="test_type", connection_id=f"pending_{i}") + session.add(ct) + pending_tests.append(ct) + session.commit() + pending_ids = [ct.id for ct in pending_tests] + + scheduler_job_runner_for_connection_tests._enqueue_connection_tests(session=session) + + session.expire_all() + states = [session.get(ConnectionTestRequest, pid).state for pid in pending_ids] + assert states.count(ConnectionTestState.QUEUED) == 2 + assert states.count(ConnectionTestState.PENDING) == 1 + + @mock.patch.dict( + os.environ, + { + "AIRFLOW__CORE__MAX_CONNECTION_TEST_CONCURRENCY": "2", + "AIRFLOW__CORE__CONNECTION_TEST_TIMEOUT": "60", + }, + ) + def test_dispatch_order_is_fifo_by_created_at(self, scheduler_job_runner_for_connection_tests, session): + """Pending tests are dispatched in FIFO order based on created_at.""" + initial_time = timezone.utcnow() + + with time_machine.travel(initial_time - timedelta(minutes=5), tick=False): + ct_old = ConnectionTestRequest(conn_type="test_type", connection_id="old_conn") + session.add(ct_old) + session.flush() + + with time_machine.travel(initial_time, tick=False): + ct_new = ConnectionTestRequest(conn_type="test_type", connection_id="new_conn") + session.add(ct_new) + session.flush() + + with time_machine.travel(initial_time + timedelta(minutes=1), tick=False): + ct_newest = ConnectionTestRequest(conn_type="test_type", connection_id="newest_conn") + session.add(ct_newest) + session.flush() + + session.commit() + + scheduler_job_runner_for_connection_tests._enqueue_connection_tests(session=session) + + session.expire_all() + assert session.get(ConnectionTestRequest, ct_old.id).state == ConnectionTestState.QUEUED + assert session.get(ConnectionTestRequest, ct_new.id).state == ConnectionTestState.QUEUED + assert session.get(ConnectionTestRequest, ct_newest.id).state == ConnectionTestState.PENDING + + @mock.patch.dict( + os.environ, + { + "AIRFLOW__CORE__MAX_CONNECTION_TEST_CONCURRENCY": "4", + "AIRFLOW__CORE__CONNECTION_TEST_TIMEOUT": "60", + }, + ) + def test_dispatch_fails_fast_for_unserved_executor( + self, scheduler_job_runner_for_connection_tests, session + ): + """Tests requesting an executor no team serves are failed immediately.""" + with mock.patch.object( + scheduler_job_runner_for_connection_tests, + "_try_to_load_executor", + return_value=None, + ): + ct = ConnectionTestRequest( + conn_type="test_type", connection_id="test_conn", executor="nonexistent_executor" + ) + session.add(ct) + session.commit() + + scheduler_job_runner_for_connection_tests._enqueue_connection_tests(session=session) + + session.expire_all() + ct = session.get(ConnectionTestRequest, ct.id) + assert ct.state == ConnectionTestState.FAILED + assert "nonexistent_executor" in ct.result_message + + @mock.patch.dict( + os.environ, + { + "AIRFLOW__CORE__MAX_CONNECTION_TEST_CONCURRENCY": "4", + "AIRFLOW__CORE__CONNECTION_TEST_TIMEOUT": "60", + }, + ) + def test_dispatch_executor_matched_by_alias(self, session): + """When executor is specified, the executor whose name.alias matches is selected.""" + session.execute(delete(ConnectionTestRequest)) + session.commit() + + mock_job = mock.MagicMock(spec=Job) + mock_job.id = 1 + mock_job.max_tis_per_query = 16 + + executor_a = LocalExecutor() + executor_a.name = ExecutorName(module_path="path.to.ExecutorA", alias="executor_a") + executor_a.queued_connection_tests.clear() + + executor_b = LocalExecutor() + executor_b.name = ExecutorName(module_path="path.to.ExecutorB", alias="executor_b") + executor_b.queued_connection_tests.clear() + + runner = SchedulerJobRunner.__new__(SchedulerJobRunner) + runner.job = mock_job + runner.executors = [executor_a, executor_b] + runner.executor = executor_a + runner._log = mock.MagicMock(spec=logging.Logger) + + ct = ConnectionTestRequest(conn_type="test_type", connection_id="team_conn", executor="executor_b") + session.add(ct) + session.commit() + + runner._enqueue_connection_tests(session=session) + + assert len(executor_b.queued_connection_tests) == 1 + assert len(executor_a.queued_connection_tests) == 0 + + @mock.patch.dict( + os.environ, + { + "AIRFLOW__CORE__MAX_CONNECTION_TEST_CONCURRENCY": "4", + "AIRFLOW__CORE__CONNECTION_TEST_TIMEOUT": "60", + }, + ) + def test_dispatch_executor_matched_by_module_path(self, session): + """When executor is specified by module_path, the matching executor is selected.""" + session.execute(delete(ConnectionTestRequest)) + session.commit() + + mock_job = mock.MagicMock(spec=Job) + mock_job.id = 1 + mock_job.max_tis_per_query = 16 + + executor_a = LocalExecutor() + executor_a.name = ExecutorName(module_path="path.to.ExecutorA", alias="executor_a") + executor_a.queued_connection_tests.clear() + + executor_b = LocalExecutor() + executor_b.name = ExecutorName(module_path="path.to.ExecutorB", alias="executor_b") + executor_b.queued_connection_tests.clear() + + runner = SchedulerJobRunner.__new__(SchedulerJobRunner) + runner.job = mock_job + runner.executors = [executor_a, executor_b] + runner.executor = executor_a + runner._log = mock.MagicMock(spec=logging.Logger) + + ct = ConnectionTestRequest( + conn_type="test_type", connection_id="team_conn", executor="path.to.ExecutorB" + ) + session.add(ct) + session.commit() + + runner._enqueue_connection_tests(session=session) + + assert len(executor_b.queued_connection_tests) == 1 + assert len(executor_a.queued_connection_tests) == 0 + + def test_dispatch_executor_matched_by_class_name(self, session): + """When executor is specified by class name only, the matching executor is selected.""" + session.execute(delete(ConnectionTestRequest)) + session.commit() + + mock_job = mock.MagicMock(spec=Job) + mock_job.id = 1 + mock_job.max_tis_per_query = 16 + + executor_a = LocalExecutor() + executor_a.name = ExecutorName(module_path="path.to.ExecutorA", alias="executor_a") + executor_a.queued_connection_tests.clear() + + executor_b = LocalExecutor() + executor_b.name = ExecutorName(module_path="path.to.ExecutorB", alias="executor_b") + executor_b.queued_connection_tests.clear() + + runner = SchedulerJobRunner.__new__(SchedulerJobRunner) + runner.job = mock_job + runner.executors = [executor_a, executor_b] + runner.executor = executor_a + runner._log = mock.MagicMock(spec=logging.Logger) + + ct = ConnectionTestRequest(conn_type="test_type", connection_id="team_conn", executor="ExecutorB") + session.add(ct) + session.commit() + + runner._enqueue_connection_tests(session=session) + + assert len(executor_b.queued_connection_tests) == 1 + assert len(executor_a.queued_connection_tests) == 0 + + @mock.patch.dict( + os.environ, + { + "AIRFLOW__CORE__MAX_CONNECTION_TEST_CONCURRENCY": "4", + "AIRFLOW__CORE__CONNECTION_TEST_TIMEOUT": "60", + "AIRFLOW__CORE__PARALLELISM": "1", + }, + ) + def test_dispatch_respects_parallelism_budget(self, scheduler_job_runner_for_connection_tests, session): + """Connection tests are not dispatched when core.parallelism is exhausted.""" + executor = scheduler_job_runner_for_connection_tests.executor + # Simulate 1 running task so all parallelism slots are occupied + executor.running = {"fake_task_key"} + + ct = ConnectionTestRequest(conn_type="test_type", connection_id="test_conn") + session.add(ct) + session.commit() + + scheduler_job_runner_for_connection_tests._enqueue_connection_tests(session=session) + + session.expire_all() + ct = session.get(ConnectionTestRequest, ct.id) + assert ct.state == ConnectionTestState.PENDING + + @mock.patch.dict( + os.environ, + { + "AIRFLOW__CORE__MAX_CONNECTION_TEST_CONCURRENCY": "4", + "AIRFLOW__CORE__CONNECTION_TEST_TIMEOUT": "60", + }, + ) + def test_dispatch_fails_when_executor_does_not_support_connection_test( + self, scheduler_job_runner_for_connection_tests, session + ): + """When the resolved executor does not support connection tests, the test is failed gracefully.""" + executor = scheduler_job_runner_for_connection_tests.executor + executor.supports_connection_test = False + + ct = ConnectionTestRequest(conn_type="test_type", connection_id="test_conn") + session.add(ct) + session.commit() + + scheduler_job_runner_for_connection_tests._enqueue_connection_tests(session=session) + + session.expire_all() + ct = session.get(ConnectionTestRequest, ct.id) + assert ct.state == ConnectionTestState.FAILED + assert "No executor supports connection testing" in ct.result_message + + +class TestReapStaleConnectionTests: + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__CONNECTION_TEST_TIMEOUT": "60"}) + def test_reap_stale_queued_test(self, scheduler_job_runner_for_connection_tests, session): + """Stale QUEUED tests are marked as FAILED by the reaper.""" + initial_time = timezone.utcnow() + + with time_machine.travel(initial_time, tick=False): + ct = ConnectionTestRequest(conn_type="test_type", connection_id="test_conn") + ct.state = ConnectionTestState.QUEUED + session.add(ct) + session.commit() + + with time_machine.travel(initial_time + timedelta(seconds=200), tick=False): + scheduler_job_runner_for_connection_tests._reap_stale_connection_tests(session=session) + + session.expire_all() + ct = session.get(ConnectionTestRequest, ct.id) + assert ct.state == ConnectionTestState.FAILED + assert "timed out" in ct.result_message + + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__CONNECTION_TEST_TIMEOUT": "60"}) + def test_does_not_reap_fresh_tests(self, scheduler_job_runner_for_connection_tests, session): + """Fresh QUEUED tests are not reaped.""" + ct = ConnectionTestRequest(conn_type="test_type", connection_id="test_conn") + ct.state = ConnectionTestState.QUEUED + session.add(ct) + session.commit() + + scheduler_job_runner_for_connection_tests._reap_stale_connection_tests(session=session) + + session.expire_all() + ct = session.get(ConnectionTestRequest, ct.id) + assert ct.state == ConnectionTestState.QUEUED + + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__CONNECTION_TEST_TIMEOUT": "60"}) + def test_reap_stale_running_test(self, scheduler_job_runner_for_connection_tests, session): + """Stale RUNNING tests are also reaped by the reaper.""" + initial_time = timezone.utcnow() + with time_machine.travel(initial_time, tick=False): + ct = ConnectionTestRequest(conn_type="test_type", connection_id="running_conn") + ct.state = ConnectionTestState.RUNNING + session.add(ct) + session.commit() + + with time_machine.travel(initial_time + timedelta(seconds=200), tick=False): + scheduler_job_runner_for_connection_tests._reap_stale_connection_tests(session=session) + + session.expire_all() + ct = session.get(ConnectionTestRequest, ct.id) + assert ct.state == ConnectionTestState.FAILED + assert "timed out" in ct.result_message + + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__CONNECTION_TEST_TIMEOUT": "60"}) + def test_reaper_ignores_terminal_states(self, scheduler_job_runner_for_connection_tests, session): + """Tests in terminal states (SUCCESS, FAILED) are not touched by the reaper.""" + initial_time = timezone.utcnow() + with time_machine.travel(initial_time, tick=False): + ct_success = ConnectionTestRequest(conn_type="test_type", connection_id="success_conn") + ct_success.state = ConnectionTestState.SUCCESS + ct_success.result_message = "OK" + session.add(ct_success) + + ct_failed = ConnectionTestRequest(conn_type="test_type", connection_id="failed_conn") + ct_failed.state = ConnectionTestState.FAILED + ct_failed.result_message = "Error" + session.add(ct_failed) + session.commit() + + with time_machine.travel(initial_time + timedelta(seconds=200), tick=False): + scheduler_job_runner_for_connection_tests._reap_stale_connection_tests(session=session) + + session.expire_all() + assert session.get(ConnectionTestRequest, ct_success.id).state == ConnectionTestState.SUCCESS + assert session.get(ConnectionTestRequest, ct_failed.id).state == ConnectionTestState.FAILED diff --git a/airflow-core/tests/unit/models/test_connection_test.py b/airflow-core/tests/unit/models/test_connection_test.py new file mode 100644 index 0000000000000..0d8da8cef695a --- /dev/null +++ b/airflow-core/tests/unit/models/test_connection_test.py @@ -0,0 +1,228 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.models.connection import Connection +from airflow.models.connection_test import ( + ConnectionTestRequest, + ConnectionTestState, + run_connection_test, +) + +from tests_common.test_utils.db import clear_db_connection_tests, clear_db_connections + +pytestmark = pytest.mark.db_test + + +class TestConnectionTestRequestModel: + def test_token_is_generated(self): + ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres") + assert ct.token is not None + assert len(ct.token) > 0 + + def test_initial_state_is_pending(self): + ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres") + assert ct.state == ConnectionTestState.PENDING + + def test_tokens_are_unique(self): + ct1 = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres") + ct2 = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres") + assert ct1.token != ct2.token + + def test_repr(self): + ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres") + r = repr(ct) + assert "test_conn" in r + assert "pending" in r + + def test_executor_parameter(self): + ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres", executor="my_executor") + assert ct.executor == "my_executor" + + def test_executor_defaults_to_none(self): + ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres") + assert ct.executor is None + + def test_queue_parameter(self): + ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres", queue="my_queue") + assert ct.queue == "my_queue" + + def test_queue_defaults_to_none(self): + ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres") + assert ct.queue is None + + def test_connection_fields_stored(self): + ct = ConnectionTestRequest( + connection_id="test_conn", + conn_type="postgres", + host="db.example.com", + login="user", + password="secret", + schema="mydb", + port=5432, + extra='{"key": "value"}', + ) + assert ct.conn_type == "postgres" + assert ct.host == "db.example.com" + assert ct.login == "user" + assert ct.password == "secret" + assert ct.schema == "mydb" + assert ct.port == 5432 + assert ct.extra == '{"key": "value"}' + + def test_password_is_encrypted(self): + ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres", password="secret") + assert ct._password is not None + assert ct._password != "secret" + assert ct.password == "secret" + + def test_extra_is_encrypted(self): + ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres", extra='{"key": "val"}') + assert ct._extra is not None + assert ct._extra != '{"key": "val"}' + assert ct.extra == '{"key": "val"}' + + def test_null_password_and_extra(self): + ct = ConnectionTestRequest(connection_id="test_conn", conn_type="http") + assert ct._password is None + assert ct._extra is None + + def test_commit_on_success_default(self): + ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres") + assert ct.commit_on_success is False + + def test_commit_on_success_true(self): + ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres", commit_on_success=True) + assert ct.commit_on_success is True + + +class TestToConnection: + def test_to_connection_returns_transient_connection(self): + ct = ConnectionTestRequest( + connection_id="test_conn", + conn_type="postgres", + host="db.example.com", + login="user", + password="secret", + schema="mydb", + port=5432, + extra='{"key": "value"}', + ) + conn = ct.to_connection() + assert isinstance(conn, Connection) + assert conn.conn_id == "test_conn" + assert conn.conn_type == "postgres" + assert conn.host == "db.example.com" + assert conn.login == "user" + assert conn.password == "secret" + assert conn.schema == "mydb" + assert conn.port == 5432 + assert conn.extra == '{"key": "value"}' + + +class TestCommitToConnectionTable: + @pytest.fixture(autouse=True) + def setup_teardown(self): + clear_db_connections(add_default_connections_back=False) + clear_db_connection_tests() + yield + clear_db_connections(add_default_connections_back=False) + clear_db_connection_tests() + + def test_creates_new_connection(self, session): + ct = ConnectionTestRequest( + connection_id="new_conn", + conn_type="postgres", + host="db.example.com", + login="user", + password="secret", + schema="mydb", + port=5432, + ) + session.add(ct) + session.flush() + + ct.commit_to_connection_table(session=session) + session.flush() + + from sqlalchemy import select + + conn = session.scalar(select(Connection).filter_by(conn_id="new_conn")) + assert conn is not None + assert conn.conn_type == "postgres" + assert conn.host == "db.example.com" + assert conn.password == "secret" + + def test_updates_existing_connection(self, session): + conn = Connection(conn_id="existing_conn", conn_type="http", host="old-host.example.com") + session.add(conn) + session.flush() + + ct = ConnectionTestRequest( + connection_id="existing_conn", + conn_type="postgres", + host="new-host.example.com", + login="new_user", + password="new_secret", + ) + session.add(ct) + session.flush() + + ct.commit_to_connection_table(session=session) + session.flush() + session.refresh(conn) + + assert conn.conn_type == "postgres" + assert conn.host == "new-host.example.com" + assert conn.login == "new_user" + assert conn.password == "new_secret" + + +class TestRunConnectionTest: + def test_successful_connection_test(self): + conn = mock.MagicMock(spec=Connection) + conn.conn_id = "test_conn" + conn.test_connection.return_value = (True, "Connection OK") + + success, message = run_connection_test(conn=conn) + + assert success is True + assert message == "Connection OK" + + def test_failed_connection_test(self): + conn = mock.MagicMock(spec=Connection) + conn.conn_id = "test_conn" + conn.test_connection.return_value = (False, "Connection failed") + + success, message = run_connection_test(conn=conn) + + assert success is False + assert message == "Connection failed" + + def test_exception_during_connection_test(self): + conn = mock.MagicMock(spec=Connection) + conn.conn_id = "test_conn" + conn.test_connection.side_effect = Exception("Could not resolve host: db.example.com") + + success, message = run_connection_test(conn=conn) + + assert success is False + assert "Could not resolve host" in message diff --git a/airflow-ctl/src/airflowctl/api/datamodels/generated.py b/airflow-ctl/src/airflowctl/api/datamodels/generated.py index bb375f888349c..9758167d72e6d 100644 --- a/airflow-ctl/src/airflowctl/api/datamodels/generated.py +++ b/airflow-ctl/src/airflowctl/api/datamodels/generated.py @@ -244,15 +244,58 @@ class ConnectionResponse(BaseModel): team_name: Annotated[str | None, Field(title="Team Name")] = None +class ConnectionTestQueuedResponse(BaseModel): + """ + Response returned when an async connection test is queued. + """ + + token: Annotated[str, Field(title="Token")] + connection_id: Annotated[str, Field(title="Connection Id")] + state: Annotated[str, Field(title="State")] + + +class ConnectionTestRequestBody(BaseModel): + """ + Request body for async connection test. + """ + + model_config = ConfigDict( + extra="forbid", + ) + connection_id: Annotated[str, Field(title="Connection Id")] + conn_type: Annotated[str, Field(title="Conn Type")] + host: Annotated[str | None, Field(title="Host")] = None + login: Annotated[str | None, Field(title="Login")] = None + schema_: Annotated[str | None, Field(alias="schema", title="Schema")] = None + port: Annotated[int | None, Field(title="Port")] = None + password: Annotated[str | None, Field(title="Password")] = None + extra: Annotated[str | None, Field(title="Extra")] = None + commit_on_success: Annotated[bool | None, Field(title="Commit On Success")] = False + executor: Annotated[str | None, Field(title="Executor")] = None + queue: Annotated[str | None, Field(title="Queue")] = None + + class ConnectionTestResponse(BaseModel): """ - Connection Test serializer for responses. + Connection Test serializer for synchronous test responses. """ status: Annotated[bool, Field(title="Status")] message: Annotated[str, Field(title="Message")] +class ConnectionTestStatusResponse(BaseModel): + """ + Response returned when polling for async connection test status. + """ + + token: Annotated[str, Field(title="Token")] + connection_id: Annotated[str, Field(title="Connection Id")] + state: Annotated[str, Field(title="State")] + result_message: Annotated[str | None, Field(title="Result Message")] = None + created_at: Annotated[datetime, Field(title="Created At")] + + class CreateAssetEventsBody(BaseModel): """ Create asset events request. diff --git a/devel-common/src/tests_common/test_utils/db.py b/devel-common/src/tests_common/test_utils/db.py index cbfb0b377ae71..dfe181284888e 100644 --- a/devel-common/src/tests_common/test_utils/db.py +++ b/devel-common/src/tests_common/test_utils/db.py @@ -470,6 +470,14 @@ def clear_db_teams(): session.execute(delete(Team)) +def clear_db_connection_tests(): + with create_session() as session: + if AIRFLOW_V_3_2_PLUS: + from airflow.models.connection_test import ConnectionTestRequest + + session.execute(delete(ConnectionTestRequest)) + + @_retry_db def clear_db_revoked_tokens(): with create_session() as session: @@ -1001,3 +1009,5 @@ def clear_all(): clear_db_backfills() clear_db_dag_bundles() clear_db_dag_parsing_requests() + if AIRFLOW_V_3_2_PLUS: + clear_db_connection_tests() diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 90374f76be50f..1f882b97c20a9 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -45,6 +45,8 @@ AssetEventsResponse, AssetResponse, ConnectionResponse, + ConnectionTestResultBody, + ConnectionTestState, DagResponse, DagRun, DagRunStateResponse, @@ -851,6 +853,25 @@ def get_detail_response(self, ti_id: uuid.UUID) -> HITLDetailResponse: return HITLDetailResponse.model_validate_json(resp.read()) +class ConnectionTestOperations: + __slots__ = ("client",) + + def __init__(self, client: Client): + self.client = client + + def get_connection(self, connection_test_id: uuid.UUID) -> ConnectionResponse: + """Fetch connection data for a test request from the API server.""" + resp = self.client.get(f"connection-tests/{connection_test_id}/connection") + return ConnectionResponse.model_validate_json(resp.read()) + + def update_state( + self, id: uuid.UUID, state: ConnectionTestState, result_message: str | None = None + ) -> None: + """Report the state of a connection test to the API server.""" + body = ConnectionTestResultBody(state=state, result_message=result_message) + self.client.patch(f"connection-tests/{id}", content=body.model_dump_json()) + + class BearerAuth(httpx.Auth): def __init__(self, token: str): self.token: str = token @@ -1025,6 +1046,12 @@ def hitl(self): """Operations related to HITL Responses.""" return HITLOperations(self) + @lru_cache() # type: ignore[misc] + @property + def connection_tests(self) -> ConnectionTestOperations: + """Operations related to Connection Tests.""" + return ConnectionTestOperations(self) + @lru_cache() # type: ignore[misc] @property def dags(self) -> DagsOperations: diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index b6c08e9d76c82..c5f442b2b7741 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -78,6 +78,33 @@ class ConnectionResponse(BaseModel): extra: Annotated[str | None, Field(title="Extra")] = None +class ConnectionTestConnectionResponse(BaseModel): + """ + Connection data returned to workers from a test request. + """ + + conn_id: Annotated[str, Field(title="Conn Id")] + conn_type: Annotated[str, Field(title="Conn Type")] + host: Annotated[str | None, Field(title="Host")] = None + login: Annotated[str | None, Field(title="Login")] = None + password: Annotated[str | None, Field(title="Password")] = None + schema_: Annotated[str | None, Field(alias="schema", title="Schema")] = None + port: Annotated[int | None, Field(title="Port")] = None + extra: Annotated[str | None, Field(title="Extra")] = None + + +class ConnectionTestState(str, Enum): + """ + All possible states of a connection test. + """ + + PENDING = "pending" + QUEUED = "queued" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + + class DagResponse(BaseModel): """ Schema for DAG response. @@ -535,6 +562,18 @@ class AssetResponse(BaseModel): extra: Annotated[dict[str, JsonValue] | None, Field(title="Extra")] = None +class ConnectionTestResultBody(BaseModel): + """ + Payload sent by workers to report connection test results. + """ + + model_config = ConfigDict( + extra="forbid", + ) + state: ConnectionTestState + result_message: Annotated[str | None, Field(title="Result Message")] = None + + class HITLDetailRequest(BaseModel): """ Schema for the request part of a Human-in-the-loop detail for a specific task instance.