From 151f6820a31d4c9e5bf36fb13dba7144992be20e Mon Sep 17 00:00:00 2001 From: anushkagupta200615-jpg Date: Tue, 23 Jun 2026 23:38:59 +0530 Subject: [PATCH] Add deferrable support for Cloud Functions invocation --- .../google/cloud/operators/functions.py | 164 ++++++++++++++++++ .../google/cloud/triggers/functions.py | 120 +++++++++++++ .../google/cloud/operators/test_functions.py | 94 ++++++++++ .../google/cloud/triggers/test_functions.py | 109 ++++++++++++ 4 files changed, 487 insertions(+) create mode 100644 providers/google/src/airflow/providers/google/cloud/triggers/functions.py create mode 100644 providers/google/tests/unit/google/cloud/triggers/test_functions.py diff --git a/providers/google/src/airflow/providers/google/cloud/operators/functions.py b/providers/google/src/airflow/providers/google/cloud/operators/functions.py index 988578f3bab5e..6cb62b2f9d024 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/functions.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/functions.py @@ -25,6 +25,7 @@ from googleapiclient.errors import HttpError +from airflow.configuration import conf from airflow.providers.common.compat.sdk import AirflowException from airflow.providers.google.cloud.hooks.functions import CloudFunctionsHook from airflow.providers.google.cloud.links.cloud_functions import ( @@ -498,3 +499,166 @@ def execute(self, context: Context): ) return result + +class CloudFunctionInvokeOperator(GoogleCloudBaseOperator): + """ + Invokes a deployed Cloud Function via its HTTP Trigger URL. + + Unlike CloudFunctionInvokeFunctionOperator which uses the testing-only `functions.call` API, + this operator makes an authenticated HTTP request to the function's URL. This makes it + suitable for production workloads. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:CloudFunctionInvokeOperator` + + :param function_id: ID of the function to be called + :param input_data: Input to be passed to the function + :param location: The location where the function is located. + :param project_id: Optional, Google Cloud Project project_id where the function belongs. + :param deferrable: Run operator in the deferrable mode + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + """ + + template_fields: Sequence[str] = ( + "function_id", + "input_data", + "location", + "project_id", + "impersonation_chain", + ) + operator_extra_links = (CloudFunctionsDetailsLink(),) + + def __init__( + self, + *, + function_id: str, + input_data: dict | None = None, + location: str, + project_id: str = PROVIDE_PROJECT_ID, + gcp_conn_id: str = "google_cloud_default", + api_version: str = "v1", + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.function_id = function_id + self.input_data = input_data or {} + self.location = location + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.api_version = api_version + self.deferrable = deferrable + self.impersonation_chain = impersonation_chain + + @property + def extra_links_params(self) -> dict[str, Any]: + return { + "location": self.location, + "function_name": self.function_id, + } + + def _get_id_token(self, hook, audience: str) -> str: + import google.auth.transport.requests + from google.oauth2 import service_account, id_token + + credentials = hook.get_credentials() + request = google.auth.transport.requests.Request() + + if isinstance(credentials, service_account.Credentials): + if hasattr(credentials, "with_target_audience"): + jwt_creds = credentials.with_target_audience(audience) + jwt_creds.refresh(request) + return jwt_creds.token + elif hasattr(credentials, "with_claims"): + jwt_creds = credentials.with_claims({"aud": audience}) + jwt_creds.refresh(request) + return jwt_creds.token + + try: + from google.auth import compute_engine + if isinstance(credentials, compute_engine.Credentials): + from google.auth.compute_engine import _metadata + return _metadata.get_service_account_id_token(request, audience=audience) + except ImportError: + pass + + try: + from google.auth import impersonated + if isinstance(credentials, impersonated.Credentials): + from google.auth.impersonated import IDTokenCredentials + id_token_creds = IDTokenCredentials( + credentials, + target_principal=credentials.target_principal, + target_audience=audience, + include_email=True, + ) + id_token_creds.refresh(request) + return id_token_creds.token + except ImportError: + pass + + # Fallback to fetch_id_token (ADC) + return id_token.fetch_id_token(request, audience) + + def execute(self, context: Context): + from airflow.providers.google.cloud.triggers.functions import CloudFunctionInvokeTrigger + import requests + + hook = CloudFunctionsHook( + api_version=self.api_version, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + project_id = self.project_id or hook.project_id + self.log.info("Fetching function details for %s.", self.function_id) + function = hook.get_function(name=f"projects/{project_id}/locations/{self.location}/functions/{self.function_id}") + + url = None + if "httpsTrigger" in function and "url" in function["httpsTrigger"]: + url = function["httpsTrigger"]["url"] + elif "serviceConfig" in function and "uri" in function["serviceConfig"]: + url = function["serviceConfig"]["uri"] + + if not url: + raise AirflowException(f"Function {self.function_id} does not have an HTTP trigger URL.") + + self.log.info("Function HTTP URL: %s", url) + + # Get ID token for authentication + token = self._get_id_token(hook, url) + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + if project_id: + CloudFunctionsDetailsLink.persist( + context=context, + project_id=project_id, + ) + + if self.deferrable: + self.log.info("Deferring execution to CloudFunctionInvokeTrigger.") + self.defer( + trigger=CloudFunctionInvokeTrigger( + function_uri=url, + json_payload=self.input_data, + headers=headers, + ), + method_name="execute_complete", + ) + else: + self.log.info("Invoking function synchronously.") + response = requests.post(url, json=self.input_data, headers=headers) + try: + response.raise_for_status() + return response.json() + except requests.exceptions.JSONDecodeError: + return response.text + + def execute_complete(self, context: Context, event: dict[str, Any]) -> Any: + if event["status"] == "error": + raise AirflowException(event["message"]) + self.log.info("Function invoked successfully.") + return event["response"] diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/functions.py b/providers/google/src/airflow/providers/google/cloud/triggers/functions.py new file mode 100644 index 0000000000000..a2c5a17e76f21 --- /dev/null +++ b/providers/google/src/airflow/providers/google/cloud/triggers/functions.py @@ -0,0 +1,120 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from typing import Any + +import aiohttp + +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class CloudFunctionInvokeTrigger(BaseTrigger): + """ + Trigger that makes an HTTP POST request to a Google Cloud Function and waits for the response. + + :param function_uri: The HTTPS trigger URL of the Cloud Function. + :param json_payload: The JSON payload to send in the request body. + :param headers: The headers to send in the request, including authentication headers. + :param timeout: Optional. The timeout in seconds for the HTTP request. + """ + + def __init__( + self, + function_uri: str, + json_payload: dict[str, Any] | None, + headers: dict[str, str], + timeout: float | None = None, + ): + super().__init__() + self.function_uri = function_uri + self.json_payload = json_payload + self.headers = headers + self.timeout = timeout + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize class arguments and classpath.""" + return ( + "airflow.providers.google.cloud.triggers.functions.CloudFunctionInvokeTrigger", + { + "function_uri": self.function_uri, + "json_payload": self.json_payload, + "headers": self.headers, + "timeout": self.timeout, + }, + ) + + async def run(self): + """Make an async HTTP request to the Cloud Function.""" + try: + # We use aiohttp instead of the synchronous requests library + timeout_obj = aiohttp.ClientTimeout(total=self.timeout) if self.timeout else aiohttp.ClientTimeout() + async with aiohttp.ClientSession(timeout=timeout_obj) as session: + async with session.post( + self.function_uri, + json=self.json_payload, + headers=self.headers, + ) as response: + # Cloud Functions return whatever the user code returns. + # We capture the status code and text/json. + status_code = response.status + response_text = await response.text() + + if status_code >= 400: + yield TriggerEvent( + { + "status": "error", + "message": f"Cloud Function invocation failed with status {status_code}: {response_text}", + "status_code": status_code, + } + ) + else: + try: + response_json = await response.json() + yield TriggerEvent( + { + "status": "success", + "response": response_json, + "status_code": status_code, + } + ) + except Exception: + # If response is not JSON, yield the text + yield TriggerEvent( + { + "status": "success", + "response": response_text, + "status_code": status_code, + } + ) + except asyncio.TimeoutError: + yield TriggerEvent( + { + "status": "error", + "message": "Cloud Function invocation timed out.", + "status_code": None, + } + ) + except Exception as e: + yield TriggerEvent( + { + "status": "error", + "message": str(e), + "status_code": None, + } + ) diff --git a/providers/google/tests/unit/google/cloud/operators/test_functions.py b/providers/google/tests/unit/google/cloud/operators/test_functions.py index 9a08d695c68b5..20d3deaaf6087 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_functions.py +++ b/providers/google/tests/unit/google/cloud/operators/test_functions.py @@ -736,3 +736,97 @@ def test_execute(self, mock_gcf_hook): key="execution_id", value=exec_id, ) + +class TestCloudFunctionInvokeOperator: + @mock.patch("airflow.providers.google.cloud.operators.functions.CloudFunctionsHook") + @mock.patch("airflow.providers.google.cloud.operators.functions.requests.post") + def test_execute_sync(self, mock_post, mock_hook): + mock_hook.return_value.get_function.return_value = { + "httpsTrigger": {"url": "https://example.com/function"} + } + + mock_response = mock.MagicMock() + mock_response.json.return_value = {"status": "ok"} + mock_post.return_value = mock_response + + from airflow.providers.google.cloud.operators.functions import CloudFunctionInvokeOperator + op = CloudFunctionInvokeOperator( + task_id="test", + function_id="my_func", + location="us-central1", + project_id="test_project", + input_data={"data": "test"}, + deferrable=False, + ) + + op._get_id_token = mock.MagicMock(return_value="mocked-token") + + result = op.execute(context=mock.MagicMock()) + + assert result == {"status": "ok"} + mock_hook.return_value.get_function.assert_called_once_with( + name="projects/test_project/locations/us-central1/functions/my_func" + ) + mock_post.assert_called_once_with( + "https://example.com/function", + json={"data": "test"}, + headers={"Authorization": "Bearer mocked-token", "Content-Type": "application/json"} + ) + + @mock.patch("airflow.providers.google.cloud.operators.functions.CloudFunctionsHook") + def test_execute_deferrable(self, mock_hook): + from airflow.exceptions import TaskDeferred + from airflow.providers.google.cloud.triggers.functions import CloudFunctionInvokeTrigger + from airflow.providers.google.cloud.operators.functions import CloudFunctionInvokeOperator + + mock_hook.return_value.get_function.return_value = { + "httpsTrigger": {"url": "https://example.com/function"} + } + + op = CloudFunctionInvokeOperator( + task_id="test", + function_id="my_func", + location="us-central1", + project_id="test_project", + input_data={"data": "test"}, + deferrable=True, + ) + op._get_id_token = mock.MagicMock(return_value="mocked-token") + + with pytest.raises(TaskDeferred) as exc: + op.execute(context=mock.MagicMock()) + + assert isinstance(exc.value.trigger, CloudFunctionInvokeTrigger) + assert exc.value.trigger.function_uri == "https://example.com/function" + assert exc.value.trigger.headers == { + "Authorization": "Bearer mocked-token", + "Content-Type": "application/json", + } + + def test_execute_complete_success(self): + from airflow.providers.google.cloud.operators.functions import CloudFunctionInvokeOperator + op = CloudFunctionInvokeOperator( + task_id="test", + function_id="my_func", + location="us-central1", + project_id="test_project", + ) + + event = {"status": "success", "response": {"result": "ok"}} + result = op.execute_complete(context=mock.MagicMock(), event=event) + + assert result == {"result": "ok"} + + def test_execute_complete_error(self): + from airflow.providers.google.cloud.operators.functions import CloudFunctionInvokeOperator + op = CloudFunctionInvokeOperator( + task_id="test", + function_id="my_func", + location="us-central1", + project_id="test_project", + ) + + event = {"status": "error", "message": "Failed"} + with pytest.raises(AirflowException, match="Failed"): + op.execute_complete(context=mock.MagicMock(), event=event) + diff --git a/providers/google/tests/unit/google/cloud/triggers/test_functions.py b/providers/google/tests/unit/google/cloud/triggers/test_functions.py new file mode 100644 index 0000000000000..75660536103ed --- /dev/null +++ b/providers/google/tests/unit/google/cloud/triggers/test_functions.py @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import asyncio +from unittest import mock + +import pytest + +from airflow.providers.google.cloud.triggers.functions import CloudFunctionInvokeTrigger +from airflow.triggers.base import TriggerEvent + +FUNCTION_URI = "https://example.com/function" +JSON_PAYLOAD = {"key": "value"} +HEADERS = {"Authorization": "Bearer token"} + + +class TestCloudFunctionInvokeTrigger: + def test_serialization(self): + trigger = CloudFunctionInvokeTrigger( + function_uri=FUNCTION_URI, + json_payload=JSON_PAYLOAD, + headers=HEADERS, + timeout=30.0, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.google.cloud.triggers.functions.CloudFunctionInvokeTrigger" + assert kwargs == { + "function_uri": FUNCTION_URI, + "json_payload": JSON_PAYLOAD, + "headers": HEADERS, + "timeout": 30.0, + } + + @pytest.mark.asyncio + @mock.patch("aiohttp.ClientSession.post") + async def test_run_success(self, mock_post): + mock_response = mock.AsyncMock() + mock_response.status = 200 + mock_response.json = mock.AsyncMock(return_value={"result": "success"}) + mock_post.return_value.__aenter__.return_value = mock_response + + trigger = CloudFunctionInvokeTrigger( + function_uri=FUNCTION_URI, + json_payload=JSON_PAYLOAD, + headers=HEADERS, + ) + + generator = trigger.run() + event = await generator.asend(None) + + assert isinstance(event, TriggerEvent) + assert event.payload["status"] == "success" + assert event.payload["response"] == {"result": "success"} + + @pytest.mark.asyncio + @mock.patch("aiohttp.ClientSession.post") + async def test_run_error(self, mock_post): + mock_response = mock.AsyncMock() + mock_response.status = 400 + mock_response.text = mock.AsyncMock(return_value="Bad Request") + mock_post.return_value.__aenter__.return_value = mock_response + + trigger = CloudFunctionInvokeTrigger( + function_uri=FUNCTION_URI, + json_payload=JSON_PAYLOAD, + headers=HEADERS, + ) + + generator = trigger.run() + event = await generator.asend(None) + + assert isinstance(event, TriggerEvent) + assert event.payload["status"] == "error" + assert event.payload["status_code"] == 400 + + @pytest.mark.asyncio + @mock.patch("aiohttp.ClientSession.post") + async def test_run_timeout(self, mock_post): + mock_post.side_effect = asyncio.TimeoutError() + + trigger = CloudFunctionInvokeTrigger( + function_uri=FUNCTION_URI, + json_payload=JSON_PAYLOAD, + headers=HEADERS, + timeout=10.0, + ) + + generator = trigger.run() + event = await generator.asend(None) + + assert isinstance(event, TriggerEvent) + assert event.payload["status"] == "error" + assert "timed out" in event.payload["message"]