Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from googleapiclient.errors import HttpError

from airflow.configuration import conf
from airflow.providers.common.compat.sdk import AirflowException
from airflow.providers.google.cloud.hooks.functions import CloudFunctionsHook
from airflow.providers.google.cloud.links.cloud_functions import (
Expand Down Expand Up @@ -498,3 +499,166 @@ def execute(self, context: Context):
)

return result

class CloudFunctionInvokeOperator(GoogleCloudBaseOperator):
"""
Invokes a deployed Cloud Function via its HTTP Trigger URL.

Unlike CloudFunctionInvokeFunctionOperator which uses the testing-only `functions.call` API,
this operator makes an authenticated HTTP request to the function's URL. This makes it
suitable for production workloads.

.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:CloudFunctionInvokeOperator`

:param function_id: ID of the function to be called
:param input_data: Input to be passed to the function
:param location: The location where the function is located.
:param project_id: Optional, Google Cloud Project project_id where the function belongs.
:param deferrable: Run operator in the deferrable mode
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
"""

template_fields: Sequence[str] = (
"function_id",
"input_data",
"location",
"project_id",
"impersonation_chain",
)
operator_extra_links = (CloudFunctionsDetailsLink(),)

def __init__(
self,
*,
function_id: str,
input_data: dict | None = None,
location: str,
project_id: str = PROVIDE_PROJECT_ID,
gcp_conn_id: str = "google_cloud_default",
api_version: str = "v1",
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.function_id = function_id
self.input_data = input_data or {}
self.location = location
self.project_id = project_id
self.gcp_conn_id = gcp_conn_id
self.api_version = api_version
self.deferrable = deferrable
self.impersonation_chain = impersonation_chain

@property
def extra_links_params(self) -> dict[str, Any]:
return {
"location": self.location,
"function_name": self.function_id,
}

def _get_id_token(self, hook, audience: str) -> str:
import google.auth.transport.requests
from google.oauth2 import service_account, id_token

credentials = hook.get_credentials()
request = google.auth.transport.requests.Request()

if isinstance(credentials, service_account.Credentials):
if hasattr(credentials, "with_target_audience"):
jwt_creds = credentials.with_target_audience(audience)
jwt_creds.refresh(request)
return jwt_creds.token
elif hasattr(credentials, "with_claims"):
jwt_creds = credentials.with_claims({"aud": audience})
jwt_creds.refresh(request)
return jwt_creds.token

try:
from google.auth import compute_engine
if isinstance(credentials, compute_engine.Credentials):
from google.auth.compute_engine import _metadata
return _metadata.get_service_account_id_token(request, audience=audience)
except ImportError:
pass

try:
from google.auth import impersonated
if isinstance(credentials, impersonated.Credentials):
from google.auth.impersonated import IDTokenCredentials
id_token_creds = IDTokenCredentials(
credentials,
target_principal=credentials.target_principal,
target_audience=audience,
include_email=True,
)
id_token_creds.refresh(request)
return id_token_creds.token
except ImportError:
pass

# Fallback to fetch_id_token (ADC)
return id_token.fetch_id_token(request, audience)

def execute(self, context: Context):
from airflow.providers.google.cloud.triggers.functions import CloudFunctionInvokeTrigger
import requests

hook = CloudFunctionsHook(
api_version=self.api_version,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
project_id = self.project_id or hook.project_id
self.log.info("Fetching function details for %s.", self.function_id)
function = hook.get_function(name=f"projects/{project_id}/locations/{self.location}/functions/{self.function_id}")

url = None
if "httpsTrigger" in function and "url" in function["httpsTrigger"]:
url = function["httpsTrigger"]["url"]
elif "serviceConfig" in function and "uri" in function["serviceConfig"]:
url = function["serviceConfig"]["uri"]

if not url:
raise AirflowException(f"Function {self.function_id} does not have an HTTP trigger URL.")

self.log.info("Function HTTP URL: %s", url)

# Get ID token for authentication
token = self._get_id_token(hook, url)
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}

if project_id:
CloudFunctionsDetailsLink.persist(
context=context,
project_id=project_id,
)

if self.deferrable:
self.log.info("Deferring execution to CloudFunctionInvokeTrigger.")
self.defer(
trigger=CloudFunctionInvokeTrigger(
function_uri=url,
json_payload=self.input_data,
headers=headers,
),
method_name="execute_complete",
)
else:
self.log.info("Invoking function synchronously.")
response = requests.post(url, json=self.input_data, headers=headers)
try:
response.raise_for_status()
return response.json()
except requests.exceptions.JSONDecodeError:
return response.text

def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
if event["status"] == "error":
raise AirflowException(event["message"])
self.log.info("Function invoked successfully.")
return event["response"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import asyncio
from typing import Any

import aiohttp

from airflow.triggers.base import BaseTrigger, TriggerEvent


class CloudFunctionInvokeTrigger(BaseTrigger):
"""
Trigger that makes an HTTP POST request to a Google Cloud Function and waits for the response.

:param function_uri: The HTTPS trigger URL of the Cloud Function.
:param json_payload: The JSON payload to send in the request body.
:param headers: The headers to send in the request, including authentication headers.
:param timeout: Optional. The timeout in seconds for the HTTP request.
"""

def __init__(
self,
function_uri: str,
json_payload: dict[str, Any] | None,
headers: dict[str, str],
timeout: float | None = None,
):
super().__init__()
self.function_uri = function_uri
self.json_payload = json_payload
self.headers = headers
self.timeout = timeout

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize class arguments and classpath."""
return (
"airflow.providers.google.cloud.triggers.functions.CloudFunctionInvokeTrigger",
{
"function_uri": self.function_uri,
"json_payload": self.json_payload,
"headers": self.headers,
"timeout": self.timeout,
},
)

async def run(self):
"""Make an async HTTP request to the Cloud Function."""
try:
# We use aiohttp instead of the synchronous requests library
timeout_obj = aiohttp.ClientTimeout(total=self.timeout) if self.timeout else aiohttp.ClientTimeout()
async with aiohttp.ClientSession(timeout=timeout_obj) as session:
async with session.post(
self.function_uri,
json=self.json_payload,
headers=self.headers,
) as response:
# Cloud Functions return whatever the user code returns.
# We capture the status code and text/json.
status_code = response.status
response_text = await response.text()

if status_code >= 400:
yield TriggerEvent(
{
"status": "error",
"message": f"Cloud Function invocation failed with status {status_code}: {response_text}",
"status_code": status_code,
}
)
else:
try:
response_json = await response.json()
yield TriggerEvent(
{
"status": "success",
"response": response_json,
"status_code": status_code,
}
)
except Exception:
# If response is not JSON, yield the text
yield TriggerEvent(
{
"status": "success",
"response": response_text,
"status_code": status_code,
}
)
except asyncio.TimeoutError:
yield TriggerEvent(
{
"status": "error",
"message": "Cloud Function invocation timed out.",
"status_code": None,
}
)
except Exception as e:
yield TriggerEvent(
{
"status": "error",
"message": str(e),
"status_code": None,
}
)
Original file line number Diff line number Diff line change
Expand Up @@ -736,3 +736,97 @@ def test_execute(self, mock_gcf_hook):
key="execution_id",
value=exec_id,
)

class TestCloudFunctionInvokeOperator:
@mock.patch("airflow.providers.google.cloud.operators.functions.CloudFunctionsHook")
@mock.patch("airflow.providers.google.cloud.operators.functions.requests.post")
def test_execute_sync(self, mock_post, mock_hook):
mock_hook.return_value.get_function.return_value = {
"httpsTrigger": {"url": "https://example.com/function"}
}

mock_response = mock.MagicMock()
mock_response.json.return_value = {"status": "ok"}
mock_post.return_value = mock_response

from airflow.providers.google.cloud.operators.functions import CloudFunctionInvokeOperator
op = CloudFunctionInvokeOperator(
task_id="test",
function_id="my_func",
location="us-central1",
project_id="test_project",
input_data={"data": "test"},
deferrable=False,
)

op._get_id_token = mock.MagicMock(return_value="mocked-token")

result = op.execute(context=mock.MagicMock())

assert result == {"status": "ok"}
mock_hook.return_value.get_function.assert_called_once_with(
name="projects/test_project/locations/us-central1/functions/my_func"
)
mock_post.assert_called_once_with(
"https://example.com/function",
json={"data": "test"},
headers={"Authorization": "Bearer mocked-token", "Content-Type": "application/json"}
)

@mock.patch("airflow.providers.google.cloud.operators.functions.CloudFunctionsHook")
def test_execute_deferrable(self, mock_hook):
from airflow.exceptions import TaskDeferred
from airflow.providers.google.cloud.triggers.functions import CloudFunctionInvokeTrigger
from airflow.providers.google.cloud.operators.functions import CloudFunctionInvokeOperator

mock_hook.return_value.get_function.return_value = {
"httpsTrigger": {"url": "https://example.com/function"}
}

op = CloudFunctionInvokeOperator(
task_id="test",
function_id="my_func",
location="us-central1",
project_id="test_project",
input_data={"data": "test"},
deferrable=True,
)
op._get_id_token = mock.MagicMock(return_value="mocked-token")

with pytest.raises(TaskDeferred) as exc:
op.execute(context=mock.MagicMock())

assert isinstance(exc.value.trigger, CloudFunctionInvokeTrigger)
assert exc.value.trigger.function_uri == "https://example.com/function"
assert exc.value.trigger.headers == {
"Authorization": "Bearer mocked-token",
"Content-Type": "application/json",
}

def test_execute_complete_success(self):
from airflow.providers.google.cloud.operators.functions import CloudFunctionInvokeOperator
op = CloudFunctionInvokeOperator(
task_id="test",
function_id="my_func",
location="us-central1",
project_id="test_project",
)

event = {"status": "success", "response": {"result": "ok"}}
result = op.execute_complete(context=mock.MagicMock(), event=event)

assert result == {"result": "ok"}

def test_execute_complete_error(self):
from airflow.providers.google.cloud.operators.functions import CloudFunctionInvokeOperator
op = CloudFunctionInvokeOperator(
task_id="test",
function_id="my_func",
location="us-central1",
project_id="test_project",
)

event = {"status": "error", "message": "Failed"}
with pytest.raises(AirflowException, match="Failed"):
op.execute_complete(context=mock.MagicMock(), event=event)

Loading