diff --git a/src/glean/api_client/_hooks/registration.py b/src/glean/api_client/_hooks/registration.py index a064445e..8d8610f1 100644 --- a/src/glean/api_client/_hooks/registration.py +++ b/src/glean/api_client/_hooks/registration.py @@ -1,4 +1,5 @@ from .types import Hooks +from .server_url_normalizer import ServerURLNormalizerHook from .multipart_fix_hook import MultipartFileFieldFixHook from .agent_file_upload_error_hook import AgentFileUploadErrorHook from .x_glean import XGlean @@ -15,6 +16,9 @@ def init_hooks(hooks: Hooks): with an instance of a hook that implements that specific Hook interface Hooks are registered per SDK instance, and are valid for the lifetime of the SDK instance""" + # Register hook to normalize server URLs (prepend https:// if no scheme provided) + hooks.register_sdk_init_hook(ServerURLNormalizerHook()) + # Register hook to fix multipart file field names that incorrectly have '[]' suffix hooks.register_sdk_init_hook(MultipartFileFieldFixHook()) diff --git a/src/glean/api_client/_hooks/server_url_normalizer.py b/src/glean/api_client/_hooks/server_url_normalizer.py new file mode 100644 index 00000000..42561988 --- /dev/null +++ b/src/glean/api_client/_hooks/server_url_normalizer.py @@ -0,0 +1,21 @@ +"""Hook to normalize server URLs, prepending https:// if no scheme is provided.""" + +import re +from typing import Tuple +from .types import SDKInitHook +from glean.api_client.httpclient import HttpClient + + +def normalize_server_url(url: str) -> str: + normalized = url + if not re.match(r'^https?://', normalized, re.IGNORECASE): + normalized = f'https://{normalized}' + normalized = normalized.rstrip('/') + return normalized + + +class ServerURLNormalizerHook(SDKInitHook): + """Normalizes server URLs by prepending https:// if no scheme is provided.""" + + def sdk_init(self, base_url: str, client: HttpClient) -> Tuple[str, HttpClient]: + return normalize_server_url(base_url), client diff --git a/tests/test_server_url_normalizer.py b/tests/test_server_url_normalizer.py new file mode 100644 index 00000000..ef9f8696 --- /dev/null +++ b/tests/test_server_url_normalizer.py @@ -0,0 +1,71 @@ +"""Tests for the server URL normalizer hook.""" + +from unittest.mock import Mock + +import pytest + +from src.glean.api_client._hooks.server_url_normalizer import ( + ServerURLNormalizerHook, + normalize_server_url, +) +from src.glean.api_client.httpclient import HttpClient + + +class TestNormalizeServerUrl: + """Test cases for the normalize_server_url function.""" + + def test_no_scheme_prepends_https(self): + assert normalize_server_url("example.glean.com") == "https://example.glean.com" + + def test_https_preserved(self): + assert normalize_server_url("https://example.glean.com") == "https://example.glean.com" + + def test_http_localhost_preserved(self): + assert normalize_server_url("http://localhost:8080") == "http://localhost:8080" + + def test_http_non_localhost_preserved(self): + assert normalize_server_url("http://example.glean.com") == "http://example.glean.com" + + def test_trailing_slash_stripped(self): + assert normalize_server_url("https://example.glean.com/") == "https://example.glean.com" + + def test_multiple_trailing_slashes_stripped(self): + assert normalize_server_url("https://example.glean.com///") == "https://example.glean.com" + + def test_no_scheme_with_trailing_slash(self): + assert normalize_server_url("example.glean.com/") == "https://example.glean.com" + + def test_url_with_path(self): + assert normalize_server_url("https://example.glean.com/api/v1") == "https://example.glean.com/api/v1" + + def test_url_with_path_and_trailing_slash(self): + assert normalize_server_url("https://example.glean.com/api/v1/") == "https://example.glean.com/api/v1" + + def test_no_scheme_with_path(self): + assert normalize_server_url("example.glean.com/api/v1") == "https://example.glean.com/api/v1" + + def test_case_insensitive_scheme(self): + assert normalize_server_url("HTTPS://example.glean.com") == "HTTPS://example.glean.com" + assert normalize_server_url("HTTP://localhost") == "HTTP://localhost" + + +class TestServerURLNormalizerHook: + """Test cases for the ServerURLNormalizerHook.""" + + def setup_method(self): + self.hook = ServerURLNormalizerHook() + self.mock_client = Mock(spec=HttpClient) + + def test_sdk_init_normalizes_url(self): + result_url, result_client = self.hook.sdk_init("example.glean.com", self.mock_client) + assert result_url == "https://example.glean.com" + assert result_client == self.mock_client + + def test_sdk_init_preserves_client(self): + result_url, result_client = self.hook.sdk_init("https://example.glean.com", self.mock_client) + assert result_url == "https://example.glean.com" + assert result_client is self.mock_client + + +if __name__ == "__main__": + pytest.main([__file__])