Skip to content

Commit 033d185

Browse files
authored
Fix unclosed aiohttp ClientSession in AzureDataFactoryAsyncHook (#60650)
* Fix unclosed aiohttp ClientSession in AzureDataFactoryAsyncHook * If cancel_pipeline_run() creates a new connection * Address review comments * review: early return * add tests: build_trigger_event --------- Co-authored-by: Akshay <cruseakshay@users.noreply.github.com>
1 parent 66dc37e commit 033d185

4 files changed

Lines changed: 217 additions & 104 deletions

File tree

providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_factory.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,20 @@ def __init__(self, azure_data_factory_conn_id: str = default_conn_name):
11211121
self.conn_id = azure_data_factory_conn_id
11221122
super().__init__(azure_data_factory_conn_id=azure_data_factory_conn_id)
11231123

1124+
async def __aenter__(self):
1125+
"""Enter async context manager - returns self for use in async with blocks."""
1126+
return self
1127+
1128+
async def __aexit__(self, exc_type, exc_val, exc_tb):
1129+
"""Exit async context manager - closes the async connection."""
1130+
await self.close()
1131+
1132+
async def close(self) -> None:
1133+
"""Close the async connection to Azure Data Factory."""
1134+
if self._async_conn is not None:
1135+
await self._async_conn.close()
1136+
self._async_conn = None
1137+
11241138
async def get_async_conn(self) -> AsyncDataFactoryManagementClient:
11251139
"""Get async connection and connect to azure data factory."""
11261140
if self._async_conn is not None:

providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/data_factory.py

Lines changed: 119 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -69,44 +69,50 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
6969
},
7070
)
7171

72+
def _build_trigger_event(self, pipeline_status: str) -> TriggerEvent | None:
73+
"""Build TriggerEvent based on pipeline status. Returns None if status is not terminal."""
74+
if pipeline_status == AzureDataFactoryPipelineRunStatus.FAILED:
75+
return TriggerEvent({"status": "error", "message": f"Pipeline run {self.run_id} has Failed."})
76+
if pipeline_status == AzureDataFactoryPipelineRunStatus.CANCELLED:
77+
return TriggerEvent(
78+
{"status": "error", "message": f"Pipeline run {self.run_id} has been Cancelled."}
79+
)
80+
if pipeline_status == AzureDataFactoryPipelineRunStatus.SUCCEEDED:
81+
return TriggerEvent(
82+
{"status": "success", "message": f"Pipeline run {self.run_id} has been Succeeded."}
83+
)
84+
return None
85+
7286
async def run(self) -> AsyncIterator[TriggerEvent]:
7387
"""Make async connection to Azure Data Factory, polls for the pipeline run status."""
74-
hook = AzureDataFactoryAsyncHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)
7588
executed_after_token_refresh = False
7689
try:
77-
while True:
78-
try:
79-
pipeline_status = await hook.get_adf_pipeline_run_status(
80-
run_id=self.run_id,
81-
resource_group_name=self.resource_group_name,
82-
factory_name=self.factory_name,
83-
)
84-
executed_after_token_refresh = False
85-
if pipeline_status == AzureDataFactoryPipelineRunStatus.FAILED:
86-
yield TriggerEvent(
87-
{"status": "error", "message": f"Pipeline run {self.run_id} has Failed."}
90+
async with AzureDataFactoryAsyncHook(
91+
azure_data_factory_conn_id=self.azure_data_factory_conn_id
92+
) as hook:
93+
while True:
94+
try:
95+
pipeline_status = await hook.get_adf_pipeline_run_status(
96+
run_id=self.run_id,
97+
resource_group_name=self.resource_group_name,
98+
factory_name=self.factory_name,
8899
)
89-
return
90-
elif pipeline_status == AzureDataFactoryPipelineRunStatus.CANCELLED:
91-
msg = f"Pipeline run {self.run_id} has been Cancelled."
92-
yield TriggerEvent({"status": "error", "message": msg})
93-
return
94-
elif pipeline_status == AzureDataFactoryPipelineRunStatus.SUCCEEDED:
95-
msg = f"Pipeline run {self.run_id} has been Succeeded."
96-
yield TriggerEvent({"status": "success", "message": msg})
97-
return
98-
await asyncio.sleep(self.poke_interval)
99-
except ServiceRequestError:
100-
# conn might expire during long running pipeline.
101-
# If exception is caught, it tries to refresh connection once.
102-
# If it still doesn't fix the issue,
103-
# than the execute_after_token_refresh would still be False
104-
# and an exception will be raised
105-
if executed_after_token_refresh:
100+
executed_after_token_refresh = False
101+
event = self._build_trigger_event(pipeline_status)
102+
if event:
103+
yield event
104+
return
105+
await asyncio.sleep(self.poke_interval)
106+
except ServiceRequestError:
107+
# conn might expire during long running pipeline.
108+
# If exception is caught, it tries to refresh connection once.
109+
# If it still doesn't fix the issue,
110+
# than the execute_after_token_refresh would still be False
111+
# and an exception will be raised
112+
if not executed_after_token_refresh:
113+
raise
106114
await hook.refresh_conn()
107115
executed_after_token_refresh = False
108-
else:
109-
raise
110116
except Exception as e:
111117
yield TriggerEvent({"status": "error", "message": str(e)})
112118

@@ -160,84 +166,93 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
160166
},
161167
)
162168

169+
def _build_trigger_event(self, pipeline_status: str) -> TriggerEvent | None:
170+
"""Build TriggerEvent based on pipeline status. Returns None if status is not terminal."""
171+
if pipeline_status in AzureDataFactoryPipelineRunStatus.FAILURE_STATES:
172+
return TriggerEvent(
173+
{
174+
"status": "error",
175+
"message": f"The pipeline run {self.run_id} has {pipeline_status}.",
176+
"run_id": self.run_id,
177+
}
178+
)
179+
if pipeline_status == AzureDataFactoryPipelineRunStatus.SUCCEEDED:
180+
return TriggerEvent(
181+
{
182+
"status": "success",
183+
"message": f"The pipeline run {self.run_id} has {pipeline_status}.",
184+
"run_id": self.run_id,
185+
}
186+
)
187+
return None
188+
163189
async def run(self) -> AsyncIterator[TriggerEvent]:
164190
"""Make async connection to Azure Data Factory, polls for the pipeline run status."""
165-
hook = AzureDataFactoryAsyncHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)
166-
try:
167-
pipeline_status = await hook.get_adf_pipeline_run_status(
168-
run_id=self.run_id,
169-
resource_group_name=self.resource_group_name,
170-
factory_name=self.factory_name,
171-
)
172-
executed_after_token_refresh = True
173-
if self.wait_for_termination:
174-
while self.end_time > time.time():
175-
try:
176-
pipeline_status = await hook.get_adf_pipeline_run_status(
177-
run_id=self.run_id,
178-
resource_group_name=self.resource_group_name,
179-
factory_name=self.factory_name,
180-
)
181-
executed_after_token_refresh = True
182-
if pipeline_status in AzureDataFactoryPipelineRunStatus.FAILURE_STATES:
183-
yield TriggerEvent(
184-
{
185-
"status": "error",
186-
"message": f"The pipeline run {self.run_id} has {pipeline_status}.",
187-
"run_id": self.run_id,
188-
}
191+
async with AzureDataFactoryAsyncHook(
192+
azure_data_factory_conn_id=self.azure_data_factory_conn_id
193+
) as hook:
194+
try:
195+
pipeline_status = await hook.get_adf_pipeline_run_status(
196+
run_id=self.run_id,
197+
resource_group_name=self.resource_group_name,
198+
factory_name=self.factory_name,
199+
)
200+
executed_after_token_refresh = True
201+
if self.wait_for_termination:
202+
while self.end_time > time.time():
203+
try:
204+
pipeline_status = await hook.get_adf_pipeline_run_status(
205+
run_id=self.run_id,
206+
resource_group_name=self.resource_group_name,
207+
factory_name=self.factory_name,
189208
)
190-
return
191-
elif pipeline_status == AzureDataFactoryPipelineRunStatus.SUCCEEDED:
192-
yield TriggerEvent(
193-
{
194-
"status": "success",
195-
"message": f"The pipeline run {self.run_id} has {pipeline_status}.",
196-
"run_id": self.run_id,
197-
}
209+
executed_after_token_refresh = True
210+
event = self._build_trigger_event(pipeline_status)
211+
if event:
212+
yield event
213+
return
214+
self.log.info(
215+
"Sleeping for %s. The pipeline state is %s.",
216+
self.check_interval,
217+
pipeline_status,
198218
)
199-
return
200-
self.log.info(
201-
"Sleeping for %s. The pipeline state is %s.", self.check_interval, pipeline_status
202-
)
203-
await asyncio.sleep(self.check_interval)
204-
except ServiceRequestError:
205-
# conn might expire during long running pipeline.
206-
# If exception is caught, it tries to refresh connection once.
207-
# If it still doesn't fix the issue,
208-
# than the execute_after_token_refresh would still be False
209-
# and an exception will be raised
210-
if executed_after_token_refresh:
219+
await asyncio.sleep(self.check_interval)
220+
except ServiceRequestError:
221+
# conn might expire during long running pipeline.
222+
# If exception is caught, it tries to refresh connection once.
223+
# If it still doesn't fix the issue,
224+
# than the execute_after_token_refresh would still be False
225+
# and an exception will be raised
226+
if not executed_after_token_refresh:
227+
raise
211228
await hook.refresh_conn()
212229
executed_after_token_refresh = False
213-
else:
214-
raise
215230

216-
yield TriggerEvent(
217-
{
218-
"status": "error",
219-
"message": f"Timeout: The pipeline run {self.run_id} has {pipeline_status}.",
220-
"run_id": self.run_id,
221-
}
222-
)
223-
else:
224-
yield TriggerEvent(
225-
{
226-
"status": "success",
227-
"message": f"The pipeline run {self.run_id} has {pipeline_status} status.",
228-
"run_id": self.run_id,
229-
}
230-
)
231-
except Exception as e:
232-
self.log.exception(e)
233-
if self.run_id:
234-
try:
235-
self.log.info("Cancelling pipeline run %s", self.run_id)
236-
await hook.cancel_pipeline_run(
237-
run_id=self.run_id,
238-
resource_group_name=self.resource_group_name,
239-
factory_name=self.factory_name,
231+
yield TriggerEvent(
232+
{
233+
"status": "error",
234+
"message": f"Timeout: The pipeline run {self.run_id} has {pipeline_status}.",
235+
"run_id": self.run_id,
236+
}
240237
)
241-
except Exception:
242-
self.log.exception("Failed to cancel pipeline run %s", self.run_id)
243-
yield TriggerEvent({"status": "error", "message": str(e), "run_id": self.run_id})
238+
else:
239+
yield TriggerEvent(
240+
{
241+
"status": "success",
242+
"message": f"The pipeline run {self.run_id} has {pipeline_status} status.",
243+
"run_id": self.run_id,
244+
}
245+
)
246+
except Exception as e:
247+
self.log.exception(e)
248+
if self.run_id:
249+
try:
250+
self.log.info("Cancelling pipeline run %s", self.run_id)
251+
await hook.cancel_pipeline_run(
252+
run_id=self.run_id,
253+
resource_group_name=self.resource_group_name,
254+
factory_name=self.factory_name,
255+
)
256+
except Exception:
257+
self.log.exception("Failed to cancel pipeline run %s", self.run_id)
258+
yield TriggerEvent({"status": "error", "message": str(e), "run_id": self.run_id})

providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_data_factory.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -876,3 +876,38 @@ async def test_refresh_conn(self, mock_get_async_conn):
876876
await hook.refresh_conn()
877877
assert not hook._conn
878878
assert mock_get_async_conn.called
879+
880+
@pytest.mark.asyncio
881+
async def test_close_method(self):
882+
"""Test close method properly closes the async connection"""
883+
hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
884+
mock_conn = mock.AsyncMock()
885+
hook._async_conn = mock_conn
886+
887+
await hook.close()
888+
889+
mock_conn.close.assert_called_once()
890+
assert hook._async_conn is None
891+
892+
@pytest.mark.asyncio
893+
async def test_close_method_when_conn_is_none(self):
894+
"""Test close method does nothing when connection is None"""
895+
hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
896+
hook._async_conn = None
897+
898+
# Should not raise any exception
899+
await hook.close()
900+
assert hook._async_conn is None
901+
902+
@pytest.mark.asyncio
903+
async def test_context_manager_calls_close(self):
904+
"""Test async context manager calls close on exit"""
905+
hook = AzureDataFactoryAsyncHook(AZURE_DATA_FACTORY_CONN_ID)
906+
mock_conn = mock.AsyncMock()
907+
hook._async_conn = mock_conn
908+
909+
async with hook:
910+
pass
911+
912+
mock_conn.close.assert_called_once()
913+
assert hook._async_conn is None

providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_data_factory.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,30 @@ class TestADFPipelineRunStatusSensorTrigger:
5151
poke_interval=POKE_INTERVAL,
5252
)
5353

54+
@pytest.mark.parametrize(
55+
("pipeline_status", "expected_status", "expected_message"),
56+
[
57+
("Failed", "error", f"Pipeline run {RUN_ID} has Failed."),
58+
("Cancelled", "error", f"Pipeline run {RUN_ID} has been Cancelled."),
59+
("Succeeded", "success", f"Pipeline run {RUN_ID} has been Succeeded."),
60+
],
61+
)
62+
def test_build_trigger_event_terminal_states(self, pipeline_status, expected_status, expected_message):
63+
"""Test _build_trigger_event returns correct TriggerEvent for terminal states."""
64+
event = self.TRIGGER._build_trigger_event(pipeline_status)
65+
assert event is not None
66+
assert event.payload["status"] == expected_status
67+
assert event.payload["message"] == expected_message
68+
69+
@pytest.mark.parametrize(
70+
"pipeline_status",
71+
["Queued", "InProgress", "Canceling"],
72+
)
73+
def test_build_trigger_event_non_terminal_states(self, pipeline_status):
74+
"""Test _build_trigger_event returns None for non-terminal states."""
75+
event = self.TRIGGER._build_trigger_event(pipeline_status)
76+
assert event is None
77+
5478
def test_adf_pipeline_run_status_sensors_trigger_serialization(self):
5579
"""
5680
Asserts that the TaskStateTrigger correctly serializes its arguments
@@ -186,6 +210,31 @@ class TestAzureDataFactoryTrigger:
186210
end_time=AZ_PIPELINE_END_TIME,
187211
)
188212

213+
@pytest.mark.parametrize(
214+
("pipeline_status", "expected_status"),
215+
[
216+
("Failed", "error"),
217+
("Cancelled", "error"),
218+
("Succeeded", "success"),
219+
],
220+
)
221+
def test_build_trigger_event_terminal_states(self, pipeline_status, expected_status):
222+
"""Test _build_trigger_event returns correct TriggerEvent for terminal states."""
223+
event = self.TRIGGER._build_trigger_event(pipeline_status)
224+
assert event is not None
225+
assert event.payload["status"] == expected_status
226+
assert event.payload["run_id"] == AZ_PIPELINE_RUN_ID
227+
assert f"The pipeline run {AZ_PIPELINE_RUN_ID} has {pipeline_status}." in event.payload["message"]
228+
229+
@pytest.mark.parametrize(
230+
"pipeline_status",
231+
["Queued", "InProgress", "Canceling"],
232+
)
233+
def test_build_trigger_event_non_terminal_states(self, pipeline_status):
234+
"""Test _build_trigger_event returns None for non-terminal states."""
235+
event = self.TRIGGER._build_trigger_event(pipeline_status)
236+
assert event is None
237+
189238
def test_azure_data_factory_trigger_serialization(self):
190239
"""Asserts that the AzureDataFactoryTrigger correctly serializes its arguments and classpath."""
191240

0 commit comments

Comments
 (0)