-
Notifications
You must be signed in to change notification settings - Fork 380
Expand file tree
/
Copy pathfabric.py
More file actions
321 lines (253 loc) · 13.9 KB
/
fabric.py
File metadata and controls
321 lines (253 loc) · 13.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
from __future__ import annotations
import typing as t
import logging
import requests
import time
from functools import cached_property
from sqlglot import exp
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_result
from sqlmesh.core.engine_adapter.mixins import LogicalMergeMixin
from sqlmesh.core.engine_adapter.mssql import MSSQLEngineAdapter
from sqlmesh.core.engine_adapter.shared import (
InsertOverwriteStrategy,
)
from sqlmesh.utils.errors import SQLMeshError
from sqlmesh.utils.connection_pool import ConnectionPool
logger = logging.getLogger(__name__)
class FabricEngineAdapter(LogicalMergeMixin, MSSQLEngineAdapter):
"""
Adapter for Microsoft Fabric.
"""
DIALECT = "fabric"
SUPPORTS_INDEXES = False
SUPPORTS_TRANSACTIONS = False
SUPPORTS_CREATE_DROP_CATALOG = True
INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.DELETE_INSERT
def __init__(
self, connection_factory_or_pool: t.Union[t.Callable, t.Any], *args: t.Any, **kwargs: t.Any
) -> None:
# Wrap connection factory to support changing the catalog dynamically at runtime
if not isinstance(connection_factory_or_pool, ConnectionPool):
original_connection_factory = connection_factory_or_pool
connection_factory_or_pool = lambda *args, **kwargs: original_connection_factory(
target_catalog=self._target_catalog, *args, **kwargs
)
super().__init__(connection_factory_or_pool, *args, **kwargs)
@property
def _target_catalog(self) -> t.Optional[str]:
return self._connection_pool.get_attribute("target_catalog")
@_target_catalog.setter
def _target_catalog(self, value: t.Optional[str]) -> None:
self._connection_pool.set_attribute("target_catalog", value)
@property
def api_client(self) -> FabricHttpClient:
# the requests Session is not guaranteed to be threadsafe
# so we create a http client per thread on demand
if existing_client := self._connection_pool.get_attribute("api_client"):
return existing_client
tenant_id: t.Optional[str] = self._extra_config.get("tenant_id")
workspace_id: t.Optional[str] = self._extra_config.get("workspace_id")
client_id: t.Optional[str] = self._extra_config.get("user")
client_secret: t.Optional[str] = self._extra_config.get("password")
if not tenant_id or not client_id or not client_secret:
raise SQLMeshError(
"Service Principal authentication requires tenant_id, client_id, and client_secret "
"in the Fabric connection configuration"
)
if not workspace_id:
raise SQLMeshError(
"Fabric requires the workspace_id to be configured in the connection configuration to create / drop catalogs"
)
client = FabricHttpClient(
tenant_id=tenant_id,
workspace_id=workspace_id,
client_id=client_id,
client_secret=client_secret,
)
self._connection_pool.set_attribute("api_client", client)
return client
def _create_catalog(self, catalog_name: exp.Identifier) -> None:
"""Create a catalog (warehouse) in Microsoft Fabric via REST API."""
warehouse_name = catalog_name.sql(dialect=self.dialect, identify=False)
logger.info(f"Creating Fabric warehouse: {warehouse_name}")
self.api_client.create_warehouse(warehouse_name)
def _drop_catalog(self, catalog_name: exp.Identifier) -> None:
"""Drop a catalog (warehouse) in Microsoft Fabric via REST API."""
warehouse_name = catalog_name.sql(dialect=self.dialect, identify=False)
current_catalog = self.get_current_catalog()
logger.info(f"Deleting Fabric warehouse: {warehouse_name}")
self.api_client.delete_warehouse(warehouse_name)
if warehouse_name == current_catalog:
# Somewhere around 2025-09-08, Fabric started validating the "Database=" connection argument and throwing 'Authentication failed' if the database doesnt exist
# In addition, set_current_catalog() is implemented using a threadlocal variable "target_catalog"
# So, when we drop a warehouse, and there are still threads with "target_catalog" set to reference it, any operations on those threads
# that use an either use an existing connection pointing to this warehouse or trigger a new connection
# will fail with an 'Authentication Failed' error unless we close all connections here, which also clears all the threadlocal data
self.close()
def set_current_catalog(self, catalog_name: str) -> None:
"""
Set the current catalog for Microsoft Fabric connections.
Override to handle Fabric's stateless session limitation where USE statements
don't persist across queries. Instead, we close existing connections and
recreate them with the new catalog in the connection configuration.
Args:
catalog_name: The name of the catalog (warehouse) to switch to
Note:
Fabric doesn't support catalog switching via USE statements because each
statement runs as an independent session. This method works around this
limitation by updating the connection pool with new catalog configuration.
See:
https://learn.microsoft.com/en-us/fabric/data-warehouse/sql-query-editor#limitations
"""
current_catalog = self.get_current_catalog()
# If already using the requested catalog, do nothing
if current_catalog and current_catalog == catalog_name:
logger.debug(f"Already using catalog '{catalog_name}', no action needed")
return
logger.info(f"Switching from catalog '{current_catalog}' to '{catalog_name}'")
# commit the transaction before closing the connection to help prevent errors like:
# > Snapshot isolation transaction failed in database because the object accessed by the statement has been modified by a
# > DDL statement in another concurrent transaction since the start of this transaction
# on subsequent queries in the new connection
self._connection_pool.commit()
# note: we call close() on the connection pool instead of self.close() because self.close() calls close_all()
# on the connection pool but we just want to close the connection for this thread
self._connection_pool.close()
self._target_catalog = catalog_name # new connections will use this catalog
catalog_after_switch = self.get_current_catalog()
if catalog_after_switch != catalog_name:
# We need to raise an error if the catalog switch failed to prevent the operation that needed the catalog switch from being run against the wrong catalog
raise SQLMeshError(
f"Unable to switch catalog to {catalog_name}, catalog ended up as {catalog_after_switch}"
)
class FabricHttpClient:
def __init__(self, tenant_id: str, workspace_id: str, client_id: str, client_secret: str):
self.tenant_id = tenant_id
self.client_id = client_id
self.client_secret = client_secret
self.workspace_id = workspace_id
def create_warehouse(
self, warehouse_name: str, if_not_exists: bool = True, attempt: int = 0
) -> None:
"""Create a catalog (warehouse) in Microsoft Fabric via REST API."""
# attempt count is arbitrary, it essentially equates to 5 minutes of 30 second waits
if attempt > 10:
raise SQLMeshError(
f"Gave up waiting for Fabric warehouse {warehouse_name} to become available"
)
logger.info(f"Creating Fabric warehouse: {warehouse_name}")
request_data = {
"displayName": warehouse_name,
"description": f"Warehouse created by SQLMesh: {warehouse_name}",
}
response = self.session.post(self._endpoint_url("warehouses"), json=request_data)
if (
if_not_exists
and response.status_code == 400
and (errorCode := response.json().get("errorCode", None))
):
if errorCode == "ItemDisplayNameAlreadyInUse":
logger.warning(f"Fabric warehouse {warehouse_name} already exists")
return
if errorCode == "ItemDisplayNameNotAvailableYet":
logger.warning(f"Fabric warehouse {warehouse_name} is still spinning up; waiting")
# Fabric error message is something like:
# - "Requested 'circleci_51d7087e__dev' is not available yet and is expected to become available in the upcoming minutes."
# This seems to happen if a catalog is dropped and then a new one with the same name is immediately created.
# There appears to be some delayed async process on the Fabric side that actually drops the warehouses and frees up the names to be used again
time.sleep(30)
return self.create_warehouse(
warehouse_name=warehouse_name, if_not_exists=if_not_exists, attempt=attempt + 1
)
try:
response.raise_for_status()
except:
# the important information to actually debug anything is in the response body which Requests never prints
logger.exception(
f"Failed to create warehouse {warehouse_name}. status: {response.status_code}, body: {response.text}"
)
raise
# Handle direct success (201) or async creation (202)
if response.status_code == 201:
logger.info(f"Successfully created Fabric warehouse: {warehouse_name}")
return
if response.status_code == 202 and (location_header := response.headers.get("location")):
logger.info(f"Warehouse creation initiated for: {warehouse_name}")
self._wait_for_completion(location_header, warehouse_name)
logger.info(f"Successfully created Fabric warehouse: {warehouse_name}")
else:
logger.error(f"Unexpected response from Fabric API: {response}\n{response.text}")
raise SQLMeshError(f"Unable to create warehouse: {response}")
def delete_warehouse(self, warehouse_name: str, if_exists: bool = True) -> None:
"""Drop a catalog (warehouse) in Microsoft Fabric via REST API."""
logger.info(f"Deleting Fabric warehouse: {warehouse_name}")
# Get the warehouse ID by listing warehouses
# TODO: handle continuationUri for pagination, ref: https://learn.microsoft.com/en-us/rest/api/fabric/warehouse/items/list-warehouses?tabs=HTTP#warehouses
response = self.session.get(self._endpoint_url("warehouses"))
response.raise_for_status()
warehouse_name_to_id = {
warehouse.get("displayName"): warehouse.get("id")
for warehouse in response.json().get("value", [])
}
warehouse_id = warehouse_name_to_id.get(warehouse_name, None)
if not warehouse_id:
logger.warning(
f"Fabric warehouse does not exist: {warehouse_name}\n(available warehouses: {', '.join(warehouse_name_to_id)})"
)
if if_exists:
return
raise SQLMeshError(
f"Unable to delete Fabric warehouse {warehouse_name} as it doesnt exist"
)
# Delete the warehouse by ID
response = self.session.delete(self._endpoint_url(f"warehouses/{warehouse_id}"))
response.raise_for_status()
logger.info(f"Successfully deleted Fabric warehouse: {warehouse_name}")
@cached_property
def session(self) -> requests.Session:
s = requests.Session()
access_token = self._get_access_token()
s.headers.update({"Authorization": f"Bearer {access_token}"})
return s
def _endpoint_url(self, endpoint: str) -> str:
if endpoint.startswith("/"):
endpoint = endpoint[1:]
return f"https://api.fabric.microsoft.com/v1/workspaces/{self.workspace_id}/{endpoint}"
def _get_access_token(self) -> str:
"""Get access token using Service Principal authentication."""
# Use Azure AD OAuth2 token endpoint
token_url = f"https://login.microsoftonline.com/{self.tenant_id}/oauth2/v2.0/token"
data = {
"grant_type": "client_credentials",
"client_id": self.client_id,
"client_secret": self.client_secret,
"scope": "https://api.fabric.microsoft.com/.default",
}
response = requests.post(token_url, data=data)
response.raise_for_status()
token_data = response.json()
return token_data["access_token"]
def _wait_for_completion(self, location_url: str, operation_name: str) -> None:
"""Poll the operation status until completion."""
@retry(
wait=wait_exponential(multiplier=1, min=1, max=30),
stop=stop_after_attempt(20),
retry=retry_if_result(lambda result: result not in ["Succeeded", "Failed"]),
)
def _poll() -> str:
response = self.session.get(location_url)
response.raise_for_status()
result = response.json()
status = result.get("status", "Unknown")
logger.debug(f"Operation {operation_name} status: {status}")
if status == "Failed":
error_msg = result.get("error", {}).get("message", "Unknown error")
raise SQLMeshError(f"Operation {operation_name} failed: {error_msg}")
elif status in ["InProgress", "Running"]:
logger.debug(f"Operation {operation_name} still in progress...")
elif status not in ["Succeeded"]:
logger.warning(f"Unknown status '{status}' for operation {operation_name}")
return status
final_status = _poll()
if final_status != "Succeeded":
raise SQLMeshError(f"Operation {operation_name} completed with status: {final_status}")