From e212dd172c6e56d042e18668c82d9e94a6d4e20e Mon Sep 17 00:00:00 2001 From: Mu Huai Date: Mon, 23 Mar 2026 11:42:46 +0800 Subject: [PATCH] feat: add Prometheus datasource connector for real-time metrics querying Add PrometheusConnector that enables OpenDerisk agents to query Prometheus HTTP API for real-time metrics data, replacing the dependency on static OpenRCA datasets. Features: - Instant query (PromQL evaluation at a single point in time) - Range query (PromQL evaluation over a time range with configurable step) - Series discovery (find time series matching label selectors) - Label enumeration (list label names and values) - Target and rule inspection (scrape targets, alerting/recording rules) - Active alerts listing - Health check endpoint - Metric metadata retrieval - Basic authentication and custom headers support - SSL/TLS configuration - Formatted output compatible with BaseConnector.run() interface The connector inherits from BaseConnector and can be used by SRE agents for real-time diagnostics and root cause analysis. Includes comprehensive unit tests with mocked HTTP responses. --- .../derisk_ext/datasource/conn_prometheus.py | 487 ++++++++++++++++++ .../derisk_ext/datasource/tests/__init__.py | 0 .../datasource/tests/test_conn_prometheus.py | 202 ++++++++ 3 files changed, 689 insertions(+) create mode 100644 packages/derisk-ext/src/derisk_ext/datasource/conn_prometheus.py create mode 100644 packages/derisk-ext/src/derisk_ext/datasource/tests/__init__.py create mode 100644 packages/derisk-ext/src/derisk_ext/datasource/tests/test_conn_prometheus.py diff --git a/packages/derisk-ext/src/derisk_ext/datasource/conn_prometheus.py b/packages/derisk-ext/src/derisk_ext/datasource/conn_prometheus.py new file mode 100644 index 00000000..68945c9f --- /dev/null +++ b/packages/derisk-ext/src/derisk_ext/datasource/conn_prometheus.py @@ -0,0 +1,487 @@ +"""Prometheus datasource connector. + +Provides a connector for querying Prometheus HTTP API, enabling +OpenDerisk agents to access real-time metrics data for diagnostics +and root cause analysis. +""" + +import logging +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import requests + +from derisk.datasource.base import BaseConnector + +logger = logging.getLogger(__name__) + +_DEFAULT_TIMEOUT_SECONDS = 30 +_DEFAULT_STEP = "60s" + + +@dataclass +class PrometheusParameters: + """Connection parameters for Prometheus HTTP API.""" + + host: str = field( + default="localhost", + metadata={"help": "Prometheus server hostname or IP address."}, + ) + port: int = field( + default=9090, + metadata={"help": "Prometheus server port."}, + ) + scheme: str = field( + default="http", + metadata={"help": "Connection scheme, 'http' or 'https'."}, + ) + username: Optional[str] = field( + default=None, + metadata={"help": "Username for basic authentication."}, + ) + password: Optional[str] = field( + default=None, + metadata={"help": "Password for basic authentication."}, + ) + timeout: int = field( + default=_DEFAULT_TIMEOUT_SECONDS, + metadata={"help": "Request timeout in seconds."}, + ) + verify_ssl: bool = field( + default=True, + metadata={"help": "Whether to verify SSL certificates."}, + ) + custom_headers: Optional[Dict[str, str]] = field( + default=None, + metadata={"help": "Custom HTTP headers for requests."}, + ) + + @property + def base_url(self) -> str: + """Return the Prometheus API base URL.""" + return f"{self.scheme}://{self.host}:{self.port}" + + @property + def api_url(self) -> str: + """Return the Prometheus API v1 URL.""" + return f"{self.base_url}/api/v1" + + +class PrometheusConnector(BaseConnector): + """Connector for Prometheus time-series database. + + Supports instant queries, range queries, series discovery, label + enumeration, and target/rule inspection via the Prometheus HTTP API. + """ + + db_type: str = "prometheus" + driver: str = "prometheus_http" + + def __init__( + self, + host: str = "localhost", + port: int = 9090, + scheme: str = "http", + username: Optional[str] = None, + password: Optional[str] = None, + timeout: int = _DEFAULT_TIMEOUT_SECONDS, + verify_ssl: bool = True, + custom_headers: Optional[Dict[str, str]] = None, + ): + """Initialize PrometheusConnector. + + Args: + host: Prometheus server hostname or IP address. + port: Prometheus server port. + scheme: Connection scheme, 'http' or 'https'. + username: Username for basic authentication. + password: Password for basic authentication. + timeout: Request timeout in seconds. + verify_ssl: Whether to verify SSL certificates. + custom_headers: Custom HTTP headers for requests. + """ + self._params = PrometheusParameters( + host=host, + port=port, + scheme=scheme, + username=username, + password=password, + timeout=timeout, + verify_ssl=verify_ssl, + custom_headers=custom_headers, + ) + self._session = requests.Session() + if username and password: + self._session.auth = (username, password) + if custom_headers: + self._session.headers.update(custom_headers) + + @classmethod + def param_class(cls) -> Type[PrometheusParameters]: + """Return the parameter class.""" + return PrometheusParameters + + @classmethod + def from_parameters(cls, parameters: PrometheusParameters) -> "PrometheusConnector": + """Create a connector from parameters.""" + return cls( + host=parameters.host, + port=parameters.port, + scheme=parameters.scheme, + username=parameters.username, + password=parameters.password, + timeout=parameters.timeout, + verify_ssl=parameters.verify_ssl, + custom_headers=parameters.custom_headers, + ) + + def _request( + self, method: str, endpoint: str, params: Optional[Dict] = None + ) -> Dict[str, Any]: + """Send an HTTP request to the Prometheus API. + + Args: + method: HTTP method ('GET' or 'POST'). + endpoint: API endpoint path (e.g., '/query'). + params: Query parameters or POST data. + + Returns: + Parsed JSON response data. + + Raises: + ConnectionError: If the request fails. + ValueError: If the API returns an error status. + """ + url = f"{self._params.api_url}{endpoint}" + try: + if method.upper() == "GET": + response = self._session.get( + url, + params=params, + timeout=self._params.timeout, + verify=self._params.verify_ssl, + ) + else: + response = self._session.post( + url, + data=params, + timeout=self._params.timeout, + verify=self._params.verify_ssl, + ) + response.raise_for_status() + except requests.exceptions.ConnectionError as error: + raise ConnectionError( + f"Failed to connect to Prometheus at {url}: {error}" + ) from error + except requests.exceptions.Timeout as error: + raise ConnectionError( + f"Request to Prometheus timed out after " + f"{self._params.timeout}s: {error}" + ) from error + except requests.exceptions.HTTPError as error: + raise ValueError( + f"Prometheus API returned HTTP error: {error}" + ) from error + + result = response.json() + if result.get("status") != "success": + error_type = result.get("errorType", "unknown") + error_message = result.get("error", "Unknown error") + raise ValueError( + f"Prometheus query failed [{error_type}]: {error_message}" + ) + return result.get("data", {}) + + def instant_query( + self, + query: str, + time: Optional[str] = None, + timeout: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """Execute an instant query against Prometheus. + + Args: + query: PromQL expression to evaluate. + time: Evaluation timestamp (RFC3339 or Unix timestamp). + Defaults to current server time. + timeout: Evaluation timeout. Overrides the global -query.timeout. + + Returns: + List of result dictionaries, each containing 'metric' labels + and 'value' (timestamp, value) pair. + """ + params: Dict[str, str] = {"query": query} + if time: + params["time"] = time + if timeout: + params["timeout"] = timeout + + data = self._request("GET", "/query", params) + return data.get("result", []) + + def range_query( + self, + query: str, + start: str, + end: str, + step: str = _DEFAULT_STEP, + timeout: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """Execute a range query against Prometheus. + + Args: + query: PromQL expression to evaluate. + start: Start timestamp (RFC3339 or Unix timestamp). + end: End timestamp (RFC3339 or Unix timestamp). + step: Query resolution step width (e.g., '15s', '1m', '5m'). + timeout: Evaluation timeout. + + Returns: + List of result dictionaries, each containing 'metric' labels + and 'values' (list of [timestamp, value] pairs). + """ + params: Dict[str, str] = { + "query": query, + "start": start, + "end": end, + "step": step, + } + if timeout: + params["timeout"] = timeout + + data = self._request("GET", "/query_range", params) + return data.get("result", []) + + def get_series( + self, + match: Union[str, List[str]], + start: Optional[str] = None, + end: Optional[str] = None, + ) -> List[Dict[str, str]]: + """Find time series matching label selectors. + + Args: + match: One or more series selectors (e.g., 'up', + '{job="prometheus"}'). + start: Start timestamp for the lookup window. + end: End timestamp for the lookup window. + + Returns: + List of label sets for matching series. + """ + if isinstance(match, str): + match = [match] + params: Dict[str, Any] = {"match[]": match} + if start: + params["start"] = start + if end: + params["end"] = end + + return self._request("GET", "/series", params) + + def get_label_names( + self, + match: Optional[Union[str, List[str]]] = None, + start: Optional[str] = None, + end: Optional[str] = None, + ) -> List[str]: + """Return a list of all label names. + + Args: + match: Optional series selectors to filter labels. + start: Start timestamp. + end: End timestamp. + + Returns: + Sorted list of label names. + """ + params: Dict[str, Any] = {} + if match: + if isinstance(match, str): + match = [match] + params["match[]"] = match + if start: + params["start"] = start + if end: + params["end"] = end + + return self._request("GET", "/labels", params) + + def get_label_values( + self, + label_name: str, + match: Optional[Union[str, List[str]]] = None, + start: Optional[str] = None, + end: Optional[str] = None, + ) -> List[str]: + """Return a list of values for a given label name. + + Args: + label_name: The label name to query values for. + match: Optional series selectors to filter results. + start: Start timestamp. + end: End timestamp. + + Returns: + List of label values. + """ + params: Dict[str, Any] = {} + if match: + if isinstance(match, str): + match = [match] + params["match[]"] = match + if start: + params["start"] = start + if end: + params["end"] = end + + return self._request("GET", f"/label/{label_name}/values", params) + + def get_targets(self, state: Optional[str] = None) -> Dict[str, Any]: + """Return an overview of the current state of scrape targets. + + Args: + state: Filter targets by state ('active', 'dropped', 'any'). + + Returns: + Dictionary with 'activeTargets' and 'droppedTargets' lists. + """ + params: Dict[str, str] = {} + if state: + params["state"] = state + return self._request("GET", "/targets", params) + + def get_rules(self, rule_type: Optional[str] = None) -> Dict[str, Any]: + """Return a list of alerting and recording rules. + + Args: + rule_type: Filter by rule type ('alert' or 'record'). + + Returns: + Dictionary with 'groups' containing rule definitions. + """ + params: Dict[str, str] = {} + if rule_type: + params["type"] = rule_type + return self._request("GET", "/rules", params) + + def get_alerts(self) -> List[Dict[str, Any]]: + """Return a list of all active alerts. + + Returns: + List of active alert dictionaries. + """ + data = self._request("GET", "/alerts") + return data.get("alerts", []) + + def check_health(self) -> bool: + """Check if the Prometheus server is healthy. + + Returns: + True if the server is healthy, False otherwise. + """ + try: + url = f"{self._params.base_url}/-/healthy" + response = self._session.get( + url, + timeout=self._params.timeout, + verify=self._params.verify_ssl, + ) + return response.status_code == 200 + except requests.exceptions.RequestException: + return False + + def get_metadata( + self, metric: Optional[str] = None, limit: Optional[int] = None + ) -> Dict[str, List[Dict[str, str]]]: + """Return metadata about metrics currently scraped. + + Args: + metric: Filter metadata for a specific metric name. + limit: Maximum number of metrics to return. + + Returns: + Dictionary mapping metric names to lists of metadata entries, + each containing 'type', 'help', and 'unit'. + """ + params: Dict[str, Any] = {} + if metric: + params["metric"] = metric + if limit is not None: + params["limit"] = str(limit) + return self._request("GET", "/metadata", params) + + def run(self, command: str, fetch: str = "all") -> List: + """Execute a PromQL query (implements BaseConnector interface). + + This method provides compatibility with the BaseConnector interface. + The 'command' parameter is treated as a PromQL expression for an + instant query. + + Args: + command: PromQL expression to evaluate. + fetch: Unused, kept for interface compatibility. + + Returns: + List of query results. + """ + results = self.instant_query(command) + return self._format_results(results) + + @staticmethod + def _format_results(results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Format Prometheus query results into a flat list. + + Args: + results: Raw Prometheus API result list. + + Returns: + List of formatted result dictionaries with 'metric', 'timestamp', + and 'value' keys. + """ + formatted = [] + for result in results: + metric_labels = result.get("metric", {}) + metric_name = metric_labels.get("__name__", "unknown") + + if "value" in result: + timestamp, value = result["value"] + formatted.append({ + "metric": metric_name, + "labels": metric_labels, + "timestamp": datetime.fromtimestamp( + timestamp, tz=timezone.utc + ).isoformat(), + "value": value, + }) + elif "values" in result: + for timestamp, value in result["values"]: + formatted.append({ + "metric": metric_name, + "labels": metric_labels, + "timestamp": datetime.fromtimestamp( + timestamp, tz=timezone.utc + ).isoformat(), + "value": value, + }) + return formatted + + @classmethod + def is_normal_type(cls) -> bool: + """Return whether the connector is a normal type.""" + return True + + def close(self): + """Close the HTTP session.""" + if self._session: + self._session.close() + + def __repr__(self) -> str: + """Return a string representation of the connector.""" + return ( + f"PrometheusConnector(" + f"host={self._params.host!r}, " + f"port={self._params.port}, " + f"scheme={self._params.scheme!r})" + ) diff --git a/packages/derisk-ext/src/derisk_ext/datasource/tests/__init__.py b/packages/derisk-ext/src/derisk_ext/datasource/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/derisk-ext/src/derisk_ext/datasource/tests/test_conn_prometheus.py b/packages/derisk-ext/src/derisk_ext/datasource/tests/test_conn_prometheus.py new file mode 100644 index 00000000..77320afe --- /dev/null +++ b/packages/derisk-ext/src/derisk_ext/datasource/tests/test_conn_prometheus.py @@ -0,0 +1,202 @@ +"""Tests for PrometheusConnector.""" + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from derisk_ext.datasource.conn_prometheus import ( + PrometheusConnector, + PrometheusParameters, +) + + +class TestPrometheusParameters: + """Tests for PrometheusParameters.""" + + def test_default_values(self): + params = PrometheusParameters() + assert params.host == "localhost" + assert params.port == 9090 + assert params.scheme == "http" + assert params.username is None + assert params.password is None + assert params.timeout == 30 + assert params.verify_ssl is True + + def test_base_url(self): + params = PrometheusParameters(host="prom.example.com", port=9090) + assert params.base_url == "http://prom.example.com:9090" + + def test_api_url(self): + params = PrometheusParameters( + host="prom.example.com", port=9090, scheme="https" + ) + assert params.api_url == "https://prom.example.com:9090/api/v1" + + +class TestPrometheusConnector: + """Tests for PrometheusConnector.""" + + def setup_method(self): + self.connector = PrometheusConnector( + host="localhost", port=9090, scheme="http" + ) + + def test_from_parameters(self): + params = PrometheusParameters( + host="prom.example.com", port=9090, scheme="https" + ) + connector = PrometheusConnector.from_parameters(params) + assert connector._params.host == "prom.example.com" + assert connector._params.scheme == "https" + + def test_db_type(self): + assert self.connector.db_type == "prometheus" + + def test_repr(self): + result = repr(self.connector) + assert "PrometheusConnector" in result + assert "localhost" in result + + @patch("derisk_ext.datasource.conn_prometheus.requests.Session") + def test_instant_query(self, mock_session_cls): + mock_session = MagicMock() + mock_response = MagicMock() + mock_response.json.return_value = { + "status": "success", + "data": { + "resultType": "vector", + "result": [ + { + "metric": {"__name__": "up", "job": "prometheus"}, + "value": [1700000000, "1"], + } + ], + }, + } + mock_response.raise_for_status = MagicMock() + mock_session.get.return_value = mock_response + mock_session_cls.return_value = mock_session + + connector = PrometheusConnector(host="localhost", port=9090) + results = connector.instant_query("up") + assert len(results) == 1 + assert results[0]["metric"]["__name__"] == "up" + + @patch("derisk_ext.datasource.conn_prometheus.requests.Session") + def test_range_query(self, mock_session_cls): + mock_session = MagicMock() + mock_response = MagicMock() + mock_response.json.return_value = { + "status": "success", + "data": { + "resultType": "matrix", + "result": [ + { + "metric": {"__name__": "up", "job": "prometheus"}, + "values": [ + [1700000000, "1"], + [1700000060, "1"], + ], + } + ], + }, + } + mock_response.raise_for_status = MagicMock() + mock_session.get.return_value = mock_response + mock_session_cls.return_value = mock_session + + connector = PrometheusConnector(host="localhost", port=9090) + results = connector.range_query( + "up", start="1700000000", end="1700000120", step="60s" + ) + assert len(results) == 1 + assert len(results[0]["values"]) == 2 + + @patch("derisk_ext.datasource.conn_prometheus.requests.Session") + def test_run_interface(self, mock_session_cls): + mock_session = MagicMock() + mock_response = MagicMock() + mock_response.json.return_value = { + "status": "success", + "data": { + "resultType": "vector", + "result": [ + { + "metric": {"__name__": "up", "job": "prometheus"}, + "value": [1700000000, "1"], + } + ], + }, + } + mock_response.raise_for_status = MagicMock() + mock_session.get.return_value = mock_response + mock_session_cls.return_value = mock_session + + connector = PrometheusConnector(host="localhost", port=9090) + results = connector.run("up") + assert len(results) == 1 + assert results[0]["metric"] == "up" + assert results[0]["value"] == "1" + + @patch("derisk_ext.datasource.conn_prometheus.requests.Session") + def test_check_health_success(self, mock_session_cls): + mock_session = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_session.get.return_value = mock_response + mock_session_cls.return_value = mock_session + + connector = PrometheusConnector(host="localhost", port=9090) + assert connector.check_health() is True + + @patch("derisk_ext.datasource.conn_prometheus.requests.Session") + def test_query_error_handling(self, mock_session_cls): + mock_session = MagicMock() + mock_response = MagicMock() + mock_response.json.return_value = { + "status": "error", + "errorType": "bad_data", + "error": "invalid expression", + } + mock_response.raise_for_status = MagicMock() + mock_session.get.return_value = mock_response + mock_session_cls.return_value = mock_session + + connector = PrometheusConnector(host="localhost", port=9090) + with pytest.raises(ValueError, match="invalid expression"): + connector.instant_query("invalid{") + + def test_format_results_instant(self): + raw_results = [ + { + "metric": {"__name__": "cpu_usage", "instance": "host1:9090"}, + "value": [1700000000, "0.85"], + } + ] + formatted = PrometheusConnector._format_results(raw_results) + assert len(formatted) == 1 + assert formatted[0]["metric"] == "cpu_usage" + assert formatted[0]["value"] == "0.85" + assert "timestamp" in formatted[0] + assert formatted[0]["labels"]["instance"] == "host1:9090" + + def test_format_results_range(self): + raw_results = [ + { + "metric": {"__name__": "cpu_usage"}, + "values": [ + [1700000000, "0.80"], + [1700000060, "0.85"], + [1700000120, "0.90"], + ], + } + ] + formatted = PrometheusConnector._format_results(raw_results) + assert len(formatted) == 3 + assert formatted[0]["value"] == "0.80" + assert formatted[2]["value"] == "0.90" + + def test_close(self): + self.connector.close()