diff --git a/core/wren/README.md b/core/wren/README.md index 8f77e4a5b2..1eba56566d 100644 --- a/core/wren/README.md +++ b/core/wren/README.md @@ -17,6 +17,7 @@ pip install wren-engine[mysql] # MySQL pip install wren-engine[bigquery] # BigQuery pip install wren-engine[snowflake] # Snowflake pip install wren-engine[clickhouse] # ClickHouse +pip install wren-engine[ytsaurus] # YTsaurus (via CHYT) pip install wren-engine[trino] # Trino pip install wren-engine[mssql] # SQL Server pip install wren-engine[databricks] # Databricks diff --git a/core/wren/docs/connections.md b/core/wren/docs/connections.md index 0b62805656..e15b2931ec 100644 --- a/core/wren/docs/connections.md +++ b/core/wren/docs/connections.md @@ -125,3 +125,18 @@ Both formats are accepted. The CLI auto-flattens the envelope format. "format": "parquet" } ``` + +## YTsaurus (via CHYT) + +```json +{ + "datasource": "ytsaurus", + "proxy": "yt-proxy.example.com", + "clique": "*ch_public", + "token": "y0_AgAA..." +} +``` + +`token` is optional — if omitted, the connector reads `YT_TOKEN` from the +environment. See [`connectors/ytsaurus.md`](connectors/ytsaurus.md) for the +full field reference and CHYT-specific behavior. diff --git a/core/wren/docs/connectors/ytsaurus.md b/core/wren/docs/connectors/ytsaurus.md new file mode 100644 index 0000000000..4dc73f99c8 --- /dev/null +++ b/core/wren/docs/connectors/ytsaurus.md @@ -0,0 +1,95 @@ +# YTsaurus connector + +Connects Wren Engine to a [YTsaurus](https://ytsaurus.tech/en) cluster through +its **CHYT** clique (ClickHouse-over-YT). CHYT exposes a ClickHouse HTTP +protocol on the YT HTTP proxy, so this connector reuses Wren's existing +ClickHouse / Ibis path with YT-flavored auth (`Authorization: OAuth ` +and a clique alias as the ClickHouse `database`). + +The sqlglot dialect is `clickhouse`, so all CHYT-compatible ClickHouse SQL — +including `toUnixTimestamp`, `startsWith`, `now() - INTERVAL N DAY`, +`COUNT(DISTINCT ...)` — works as-is. + +## Install + +```bash +pip install "wren-engine[ytsaurus]" +``` + +The `ytsaurus` extra pulls `ibis-framework[clickhouse]`. + +## Connection info + +```python +from wren.model import YTsaurusConnectionInfo +from wren.model.data_source import DataSource +from wren.connector.factory import get_connector + +info = DataSource.ytsaurus.get_connection_info({ + "proxy": "yt-proxy.example.com", # YT HTTP proxy host + "clique": "*ch_public", # CHYT clique alias incl. leading "*" + # "token": "y0_AgAA...", # optional — falls back to YT_TOKEN env + # "secure": True, # default + # "port": 443, # default 443 / 80 by secure flag + # "settings": {"max_threads": "8"}, + # "kwargs": {"connect_timeout": "30"}, +}) + +connector = get_connector(DataSource.ytsaurus, info) +table = connector.query("SELECT now()", limit=1) +print(table.to_pandas()) +``` + +| Field | Type | Default | Meaning | +|---|---|---|---| +| `proxy` | str (required) | — | YT HTTP proxy host (no scheme). | +| `clique` | str (required) | — | CHYT clique alias including the `*` prefix. | +| `token` | SecretStr | env `YT_TOKEN` | YT OAuth token. | +| `secure` | bool | `True` | HTTPS vs HTTP. | +| `port` | int | 443 / 80 | Override proxy port. | +| `settings` | dict | `None` | ClickHouse session settings (e.g. `max_execution_time`). | +| `kwargs` | dict | `None` | Passed to `clickhouse_connect.get_client()`. Supports `http_headers` (the connector merges `Authorization` in automatically). | + +JSON form for use with `--connection-info` / `--connection-file`: + +```json +{ + "datasource": "ytsaurus", + "proxy": "yt-proxy.example.com", + "clique": "*ch_public", + "token": "y0_AgAA..." +} +``` + +## Auth + +The connector resolves the YT OAuth token in this order: + +1. `connection_info.token` if provided +2. `YT_TOKEN` environment variable + +The token is sent both as `Authorization: OAuth ` (current CHYT auth) +and as the ClickHouse `password` (legacy). Either works on any modern YT +proxy. + +If neither source produces a token, the connector raises +`WrenError(INVALID_CONNECTION_INFO)`. + +## Statement timeout + +Like the ClickHouse connector, the YTsaurus connector honors the +`x-wren-db-statement-timeout` header by setting the CHYT session's +`max_execution_time` (defaults to 180 seconds). + +## Limitations + +- **CHYT only**: the connector targets the ClickHouse-over-YT engine. + Query-Tracker-only features (raw YQL, SPYT) are not exposed. If you need a + YT-native YQL path, fork the connector and replace + `get_ytsaurus_connection` with a Query Tracker REST client; the rest of + the Wren plumbing (factory, enum, connection info) stays the same. +- **Clique availability**: queries fail if the named CHYT clique is not + running. Cliques are managed in the YT UI under "CHYT cliques". +- **Schema discovery**: `system.tables` works for CHYT-attached tables. + Static YT tables outside the clique's exposed schema must be referenced + by their full YT path (`"//home/.../table"`) inside CHYT queries. diff --git a/core/wren/pyproject.toml b/core/wren/pyproject.toml index f22008f47c..0143b0fe09 100644 --- a/core/wren/pyproject.toml +++ b/core/wren/pyproject.toml @@ -48,6 +48,7 @@ mysql = ["mysqlclient>=2.2", "ibis-framework[mysql]"] bigquery = ["ibis-framework[bigquery]", "google-auth"] snowflake = ["ibis-framework[snowflake]"] clickhouse = ["ibis-framework[clickhouse]"] +ytsaurus = ["ibis-framework[clickhouse]"] trino = ["ibis-framework[trino]", "trino>=0.321"] mssql = ["ibis-framework[mssql]"] databricks = ["databricks-sql-connector", "databricks-sdk"] @@ -60,7 +61,7 @@ interactive = ["InquirerPy>=0.3.4"] ui = ["starlette>=0.37", "uvicorn>=0.29", "jinja2>=3.1", "python-multipart>=0.0.9"] main = ["wren-engine[interactive,ui]"] all = [ - "wren-engine[postgres,mysql,bigquery,snowflake,clickhouse,trino,mssql,databricks,redshift,athena,oracle,spark,main,memory]", + "wren-engine[postgres,mysql,bigquery,snowflake,clickhouse,ytsaurus,trino,mssql,databricks,redshift,athena,oracle,spark,main,memory]", ] dev = [ "pytest>=8", diff --git a/core/wren/src/wren/connector/factory.py b/core/wren/src/wren/connector/factory.py index a58217856f..a639e4f1db 100644 --- a/core/wren/src/wren/connector/factory.py +++ b/core/wren/src/wren/connector/factory.py @@ -24,6 +24,7 @@ DataSource.oracle: "wren.connector.oracle", DataSource.snowflake: "wren.connector.ibis", DataSource.athena: "wren.connector.ibis", + DataSource.ytsaurus: "wren.connector.ytsaurus", } # Map data sources to the correct pip extra when they share a connector module diff --git a/core/wren/src/wren/connector/ytsaurus.py b/core/wren/src/wren/connector/ytsaurus.py new file mode 100644 index 0000000000..346c732384 --- /dev/null +++ b/core/wren/src/wren/connector/ytsaurus.py @@ -0,0 +1,118 @@ +"""YTsaurus (CHYT) connector. + +Talks to a YTsaurus cluster through its CHYT (ClickHouse-over-YT) clique. CHYT +exposes a ClickHouse-compatible HTTP protocol on the YT HTTP proxy, so the +underlying machinery is ibis' ClickHouse backend with YT-flavored auth. + +Auth: YT OAuth token. Resolution order: + 1. ``connection_info.token`` (SecretStr) if provided + 2. ``YT_TOKEN`` environment variable + +CHYT diverges from a stock ClickHouse server in two ways the IbisConnector +default can't handle: + + * **No CREATE VIEW.** ibis introspects query schemas by creating a temporary + view, but CHYT is read-only at the SQL layer and rejects DDL with + ``std::out_of_range``. This connector overrides ``query`` and + ``dry_run`` to bypass ibis and talk to the underlying ``clickhouse_connect`` + HttpClient directly via ``query_arrow``. + * **OAuth-only auth.** The token is sent as ``Authorization: OAuth `` + (the ``Bearer`` and ``Basic`` schemes are explicitly rejected by the YT + proxy). The clique alias is passed via the ``chyt.clique_alias`` URL + parameter, both wired in :func:`wren.model.data_source.DataSourceExtension.get_ytsaurus_connection`. +""" + +from __future__ import annotations + +import pyarrow as pa + +from wren.connector.base import IbisConnector +from wren.model.data_source import DataSource +from wren.model.error import DIALECT_SQL, ErrorCode, ErrorPhase, WrenError + +try: + import clickhouse_connect + + _ClickHouseDbError = clickhouse_connect.driver.exceptions.DatabaseError +except ImportError: + + class _ClickHouseDbError(Exception): + """Fallback stand-in when ``clickhouse_connect`` is not installed.""" + + +class YTsaurusConnector(IbisConnector): + """Connector for YTsaurus clusters via their CHYT (ClickHouse-over-YT) clique.""" + + def __init__(self, connection_info): + """Build the connector with a :class:`YTsaurusConnectionInfo`-shaped payload.""" + super().__init__(DataSource.ytsaurus, connection_info) + + @property + def _ch_client(self): + """Underlying clickhouse_connect HttpClient (set up by data_source.py).""" + return self.connection.con + + def query(self, sql: str, limit: int | None = None) -> pa.Table: + """Execute ``sql`` against CHYT and return the result as a ``pyarrow.Table``. + + ``limit``, if given, is appended as a ``LIMIT`` on a wrapping ``SELECT``. + Non-timeout backend errors are remapped to ``WrenError(INVALID_SQL)``. + """ + wrapped = sql + if limit is not None: + # ``limit`` is interpolated into the SQL string, so refuse anything + # that isn't a non-negative integer to make the f-string safe even + # if a caller bypasses the type hint. + if isinstance(limit, bool) or not isinstance(limit, int) or limit < 0: + raise ValueError(f"limit must be a non-negative int, got {limit!r}") + wrapped = f"SELECT * FROM (\n{sql}\n) LIMIT {limit}" + try: + # CHYT speaks the ClickHouse Native protocol but rejects + # ``query_arrow`` (UNKNOWN_FORMAT for Arrow). Fall back to native + # rows + columns and assemble a pyarrow.Table here. + result = self._ch_client.query(wrapped) + columns = list(result.column_names) + data = list(result.result_columns) + if len(columns) != len(data): + raise WrenError( + ErrorCode.INVALID_SQL, + f"CHYT returned mismatched column metadata: " + f"{len(columns)} names vs {len(data)} column arrays", + phase=ErrorPhase.SQL_EXECUTION, + metadata={DIALECT_SQL: sql}, + ) + return pa.table({name: col for name, col in zip(columns, data)}) + except _ClickHouseDbError as e: + if "TIMEOUT_EXCEEDED" not in str(e): + raise WrenError( + ErrorCode.INVALID_SQL, + str(e), + phase=ErrorPhase.SQL_EXECUTION, + metadata={DIALECT_SQL: sql}, + ) from e + raise + except (WrenError, TimeoutError): + raise + + def dry_run(self, sql: str) -> None: + """Validate ``sql`` against CHYT via ``EXPLAIN AST`` without materializing rows.""" + # CHYT supports `EXPLAIN AST` for syntax/planning validation without + # materializing rows. Wrap the user SQL and let CHYT parse it. + try: + self._ch_client.query(f"EXPLAIN AST {sql}") + except _ClickHouseDbError as e: + if "TIMEOUT_EXCEEDED" not in str(e): + raise WrenError( + ErrorCode.INVALID_SQL, + str(e), + phase=ErrorPhase.SQL_DRY_RUN, + metadata={DIALECT_SQL: sql}, + ) from e + raise + except (WrenError, TimeoutError): + raise + + +def create_connector(connection_info) -> YTsaurusConnector: + """Factory hook used by :mod:`wren.connector.factory`.""" + return YTsaurusConnector(connection_info) diff --git a/core/wren/src/wren/engine.py b/core/wren/src/wren/engine.py index 49711949bf..836318a63a 100644 --- a/core/wren/src/wren/engine.py +++ b/core/wren/src/wren/engine.py @@ -22,6 +22,7 @@ import base64 import json +import logging from typing import Any import pyarrow as pa @@ -35,6 +36,8 @@ from wren.model.error import DIALECT_SQL, ErrorCode, ErrorPhase, WrenError from wren.policy import validate_sql_policy +logger = logging.getLogger(__name__) + class WrenEngine: """Thin facade over wren-core MDL processing and connector execution. @@ -97,8 +100,12 @@ def dry_plan(self, sql: str, properties: dict | None = None) -> str: → per-model: sqlglot parse (Wren dialect) → inject as CTE → sqlglot generate (target dialect) → output SQL with model CTEs in target dialect + + Data-source-specific physical rewrites (e.g. YTsaurus path + substitution) are applied to the result so callers see the same SQL + that ``query()`` and ``dry_run()`` send to the backend. """ - return self._plan(sql, properties) + return self._apply_physical_overrides(self._plan(sql, properties)) # ------------------------------------------------------------------ # SQL execution @@ -141,6 +148,110 @@ def dry_run(self, sql: str, properties: dict | None = None) -> None: metadata={DIALECT_SQL: dialect_sql}, ) from e + # ------------------------------------------------------------------ + # Physical SQL rewriting (currently: YT path substitution) + # ------------------------------------------------------------------ + + def _apply_physical_overrides(self, sql: str) -> str: + """Apply data-source-specific rewrites to dialect SQL. + + For YTsaurus, the MDL's ``table_reference`` carries a synthetic + ``schema.table`` name (e.g. ``cdm_clients.tenant_index``) that wren + emits as an unquoted identifier. CHYT can't resolve those — it needs + the full YT path in backticks. If a model declares + ``properties.ytPath``, this rewrites every reference to that model + into the backticked path form CHYT understands. + """ + if self.data_source != DataSource.ytsaurus: + return sql + path_map = self._yt_path_map() + if not path_map: + return sql + + try: + dialect = get_sqlglot_dialect(self.data_source) + tree = parse_one(sql, read=dialect) + except Exception: + return sql + + def _rewrite(node): + """Replace a sqlglot ``Table`` node with the model's YT path when one is mapped.""" + if not isinstance(node, exp.Table): + return node + db = node.args.get("db") + name = node.args.get("this") + db_name = db.name if db is not None else "" + tbl_name = name.name if name is not None else "" + if not tbl_name: + return node + yt_path = path_map.get(f"{db_name}.{tbl_name}") or path_map.get(tbl_name) + if not yt_path: + return node + # Replace with a single backtick-quoted identifier carrying the + # YT path. Set quoted=True so sqlglot preserves the backticks + # when serializing to the ClickHouse dialect. + replacement = exp.Table( + this=exp.Identifier(this=yt_path, quoted=True), + alias=node.args.get("alias"), + ) + return replacement + + tree = tree.transform(_rewrite) + return tree.sql(dialect=dialect) + + def _yt_path_map(self) -> dict[str, str]: + """Build a `schema.table` / `table` → yt_path map from the manifest.""" + cached = getattr(self, "_yt_path_map_cache", None) + if cached is not None: + return cached + try: + manifest = ( + json.loads(self.manifest_str) + if self.manifest_str.lstrip().startswith("{") + else json.loads(base64.b64decode(self.manifest_str)) + ) + except Exception: + self._yt_path_map_cache = {} + return self._yt_path_map_cache + out: dict[str, str] = {} + # Unqualified `table` keys we've already chosen to remove because two + # models in different schemas share that bare name — the rewrite must + # not silently pick one yt_path over the other. + ambiguous: set[str] = set() + for m in manifest.get("models", []): + props = m.get("properties", {}) or {} + yt_path = props.get("ytPath") or props.get("yt_path") + if not yt_path: + continue + tr = m.get("tableReference") or m.get("table_reference") or {} + schema = (tr.get("schema") or "").strip() + table = (tr.get("table") or m.get("name") or "").strip() + if not table: + continue + if schema: + out[f"{schema}.{table}"] = yt_path + if table in ambiguous: + continue + existing = out.get(table) + if existing is None: + out[table] = yt_path + elif existing != yt_path: + # Conflict: drop the bare-name mapping so a query referencing + # just `` falls through to whatever CHYT resolves + # natively rather than rewriting to the wrong YT path. + logger.warning( + "YT path map collision on unqualified table %r " + "(paths %r vs %r) — dropping bare-name rewrite; " + "qualify with a schema to disambiguate.", + table, + existing, + yt_path, + ) + del out[table] + ambiguous.add(table) + self._yt_path_map_cache = out + return out + # ------------------------------------------------------------------ # Lifecycle # ------------------------------------------------------------------ diff --git a/core/wren/src/wren/mdl/cte_rewriter.py b/core/wren/src/wren/mdl/cte_rewriter.py index 62be74ce93..94e879c719 100644 --- a/core/wren/src/wren/mdl/cte_rewriter.py +++ b/core/wren/src/wren/mdl/cte_rewriter.py @@ -30,6 +30,7 @@ DataSource.s3_file: "duckdb", DataSource.minio_file: "duckdb", DataSource.gcs_file: "duckdb", + DataSource.ytsaurus: "clickhouse", } diff --git a/core/wren/src/wren/model/__init__.py b/core/wren/src/wren/model/__init__.py index b22fe1be95..9e623671a7 100644 --- a/core/wren/src/wren/model/__init__.py +++ b/core/wren/src/wren/model/__init__.py @@ -112,6 +112,38 @@ class ClickHouseConnectionInfo(BaseConnectionInfo): kwargs: dict[str, str] | None = Field(default=None) +class YTsaurusConnectionInfo(BaseConnectionInfo): + """Connection info for a YTsaurus cluster via its CHYT clique. + + ``proxy`` is the YT HTTP proxy host (e.g. ``yt-proxy.example.com``). + ``clique`` is the CHYT clique alias including the leading ``*`` + (e.g. ``*ch_public``). ``token`` is the YT OAuth token; if omitted the + connector reads ``YT_TOKEN`` from the environment. + """ + + proxy: str = Field(examples=["yt-proxy.example.com"]) + clique: str = Field(examples=["*ch_public"]) + token: SecretStr | None = Field( + default=None, + description="YT OAuth token. Falls back to YT_TOKEN env var if unset.", + ) + secure: bool = Field(default=True) + port: StrPort | None = Field( + default=None, + description="Override the proxy port. Defaults to 443 (secure) or 80.", + ) + query_path: str = Field( + default="/query", + description=( + "URL path on the YT HTTP proxy that exposes the CHYT endpoint. " + "The Nebius and open-source YT default is '/query'. Override only " + "if your proxy mounts CHYT elsewhere." + ), + ) + settings: dict[str, str] | None = Field(default=None) + kwargs: dict[str, str] | None = Field(default=None) + + class MSSqlConnectionInfo(BaseConnectionInfo): host: str = Field(examples=["localhost"]) port: StrPort = Field(examples=["1433"]) @@ -309,6 +341,7 @@ class ConnectionUrl(BaseConnectionInfo): | S3FileConnectionInfo | MinioFileConnectionInfo | GcsFileConnectionInfo + | YTsaurusConnectionInfo ) diff --git a/core/wren/src/wren/model/data_source.py b/core/wren/src/wren/model/data_source.py index 0ac6e88983..220190cf5a 100644 --- a/core/wren/src/wren/model/data_source.py +++ b/core/wren/src/wren/model/data_source.py @@ -1,7 +1,9 @@ from __future__ import annotations import base64 +import os import ssl +import threading import urllib from enum import Enum, StrEnum, auto from json import loads @@ -39,13 +41,23 @@ SparkConnectionInfo, SSLMode, TrinoConnectionInfo, + YTsaurusConnectionInfo, ) from wren.model.error import ErrorCode, WrenError X_WREN_DB_STATEMENT_TIMEOUT = "x-wren-db-statement_timeout" +# Serializes the class-level monkey-patching of clickhouse_connect's +# HttpClient inside ``DataSourceExtension.get_ytsaurus_connection``. +# Concurrent callers would otherwise race on ``HttpClient.params`` and the +# per-class ``_wren_yt_token`` (potentially leaking one caller's token into +# another caller's connection). +_YTSAURUS_PATCH_LOCK = threading.Lock() + class DataSource(StrEnum): + """Enumeration of supported data-source backends.""" + athena = auto() bigquery = auto() canner = auto() @@ -66,8 +78,10 @@ class DataSource(StrEnum): duckdb = auto() spark = auto() databricks = auto() + ytsaurus = auto() def get_connection(self, info: ConnectionInfo) -> BaseBackend: + """Dispatch to the corresponding :class:`DataSourceExtension` connection factory.""" try: return DataSourceExtension[self].get_connection(info) except KeyError: @@ -78,6 +92,7 @@ def get_connection_info( data: dict[str, Any] | ConnectionInfo, headers: dict[str, str] | None = None, ) -> ConnectionInfo: + """Validate raw payload into a typed ``ConnectionInfo`` and apply header overrides.""" headers = headers or {} if isinstance(data, BaseConnectionInfo): info = data @@ -95,7 +110,7 @@ def get_connection_info( options += f"-c statement_timeout={headers.get(X_WREN_DB_STATEMENT_TIMEOUT, 180)}s" kwargs["options"] = options info.kwargs = kwargs - case DataSource.clickhouse: + case DataSource.clickhouse | DataSource.ytsaurus: session_timeout = headers.get(X_WREN_DB_STATEMENT_TIMEOUT, 180) if info.settings is None: info.settings = {} @@ -118,7 +133,14 @@ def get_connection_info( return info def _build_connection_info(self, data: dict) -> ConnectionInfo: + """Build a ``ConnectionInfo`` instance for this data source from a raw dict.""" if "connectionUrl" in data or "connection_url" in data: + if self == DataSource.ytsaurus: + raise WrenError( + ErrorCode.INVALID_CONNECTION_INFO, + "YTsaurus does not support connectionUrl; " + "use proxy/clique/token fields.", + ) if self == DataSource.clickhouse: return self._handle_clickhouse_url( urllib.parse.urlparse( @@ -175,6 +197,8 @@ def _build_connection_info(self, data: dict) -> ConnectionInfo: ): return DatabricksServicePrincipalConnectionInfo.model_validate(data) return DatabricksTokenConnectionInfo.model_validate(data) + case DataSource.ytsaurus: + return YTsaurusConnectionInfo.model_validate(data) case _: raise NotImplementedError(f"Unsupported data source: {self}") @@ -209,6 +233,8 @@ def _safe_strtobool(self, val: str) -> bool: class DataSourceExtension(Enum): + """Mirror of :class:`DataSource` that owns the per-source connection factories.""" + athena = "athena" bigquery = "bigquery" canner = "canner" @@ -229,8 +255,10 @@ class DataSourceExtension(Enum): gcs_file = "gcs_file" databricks = "databricks" spark = "spark" + ytsaurus = "ytsaurus" def get_connection(self, info: ConnectionInfo) -> BaseBackend: + """Dispatch to the connection factory matching this data source.""" try: if hasattr(info, "connection_url"): kwargs = info.kwargs if info.kwargs else {} @@ -316,15 +344,138 @@ def get_canner_connection(info: CannerConnectionInfo) -> BaseBackend: @staticmethod def get_clickhouse_connection(info: ClickHouseConnectionInfo) -> BaseBackend: - return ibis.clickhouse.connect( - host=info.host, - port=int(info.port), - database=info.database, - user=info.user, - password=(info.password and info.password.get_secret_value()), - settings=info.settings if info.settings else {}, - **info.kwargs if info.kwargs else {}, + """Build a plain ClickHouse ibis backend from a connection-info payload.""" + # Serialize with get_ytsaurus_connection(): that path temporarily mutates + # clickhouse_connect's module-level HttpClient class and its `params` + # dict to inject CHYT auth. Without this lock, a concurrent plain + # ClickHouse connect can snapshot the patched state and leak the YT + # OAuth token + `chyt.clique_alias` into a regular ClickHouse client. + with _YTSAURUS_PATCH_LOCK: + return ibis.clickhouse.connect( + host=info.host, + port=int(info.port), + database=info.database, + user=info.user, + password=(info.password and info.password.get_secret_value()), + settings=info.settings if info.settings else {}, + **info.kwargs if info.kwargs else {}, + ) + + @staticmethod + def get_ytsaurus_connection(info: YTsaurusConnectionInfo) -> BaseBackend: + """Build an ibis ClickHouse backend wired to a YT CHYT clique. + + Resolves the YT OAuth token from ``info.token`` or ``YT_TOKEN``, then + temporarily monkey-patches ``clickhouse_connect``'s ``HttpClient`` so + every request carries the ``Authorization: OAuth `` header and + the ``chyt.clique_alias`` URL parameter. The patch window is guarded + by :data:`_YTSAURUS_PATCH_LOCK` so concurrent callers can't observe + each other's token or clique. Returns a live ibis backend; the + original ``HttpClient`` class is restored before returning. + """ + token = (info.token and info.token.get_secret_value()) or os.environ.get( + "YT_TOKEN" ) + if not token: + raise WrenError( + ErrorCode.INVALID_CONNECTION_INFO, + "YTsaurus requires a YT OAuth token via connection_info.token " + "or the YT_TOKEN environment variable.", + ) + + proxy = info.proxy.removeprefix("https://").removeprefix("http://").rstrip("/") + port = int(info.port) if info.port else (443 if info.secure else 80) + kwargs: dict[str, Any] = dict(info.kwargs) if info.kwargs else {} + + # YT exposes CHYT at a non-root URL path (default `/query`). The + # clickhouse_connect HttpClient calls this `proxy_path`. Allow + # override via YTsaurusConnectionInfo but default to the YT + # convention. + kwargs.setdefault("proxy_path", info.query_path) + + # YT's HTTP proxy has two requirements that vanilla clickhouse_connect + # doesn't meet: + # (1) the CHYT clique alias must be sent as URL parameter + # `chyt.clique_alias=` on every request, including the + # `SELECT version()` and `SELECT FROM system.settings` queries + # run during HttpClient construction; + # (2) the Authorization header must use the `OAuth` scheme, not + # `Basic` or `Bearer`. + # + # clickhouse_connect's HttpClient snapshots class-level `params` into + # the instance during __init__ (before any query runs), and creates + # an empty `self.headers` dict. We patch `HttpClient.params` for the + # duration of construction, and swap in an HttpClient subclass that + # injects the OAuth header into `self.headers` right before the + # parent `Client.__init__` runs the startup queries. + try: + # Deferred: clickhouse_connect is an optional dependency installed + # only with the [ytsaurus] / [clickhouse] extras. + import clickhouse_connect.driver as _ch_driver # noqa: PLC0415 + import clickhouse_connect.driver.httpclient as _ch_http # noqa: PLC0415 + except ImportError as e: + raise WrenError( + ErrorCode.INVALID_CONNECTION_INFO, + "clickhouse_connect is required for YTsaurus. Install the " + "extra: pip install 'wren-engine[ytsaurus]'.", + ) from e + + _BaseHttpClient = _ch_http.HttpClient + + class _CHYTHttpClient(_BaseHttpClient): + """HttpClient with YT OAuth auth pre-injected.""" + + _wren_yt_token: str | None = None + + def _init_common_settings(self, tz_source): + """Inject the YT OAuth header before clickhouse_connect's startup queries run.""" + token_val = type(self)._wren_yt_token + if token_val: + self.headers["Authorization"] = f"OAuth {token_val}" + return super()._init_common_settings(tz_source) + + # `clickhouse_connect.driver.create_client` does + # `from clickhouse_connect.driver.httpclient import HttpClient`, so it + # binds the class into its own namespace. We have to patch BOTH the + # source module and the importer's binding for the override to take + # effect. Hold the module-level lock so concurrent callers don't + # observe each other's patches or token. + with _YTSAURUS_PATCH_LOCK: + _CHYTHttpClient._wren_yt_token = token + original_class_params = _BaseHttpClient.params + _BaseHttpClient.params = dict(original_class_params) + _BaseHttpClient.params["chyt.clique_alias"] = info.clique + _ch_http.HttpClient = _CHYTHttpClient + _ch_driver.HttpClient = _CHYTHttpClient + try: + backend = ibis.clickhouse.connect( + host=proxy, + port=port, + database="", + user="", # empty → no Basic auth header (we set OAuth above) + password="", + secure=info.secure, + settings=info.settings if info.settings else {}, + **kwargs, + ) + finally: + _BaseHttpClient.params = original_class_params + _ch_http.HttpClient = _BaseHttpClient + _ch_driver.HttpClient = _BaseHttpClient + _CHYTHttpClient._wren_yt_token = None + + # Belt-and-braces: ensure the live instance also carries the alias + # and OAuth header (defends against clickhouse_connect ever + # re-snapshotting class state). + ch_client = getattr(backend, "con", backend) + params = getattr(ch_client, "params", None) + if params is not None: + params["chyt.clique_alias"] = info.clique + headers = getattr(ch_client, "headers", None) + if headers is not None: + headers["Authorization"] = f"OAuth {token}" + + return backend @classmethod def get_mssql_connection(cls, info: MSSqlConnectionInfo) -> BaseBackend: diff --git a/core/wren/src/wren/model/field_registry.py b/core/wren/src/wren/model/field_registry.py index 3c35ef84be..9dac995d1e 100644 --- a/core/wren/src/wren/model/field_registry.py +++ b/core/wren/src/wren/model/field_registry.py @@ -38,6 +38,7 @@ SnowflakeConnectionInfo, SparkConnectionInfo, TrinoConnectionInfo, + YTsaurusConnectionInfo, ) # Canonical datasource → ConnectionInfo model(s) mapping. @@ -66,6 +67,7 @@ "snowflake": [SnowflakeConnectionInfo], "spark": [SparkConnectionInfo], "trino": [TrinoConnectionInfo], + "ytsaurus": [YTsaurusConnectionInfo], "connection_url": [ConnectionUrl], } @@ -174,6 +176,12 @@ class FieldDef: "OracleConnectionInfo": { "dsn": {"label": "DSN (optional — overrides host/port/database)"}, }, + "YTsaurusConnectionInfo": { + "proxy": {"label": "YT Proxy", "placeholder": "yt-proxy.example.com"}, + "clique": {"label": "CHYT Clique", "placeholder": "*ch_public"}, + "token": {"label": "YT OAuth Token"}, + "query_path": {"label": "Query Path"}, + }, } # Datasource-level overrides: datasource_name → field_name → override dict.