diff --git a/airflow-core/src/airflow/config_templates/provider_config_fallback_defaults.cfg b/airflow-core/src/airflow/config_templates/provider_config_fallback_defaults.cfg index b49c633c5af1d..691fe6064e8f0 100644 --- a/airflow-core/src/airflow/config_templates/provider_config_fallback_defaults.cfg +++ b/airflow-core/src/airflow/config_templates/provider_config_fallback_defaults.cfg @@ -60,7 +60,7 @@ flower_url_prefix = flower_port = 5555 flower_basic_auth = sync_parallelism = 0 -celery_config_options = airflow.config_templates.default_celery.DEFAULT_CELERY_CONFIG +celery_config_options = airflow.providers.celery.executors.default_celery.DEFAULT_CELERY_CONFIG ssl_active = False ssl_key = ssl_cert = diff --git a/airflow-core/src/airflow/configuration.py b/airflow-core/src/airflow/configuration.py index 1426199c7e07e..be07df7f418c4 100644 --- a/airflow-core/src/airflow/configuration.py +++ b/airflow-core/src/airflow/configuration.py @@ -196,17 +196,17 @@ def __init__( *args, **kwargs, ): - configuration_description = retrieve_configuration_description(include_providers=False) + _configuration_description = retrieve_configuration_description(include_providers=False) # For those who would like to use a different data structure to keep defaults: # We have to keep the default values in a ConfigParser rather than in any other # data structure, because the values we have might contain %% which are ConfigParser # interpolation placeholders. The _default_values config parser will interpolate them # properly when we call get() on it. - _default_values = create_default_config_parser(configuration_description) + _default_values = create_default_config_parser(_configuration_description) from airflow.providers_manager import ProvidersManager super().__init__( - configuration_description, + _configuration_description, _default_values, ProvidersManager, create_default_config_parser, @@ -214,14 +214,13 @@ def __init__( *args, **kwargs, ) - self.configuration_description = configuration_description + self._configuration_description = _configuration_description self._default_values = _default_values if default_config is not None: self._update_defaults_from_string(default_config) self._update_logging_deprecated_template_to_one_from_defaults() self.is_validated = False self._suppress_future_warnings = False - self._providers_configuration_loaded = False @property def _validators(self) -> list[Callable[[], None]]: @@ -367,24 +366,6 @@ def write_custom_config( if content: file.write(f"{content}\n\n") - def _ensure_providers_config_loaded(self) -> None: - """Ensure providers configurations are loaded.""" - if not self._providers_configuration_loaded: - from airflow.providers_manager import ProvidersManager - - ProvidersManager().initialize_providers_configuration() - - def _ensure_providers_config_unloaded(self) -> bool: - """Ensure providers configurations are unloaded temporarily to load core configs. Returns True if providers get unloaded.""" - if self._providers_configuration_loaded: - self.restore_core_default_configuration() - return True - return False - - def _reload_provider_configs(self) -> None: - """Reload providers configuration.""" - self.load_providers_configuration() - def _upgrade_postgres_metastore_conn(self): """ Upgrade SQL schemas. @@ -514,7 +495,7 @@ def expand_all_configuration_values(self): for key, value in self.items(section): if value is not None: if self.has_option(section, key): - self.remove_option(section, key) + self.remove_option(section, key, remove_default=False) if self.is_template(section, key) or not isinstance(value, str): self.set(section, key, value) else: @@ -525,11 +506,6 @@ def remove_all_read_configurations(self): for section in self.sections(): self.remove_section(section) - @property - def providers_configuration_loaded(self) -> bool: - """Checks if providers have been loaded.""" - return self._providers_configuration_loaded - def _get_config_value_from_secret_backend(self, config_key: str) -> str | None: """ Override to use module-level function that reads from global conf. @@ -644,16 +620,18 @@ def write_default_airflow_configuration_if_needed() -> AirflowConfigParser: # We know that fernet_key is not set, so we can generate it, set as global key # and also write it to the config file so that same key will be used next time _SecretKeys.fernet_key = _generate_fernet_key() - conf.configuration_description["core"]["options"]["fernet_key"]["default"] = ( + conf._configuration_description["core"]["options"]["fernet_key"]["default"] = ( _SecretKeys.fernet_key ) conf._default_values.set("core", "fernet_key", _SecretKeys.fernet_key) _SecretKeys.jwt_secret_key = b64encode(os.urandom(16)).decode("utf-8") - conf.configuration_description["api_auth"]["options"]["jwt_secret"]["default"] = ( + conf._configuration_description["api_auth"]["options"]["jwt_secret"]["default"] = ( _SecretKeys.jwt_secret_key ) conf._default_values.set("api_auth", "jwt_secret", _SecretKeys.jwt_secret_key) + # Invalidate cached configuration_description so it recomputes with the updated base + conf.invalidate_cache() pathlib.Path(airflow_config.__fspath__()).touch() make_group_other_inaccessible(airflow_config.__fspath__()) with open(airflow_config, "w") as file: diff --git a/airflow-core/src/airflow/providers_manager.py b/airflow-core/src/airflow/providers_manager.py index 02f5a6957c936..db5ecef66b7e6 100644 --- a/airflow-core/src/airflow/providers_manager.py +++ b/airflow-core/src/airflow/providers_manager.py @@ -613,10 +613,6 @@ def initialize_providers_configuration(self): """Lazy initialization of provider configuration metadata and merge it into ``conf``.""" self.initialize_providers_list() self._discover_config() - # Imported lazily to avoid a configuration/providers_manager import cycle during startup. - from airflow.configuration import conf - - conf.load_providers_configuration() @provider_info_cache("plugins") def initialize_providers_plugins(self): @@ -1455,10 +1451,6 @@ def provider_configs(self) -> list[tuple[str, dict[str, Any]]]: self.initialize_providers_configuration() return sorted(self._provider_configs.items(), key=lambda x: x[0]) - @property - def already_initialized_provider_configs(self) -> list[tuple[str, dict[str, Any]]]: - return sorted(self._provider_configs.items(), key=lambda x: x[0]) - def _cleanup(self): self._initialized_cache.clear() self._provider_dict.clear() diff --git a/airflow-core/tests/unit/core/test_configuration.py b/airflow-core/tests/unit/core/test_configuration.py index 9ee2f73753397..2d53d69c1d554 100644 --- a/airflow-core/tests/unit/core/test_configuration.py +++ b/airflow-core/tests/unit/core/test_configuration.py @@ -45,7 +45,13 @@ from airflow.providers_manager import ProvidersManager from airflow.sdk.execution_time.secrets import DEFAULT_SECRETS_SEARCH_PATH_WORKERS -from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.config import ( + CFG_FALLBACK_CONFIG_OPTIONS, + PROVIDER_METADATA_CONFIG_OPTIONS, + PROVIDER_METADATA_OVERRIDES_CFG_FALLBACK, + conf_vars, + create_fresh_airflow_config, +) from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker from tests_common.test_utils.reset_warning_registry import reset_warning_registry from unit.utils.test_config import ( @@ -1876,50 +1882,182 @@ def test_sensitive_values(): @skip_if_force_lowest_dependencies_marker -def test_restore_and_reload_provider_configuration(): +def test_provider_configuration_toggle_with_context_manager(): + """Test that make_sure_configuration_loaded toggles provider config on/off.""" from airflow.settings import conf - assert conf.providers_configuration_loaded is True + assert conf._use_providers_configuration is True + # With providers enabled, the provider value is returned via the fallback lookup chain. assert conf.get("celery", "celery_app_name") == "airflow.providers.celery.executors.celery_executor" - conf.restore_core_default_configuration() - assert conf.providers_configuration_loaded is False - # built-in pre-2-7 celery executor - assert conf.get("celery", "celery_app_name") == "airflow.executors.celery_executor" - conf.load_providers_configuration() - assert conf.providers_configuration_loaded is True + + with conf.make_sure_configuration_loaded(with_providers=False): + assert conf._use_providers_configuration is False + with pytest.raises( + AirflowConfigException, + match=re.escape("section/key [celery/celery_app_name] not found in config"), + ): + conf.get("celery", "celery_app_name") + # After the context manager exits, provider config is restored. + assert conf._use_providers_configuration is True assert conf.get("celery", "celery_app_name") == "airflow.providers.celery.executors.celery_executor" @skip_if_force_lowest_dependencies_marker -def test_error_when_contributing_to_existing_section(): +def test_provider_sections_do_not_overlap_with_core(): + """Test that provider config sections don't overlap with core configuration sections.""" from airflow.settings import conf - with conf.make_sure_configuration_loaded(with_providers=True): - assert conf.providers_configuration_loaded is True - assert conf.get("celery", "celery_app_name") == "airflow.providers.celery.executors.celery_executor" - conf.restore_core_default_configuration() - assert conf.providers_configuration_loaded is False - conf.configuration_description["celery"] = { - "description": "Celery Executor configuration", - "options": { - "celery_app_name": { - "default": "test", - } - }, - } - conf._default_values.add_section("celery") - conf._default_values.set("celery", "celery_app_name", "test") - assert conf.get("celery", "celery_app_name") == "test" - # patching restoring_core_default_configuration to avoid reloading the defaults - with patch.object(conf, "restore_core_default_configuration"): + core_sections = set(conf._configuration_description.keys()) + provider_sections = set(conf._provider_metadata_configuration_description.keys()) + overlap = core_sections & provider_sections + assert not overlap, ( + f"Provider configuration sections overlap with core sections: {overlap}. " + "Providers must only add new sections, not contribute to existing ones." + ) + + +@skip_if_force_lowest_dependencies_marker +class TestProviderConfigPriority: + """Tests that conf.get and conf.has_option respect provider metadata and cfg fallbacks with correct priority.""" + + @pytest.mark.parametrize( + ("section", "option", "expected"), + PROVIDER_METADATA_CONFIG_OPTIONS, + ids=[f"{s}.{o}" for s, o, _ in PROVIDER_METADATA_CONFIG_OPTIONS], + ) + def test_get_returns_provider_metadata_value(self, section, option, expected): + """conf.get returns provider metadata (provider.yaml) values.""" + from airflow.settings import conf + + assert conf.get(section, option) == expected + + @pytest.mark.parametrize( + ("section", "option", "expected"), + CFG_FALLBACK_CONFIG_OPTIONS, + ids=[f"{s}.{o}" for s, o, _ in CFG_FALLBACK_CONFIG_OPTIONS], + ) + def test_cfg_fallback_has_expected_value(self, section, option, expected): + """provider_config_fallback_defaults.cfg contains expected default values.""" + from airflow.settings import conf + + assert conf.get_from_provider_cfg_config_fallback_defaults(section, option) == expected + + @pytest.mark.parametrize( + ("section", "option", "expected"), + PROVIDER_METADATA_CONFIG_OPTIONS, + ids=[f"{s}.{o}" for s, o, _ in PROVIDER_METADATA_CONFIG_OPTIONS], + ) + def test_has_option_true_for_provider_metadata(self, section, option, expected): + """conf.has_option returns True for options defined in provider metadata.""" + from airflow.settings import conf + + assert conf.has_option(section, option) is True + + @pytest.mark.parametrize( + ("section", "option", "expected"), + CFG_FALLBACK_CONFIG_OPTIONS, + ids=[f"{s}.{o}" for s, o, _ in CFG_FALLBACK_CONFIG_OPTIONS], + ) + def test_has_option_true_for_cfg_fallback(self, section, option, expected): + """conf.has_option returns True for options in provider_config_fallback_defaults.cfg.""" + from airflow.settings import conf + + assert conf.has_option(section, option) is True + + def test_has_option_false_for_nonexistent_option(self): + """conf.has_option returns False for options not in any source.""" + from airflow.settings import conf + + assert conf.has_option("celery", "totally_nonexistent_option_xyz") is False + + @pytest.mark.parametrize( + ("section", "option", "metadata_value", "cfg_value"), + PROVIDER_METADATA_OVERRIDES_CFG_FALLBACK, + ids=[f"{s}.{o}" for s, o, _, _ in PROVIDER_METADATA_OVERRIDES_CFG_FALLBACK], + ) + def test_provider_metadata_overrides_cfg_fallback(self, section, option, metadata_value, cfg_value): + """Provider metadata values take priority over provider_config_fallback_defaults.cfg values.""" + from airflow.settings import conf + + assert conf.get(section, option) == metadata_value + assert conf.get_from_provider_cfg_config_fallback_defaults(section, option) == cfg_value + + @pytest.mark.parametrize( + ("section", "option", "metadata_value", "cfg_value"), + PROVIDER_METADATA_OVERRIDES_CFG_FALLBACK, + ids=[f"{s}.{o}" for s, o, _, _ in PROVIDER_METADATA_OVERRIDES_CFG_FALLBACK], + ) + def test_get_default_value_priority(self, section, option, metadata_value, cfg_value): + """get_default_value checks provider metadata before cfg fallback.""" + from airflow.settings import conf + + assert conf.get_default_value(section, option) == metadata_value + + @pytest.mark.parametrize( + ("section", "option", "expected"), + CFG_FALLBACK_CONFIG_OPTIONS + PROVIDER_METADATA_CONFIG_OPTIONS, + ids=[f"{s}.{o}" for s, o, _ in CFG_FALLBACK_CONFIG_OPTIONS + PROVIDER_METADATA_CONFIG_OPTIONS], + ) + def test_providers_disabled_dont_get_cfg_defaults_or_provider_metadata(self, section, option, expected): + """With providers disabled, conf.get raises for provider-only options.""" + test_conf = create_fresh_airflow_config() + with test_conf.make_sure_configuration_loaded(with_providers=False): with pytest.raises( AirflowConfigException, - match="The provider apache-airflow-providers-celery is attempting to contribute " - "configuration section celery that has already been added before. " - "The source of it: Airflow's core package", + match=re.escape(f"section/key [{section}/{option}] not found in config"), ): - conf.load_providers_configuration() - assert conf.get("celery", "celery_app_name") == "test" + test_conf.get(section, option) + + def test_provider_section_absent_when_providers_disabled(self): + """Provider-contributed sections are excluded from configuration_description when providers disabled.""" + test_conf = create_fresh_airflow_config() + with test_conf.make_sure_configuration_loaded(with_providers=False): + desc = test_conf.configuration_description + provider_only_sections = set(test_conf._provider_metadata_configuration_description.keys()) + for section in provider_only_sections: + if section not in test_conf._configuration_description: + assert section not in desc + + @pytest.mark.parametrize( + ("section", "option", "expected"), + CFG_FALLBACK_CONFIG_OPTIONS, + ids=[f"{s}.{o}" for s, o, _ in CFG_FALLBACK_CONFIG_OPTIONS], + ) + def test_has_option_returns_false_for_cfg_fallback_when_providers_disabled( + self, section, option, expected + ): + """With providers disabled, conf.has_option returns False for cfg-fallback-only options.""" + test_conf = create_fresh_airflow_config() + with test_conf.make_sure_configuration_loaded(with_providers=False): + assert test_conf.has_option(section, option) is False + + @pytest.mark.parametrize( + ("section", "option", "expected"), + PROVIDER_METADATA_CONFIG_OPTIONS, + ids=[f"{s}.{o}" for s, o, _ in PROVIDER_METADATA_CONFIG_OPTIONS], + ) + def test_has_option_returns_false_for_provider_metadata_when_providers_disabled( + self, section, option, expected + ): + """With providers disabled, conf.has_option returns False for provider-metadata-only options.""" + test_conf = create_fresh_airflow_config() + with test_conf.make_sure_configuration_loaded(with_providers=False): + assert test_conf.has_option(section, option) is False + + def test_env_var_overrides_provider_values(self): + """Environment variables override both provider metadata and cfg fallback values.""" + from airflow.settings import conf + + with mock.patch.dict("os.environ", {"AIRFLOW__CELERY__CELERY_APP_NAME": "env_override"}): + assert conf.get("celery", "celery_app_name") == "env_override" + + def test_user_config_overrides_provider_values(self): + """User-set config values (airflow.cfg) override provider defaults.""" + from airflow.settings import conf + + custom_value = "my_custom.celery_executor" + with conf_vars({("celery", "celery_app_name"): custom_value}): + assert conf.get("celery", "celery_app_name") == custom_value # Technically it's not a DB test, but we want to make sure it's not interfering with xdist non-db tests diff --git a/devel-common/src/tests_common/test_utils/config.py b/devel-common/src/tests_common/test_utils/config.py index e1d59b9f83aa8..9a278346d3068 100644 --- a/devel-common/src/tests_common/test_utils/config.py +++ b/devel-common/src/tests_common/test_utils/config.py @@ -19,6 +19,48 @@ import contextlib import os +from typing import TYPE_CHECKING, Literal, overload + +if TYPE_CHECKING: + from airflow.configuration import AirflowConfigParser + from airflow.sdk.configuration import AirflowSDKConfigParser + +# Provider config test data for parametrized tests. +# Options listed here must NOT be overridden in unit_tests.cfg, otherwise +# tests that assert default values via conf.get() will see the unit_tests.cfg +# value instead. + +# (section, option, expected_value) +# Options defined in provider metadata (provider.yaml) with non-None defaults. +PROVIDER_METADATA_CONFIG_OPTIONS: list[tuple[str, str, str]] = [ + ("celery", "celery_app_name", "airflow.providers.celery.executors.celery_executor"), + ("celery", "worker_enable_remote_control", "true"), + ("celery", "task_acks_late", "True"), + ("kubernetes_executor", "namespace", "default"), + ("kubernetes_executor", "delete_worker_pods", "True"), + ("celery_kubernetes_executor", "kubernetes_queue", "kubernetes"), +] + +# Options defined in provider_config_fallback_defaults.cfg. +CFG_FALLBACK_CONFIG_OPTIONS: list[tuple[str, str, str]] = [ + ("celery", "flower_host", "0.0.0.0"), + ("celery", "pool", "prefork"), + ("celery", "worker_precheck", "False"), + ("kubernetes_executor", "in_cluster", "True"), + ("kubernetes_executor", "verify_ssl", "True"), + ("elasticsearch", "end_of_log_mark", "end_of_log"), +] + +# Options where provider metadata and cfg fallback have DIFFERENT default values. +# (section, option, metadata_value, cfg_fallback_value) +PROVIDER_METADATA_OVERRIDES_CFG_FALLBACK: list[tuple[str, str, str, str]] = [ + ( + "celery", + "celery_app_name", + "airflow.providers.celery.executors.celery_executor", + "airflow.executors.celery_executor", + ), +] @contextlib.contextmanager @@ -76,6 +118,37 @@ def conf_vars(overrides): settings.configure_vars() +@overload +def create_fresh_airflow_config(variant: Literal["core"] = ...) -> AirflowConfigParser: ... + + +@overload +def create_fresh_airflow_config(variant: Literal["task-sdk"]) -> AirflowSDKConfigParser: ... + + +def create_fresh_airflow_config( + variant: Literal["core", "task-sdk"] = "core", +) -> AirflowConfigParser | AirflowSDKConfigParser: + """Create a fresh, fully-initialized config parser independent of the singleton. + + Use this instead of ``from airflow.settings import conf`` when the test mutates + parser state (e.g. ``make_sure_configuration_loaded(with_providers=False)``). + A fresh instance avoids interference with other tests that may run in parallel. + + :param variant: Which config parser to create — ``"core"`` (default) for the + full Airflow config, or ``"task-sdk"`` for the lightweight SDK config. + """ + if variant == "core": + from airflow.configuration import initialize_config as initialize_core_config + + return initialize_core_config() + if variant == "task-sdk": + from airflow.sdk.configuration import initialize_config as initialize_sdk_config + + return initialize_sdk_config() + raise ValueError(f"Unknown variant: {variant!r}. Expected 'core' or 'task-sdk'.") + + @contextlib.contextmanager def env_vars(overrides): """ diff --git a/scripts/in_container/run_check_default_configuration.py b/scripts/in_container/run_check_default_configuration.py index b32917f2ec73c..10340d43b7949 100755 --- a/scripts/in_container/run_check_default_configuration.py +++ b/scripts/in_container/run_check_default_configuration.py @@ -36,23 +36,40 @@ if __name__ == "__main__": with tempfile.TemporaryDirectory() as tmp_dir: + # We need to explicitly set the logging level to ERROR to avoid debug logs from "airflow config lint" command that can spoil the output and make the test fail. + # This is needed in case the default config has logging level set to DEBUG, but it does not hurt to set it explicitly in any case to avoid any unexpected debug logs from the command. + env = os.environ.copy() + env["AIRFLOW__LOGGING__LOGGING_LEVEL"] = "ERROR" + # Write default config cmd output to a temporary file default_config_file = os.path.join(tmp_dir, "airflow.cfg") with open(default_config_file, "w") as f: - result = subprocess.run(list_default_config_cmd, check=False, stdout=f) + result = subprocess.run( + list_default_config_cmd, check=False, stdout=f, stderr=subprocess.PIPE, env=env + ) if result.returncode != 0: print(f"\033[0;31mERROR: when running `{' '.join(list_default_config_cmd)}`\033[0m\n") + if result.stderr: + print(result.stderr.decode()) + print(f"Default config (if any) was written to: {default_config_file}") exit(1) # Run airflow config lint to check the default config - env = os.environ.copy() env["AIRFLOW_HOME"] = tmp_dir env["AIRFLOW_CONFIG"] = default_config_file result = subprocess.run(lint_config_cmd, check=False, capture_output=True, env=env) - output: str = result.stdout.decode().strip() - if result.returncode != 0 or expected_output not in output: - print(f"\033[0;31mERROR: when running `{' '.join(lint_config_cmd)}`\033[0m\n") + output: str = result.stdout.decode().strip() + if result.returncode != 0 or expected_output not in output: + print(f"\033[0;31mERROR: when running `{' '.join(lint_config_cmd)}`\033[0m\n") + print(output) + # log the stderr as well if available + if result.stderr: + print(f"\033[0;31mERROR: stderr from `{' '.join(lint_config_cmd)}`\033[0m\n") + print(result.stderr.decode()) + # log the default config that was generated for debugging + print("\033[0;31mGenerated default config for debugging:\033[0m\n") + with open(default_config_file) as f: + print(f.read()) + exit(1) print(output) - exit(1) - print(output) - exit(0) + exit(0) diff --git a/shared/configuration/src/airflow_shared/configuration/parser.py b/shared/configuration/src/airflow_shared/configuration/parser.py index 55f24c529e08d..e83590dccbdd2 100644 --- a/shared/configuration/src/airflow_shared/configuration/parser.py +++ b/shared/configuration/src/airflow_shared/configuration/parser.py @@ -269,34 +269,97 @@ def _lookup_sequence(self) -> list[Callable]: Subclasses can override this to customise lookup order. """ - return [ + lookup_methods = [ self._get_environment_variables, self._get_option_from_config_file, self._get_option_from_commands, self._get_option_from_secrets, self._get_option_from_defaults, - self._get_option_from_provider_cfg_config_fallbacks, - self._get_option_from_provider_metadata_config_fallbacks, ] + if self._use_providers_configuration: + # Provider fallback lookups are last so they have the lowest priority in the lookup sequence. + lookup_methods += [ + self._get_option_from_provider_metadata_config_fallbacks, + self._get_option_from_provider_cfg_config_fallbacks, + ] + return lookup_methods - def _get_config_sources_for_as_dict(self) -> list[tuple[str, ConfigParser]]: + @functools.cached_property + def configuration_description(self) -> dict[str, dict[str, Any]]: + """ + Return configuration description from multiple sources. + + Respects the ``_use_providers_configuration`` flag to decide whether to include + provider configuration. + + The merged description is built as follows: + + 1. Start from the base configuration description provided in ``__init__``, usually + loaded from ``config.yml`` in core. Values defined here are never overridden. + 2. Merge provider metadata from ``_provider_metadata_configuration_description``, + loaded from provider packages' ``get_provider_info`` method. Only adds missing + sections/options; does not overwrite existing entries from the base configuration. + 3. Merge default values from ``_provider_cfg_config_fallback_default_values``, + loaded from ``provider_config_fallback_defaults.cfg``. Only sets ``"default"`` + (and heuristically ``"sensitive"``) for options that do not already define them. + + Base configuration takes precedence, then provider metadata fills in missing + descriptions/options, and finally cfg-based fallbacks provide defaults only where + none are defined. + + We use ``cached_property`` to cache the merged result; clear this cache (via + ``invalidate_cache``) when toggling ``_use_providers_configuration``. + """ + if not self._use_providers_configuration: + return self._configuration_description + + merged_description: dict[str, dict[str, Any]] = deepcopy(self._configuration_description) + + # Merge full provider config descriptions (with metadata like sensitive, description, etc.) + # from provider packages' get_provider_info method, reusing the cached raw dict. + for section, section_content in self._provider_metadata_configuration_description.items(): + if section not in merged_description: + merged_description[section] = deepcopy(section_content) + else: + existing_options = merged_description[section].setdefault("options", {}) + for option, option_content in section_content.get("options", {}).items(): + if option not in existing_options: + existing_options[option] = deepcopy(option_content) + + # Merge default values from cfg-based fallbacks (key=value only, no metadata). + # Uses setdefault so provider metadata values above take priority. + cfg = self._provider_cfg_config_fallback_default_values + for section in cfg.sections(): + section_options = merged_description.setdefault(section, {"options": {}}).setdefault( + "options", {} + ) + for option in cfg.options(section): + opt_dict = section_options.setdefault(option, {}) + opt_dict.setdefault("default", cfg.get(section, option)) + # For cfg-only options with no provider metadata, infer sensitivity from name. + if "sensitive" not in opt_dict and option.endswith(("password", "secret")): + opt_dict["sensitive"] = True + + return merged_description + + @property + def _config_sources_for_as_dict(self) -> list[tuple[str, ConfigParser]]: """Override the base method to add provider fallbacks when providers are loaded.""" - sources: list[tuple[str, ConfigParser]] = [ - ("default", self._default_values), - ("airflow.cfg", self), - ] - if self._providers_configuration_loaded: - sources.insert( - 0, + sources: list[tuple[str, ConfigParser]] = [] + if self._use_providers_configuration: + # Provider fallback defaults are listed first so they have the lowest priority + # in as_dict()'s "last source wins" semantics. + sources += [ + ("provider-cfg-fallback-defaults", self._provider_cfg_config_fallback_default_values), ( "provider-metadata-fallback-defaults", self._provider_metadata_config_fallback_default_values, ), - ) - sources.insert( - 0, - ("provider-cfg-fallback-defaults", self._provider_cfg_config_fallback_default_values), - ) + ] + sources += [ + ("default", self._default_values), + ("airflow.cfg", self), + ] return sources def _get_option_from_provider_cfg_config_fallbacks( @@ -327,7 +390,7 @@ def _get_option_from_provider_metadata_config_fallbacks( ) -> str | ValueNotFound: """Get config option from provider metadata fallback defaults.""" value = self.get_from_provider_metadata_config_fallback_defaults(section, key, **kwargs) - if value is not None: + if value is not VALUE_NOT_FOUND_SENTINEL: return value return VALUE_NOT_FOUND_SENTINEL @@ -339,20 +402,25 @@ def get_from_provider_cfg_config_fallback_defaults(self, section: str, key: str, section, key, fallback=None, raw=raw, vars=vars_ ) + @functools.cached_property + def _provider_metadata_configuration_description(self) -> dict[str, dict[str, Any]]: + """Raw provider configuration descriptions with full metadata (sensitive, description, etc.).""" + result: dict[str, dict[str, Any]] = {} + for _, config in self._provider_manager_type().provider_configs: + result.update(config) + return result + @functools.cached_property def _provider_metadata_config_fallback_default_values(self) -> ConfigParser: """Return Provider metadata config fallback default values.""" - base_configuration_description: dict[str, dict[str, Any]] = {} - for _, config in self._provider_manager_type().provider_configs: - base_configuration_description.update(config) - return self._create_default_config_parser_callable(base_configuration_description) + return self._create_default_config_parser_callable(self._provider_metadata_configuration_description) def get_from_provider_metadata_config_fallback_defaults(self, section: str, key: str, **kwargs) -> Any: """Get provider metadata config fallback default values.""" raw = kwargs.get("raw", False) vars_ = kwargs.get("vars") return self._provider_metadata_config_fallback_default_values.get( - section, key, fallback=None, raw=raw, vars=vars_ + section, key, fallback=VALUE_NOT_FOUND_SENTINEL, raw=raw, vars=vars_ ) @property @@ -428,8 +496,7 @@ def __init__( :param provider_config_fallback_defaults_cfg_path: Path to the `provider_config_fallback_defaults.cfg` file. """ super().__init__(*args, **kwargs) - self.configuration_description = configuration_description - self._base_configuration_description = deepcopy(configuration_description) + self._configuration_description = configuration_description self._default_values = _default_values self._provider_manager_type = provider_manager_type self._create_default_config_parser_callable = create_default_config_parser_callable @@ -438,7 +505,9 @@ def __init__( ) self._suppress_future_warnings = False self.upgraded_values: dict[tuple[str, str], str] = {} - self._providers_configuration_loaded = False + # The _use_providers_configuration flag will always be True unless we call `write(include_providers=False)` or `with self.make_sure_configuration_loaded(with_providers=False)`. + # Even we call those methods, the flag will be set back to True after the method is done, so it only affects the current call to `as_dict()` and does not have any effect on subsequent calls. + self._use_providers_configuration = True def invalidate_cache(self) -> None: """ @@ -454,6 +523,11 @@ def invalidate_cache(self) -> None: ): self.__dict__.pop(attr_name, None) + def _invalidate_provider_flag_caches(self) -> None: + """Invalidate caches related to provider configuration flags.""" + self.__dict__.pop("configuration_description", None) + self.__dict__.pop("sensitive_config_values", None) + @functools.cached_property def inversed_deprecated_options(self): """Build inverse mapping from old options to new options.""" @@ -520,12 +594,13 @@ def _update_defaults_from_string(self, config_string: str) -> None: def get_default_value(self, section: str, key: str, fallback: Any = None, raw=False, **kwargs) -> Any: """ - Retrieve default value from default config parser. + Retrieve default value from default config parser, including provider fallbacks. - This will retrieve the default value from the default config parser. Optionally a raw, stored - value can be retrieved by setting skip_interpolation to True. This is useful for example when - we want to write the default value to a file, and we don't want the interpolation to happen - as it is going to be done later when the config is read. + This will retrieve the default value from the core default config parser first. If not found + and providers configuration is loaded, it also checks provider fallback defaults. + Optionally a raw, stored value can be retrieved by setting skip_interpolation to True. + This is useful for example when we want to write the default value to a file, and we don't + want the interpolation to happen as it is going to be done later when the config is read. :param section: section of the config :param key: key to use @@ -534,7 +609,18 @@ def get_default_value(self, section: str, key: str, fallback: Any = None, raw=Fa :param kwargs: other args :return: """ - value = self._default_values.get(section, key, fallback=fallback, **kwargs) + value = self._default_values.get(section, key, fallback=VALUE_NOT_FOUND_SENTINEL, **kwargs) + # Provider metadata has higher priority than cfg fallback — check it first. + if value is VALUE_NOT_FOUND_SENTINEL and self._use_providers_configuration: + value = self._provider_metadata_config_fallback_default_values.get( + section, key, fallback=VALUE_NOT_FOUND_SENTINEL, **kwargs + ) + if value is VALUE_NOT_FOUND_SENTINEL and self._use_providers_configuration: + value = self._provider_cfg_config_fallback_default_values.get( + section, key, fallback=VALUE_NOT_FOUND_SENTINEL, **kwargs + ) + if value is VALUE_NOT_FOUND_SENTINEL: + value = fallback if raw and value is not None: return value.replace("%", "%%") return value @@ -1189,51 +1275,38 @@ def load_providers_configuration(self) -> None: """ Load configuration for providers. - This should be done after initial configuration have been performed. Initializing and discovering - providers is an expensive operation and cannot be performed when we load configuration for the first - time when airflow starts, because we initialize configuration very early, during importing of the - `airflow` package and the module is not yet ready to be used when it happens and until configuration - and settings are loaded. Therefore, in order to reload provider configuration we need to additionally - load provider - specific configuration. + .. deprecated:: 3.2.0 + Provider configuration is now loaded lazily via the ``configuration_description`` + cached property. This method is kept for backwards compatibility and will be + removed in a future version. """ - log.debug("Loading providers configuration") - - self.restore_core_default_configuration() - for provider, config in self._provider_manager_type().already_initialized_provider_configs: - for provider_section, provider_section_content in config.items(): - provider_options = provider_section_content["options"] - section_in_current_config = self.configuration_description.get(provider_section) - if not section_in_current_config: - self.configuration_description[provider_section] = deepcopy(provider_section_content) - section_in_current_config = self.configuration_description.get(provider_section) - section_in_current_config["source"] = f"default-{provider}" - for option in provider_options: - section_in_current_config["options"][option]["source"] = f"default-{provider}" - else: - section_source = section_in_current_config.get("source", "Airflow's core package").split( - "default-" - )[-1] - raise AirflowConfigException( - f"The provider {provider} is attempting to contribute " - f"configuration section {provider_section} that " - f"has already been added before. The source of it: {section_source}. " - "This is forbidden. A provider can only add new sections. It " - "cannot contribute options to existing sections or override other " - "provider's configuration.", - UserWarning, - ) - self._default_values = self._create_default_config_parser_callable(self.configuration_description) - # Cached properties derived from configuration_description (e.g. sensitive_config_values) need - # to be recomputed now that provider config has been merged in. - self.invalidate_cache() - self._providers_configuration_loaded = True + warnings.warn( + "load_providers_configuration() is deprecated. " + "Provider configuration is now loaded lazily via the " + "`configuration_description` cached property.", + DeprecationWarning, + stacklevel=2, + ) + self._use_providers_configuration = True + self._invalidate_provider_flag_caches() def restore_core_default_configuration(self) -> None: - """Restore the parser state before provider-contributed sections were loaded.""" - self.configuration_description = deepcopy(self._base_configuration_description) - self._default_values = self._create_default_config_parser_callable(self.configuration_description) - self.invalidate_cache() - self._providers_configuration_loaded = False + """ + Restore the parser state before provider-contributed sections were loaded. + + .. deprecated:: 3.2.0 + Use ``make_sure_configuration_loaded(with_providers=False)`` context manager + instead. This method is kept for backwards compatibility and will be removed + in a future version. + """ + warnings.warn( + "restore_core_default_configuration() is deprecated. " + "Use `make_sure_configuration_loaded(with_providers=False)` instead.", + DeprecationWarning, + stacklevel=2, + ) + self._use_providers_configuration = False + self._invalidate_provider_flag_caches() @overload # type: ignore[override] def get(self, section: str, key: str, fallback: str = ..., **kwargs) -> str: ... @@ -1521,6 +1594,17 @@ def read_dict( # type: ignore[override] """ super().read_dict(dictionary=dictionary, source=source) + def _has_section_in_any_defaults(self, section: str) -> bool: + """Check if section exists in core defaults or provider fallback defaults.""" + if self._default_values.has_section(section): + return True + if self._use_providers_configuration: + if self._provider_cfg_config_fallback_default_values.has_section(section): + return True + if self._provider_metadata_config_fallback_default_values.has_section(section): + return True + return False + def get_sections_including_defaults(self) -> list[str]: """ Retrieve all sections from the configuration parser, including sections defined by built-in defaults. @@ -1563,13 +1647,13 @@ def has_option(self, section: str, option: str, lookup_from_deprecated: bool = T value = self.get( section, option, - fallback=None, + fallback=VALUE_NOT_FOUND_SENTINEL, _extra_stacklevel=1, suppress_warnings=True, lookup_from_deprecated=lookup_from_deprecated, **kwargs, ) - if value is None: + if value is VALUE_NOT_FOUND_SENTINEL: return False return True except (NoOptionError, NoSectionError, AirflowConfigException): @@ -1602,7 +1686,7 @@ def remove_option(self, section: str, option: str, remove_default: bool = True): if super().has_option(section, option): super().remove_option(section, option) - if self.get_default_value(section, option) is not None and remove_default: + if remove_default and self._default_values.has_option(section, option): self._default_values.remove_option(section, option) def optionxform(self, optionstr: str) -> str: @@ -1669,7 +1753,7 @@ def as_dict( config_sources: ConfigSourcesType = {} # We check sequentially all those sources and the last one we saw it in will "win" - configs = self._get_config_sources_for_as_dict() + configs = self._config_sources_for_as_dict self._replace_config_with_display_sources( config_sources, @@ -1927,7 +2011,7 @@ def write( # type: ignore[override] section_config_description = self.configuration_description.get(section_to_write, {}) if section_to_write != section and section is not None: continue - if self._default_values.has_section(section_to_write) or self.has_section(section_to_write): + if self._has_section_in_any_defaults(section_to_write) or self.has_section(section_to_write): self._write_section_header( file, include_descriptions, section_config_description, section_to_write ) @@ -1968,28 +2052,21 @@ def make_sure_configuration_loaded(self, with_providers: bool) -> Generator[None """ Make sure configuration is loaded with or without providers. - This happens regardless if the provider configuration has been loaded before or not. - Restores configuration to the state before entering the context. + The context manager will only toggle the `self._use_providers_configuration` flag if `with_providers` is False, and will reset `self._use_providers_configuration` to True after the context block. + Nop for `with_providers=True` as the configuration already loads providers configuration by default. :param with_providers: whether providers should be loaded """ - needs_reload = False - if with_providers: - self._ensure_providers_config_loaded() - else: - needs_reload = self._ensure_providers_config_unloaded() - yield - if needs_reload: - self._reload_provider_configs() - - def _ensure_providers_config_loaded(self) -> None: - """Ensure providers configurations are loaded.""" - raise NotImplementedError("Subclasses must implement _ensure_providers_config_loaded method") - - def _ensure_providers_config_unloaded(self) -> bool: - """Ensure providers configurations are unloaded temporarily to load core configs. Returns True if providers get unloaded.""" - raise NotImplementedError("Subclasses must implement _ensure_providers_config_unloaded method") - - def _reload_provider_configs(self) -> None: - """Reload providers configuration.""" - raise NotImplementedError("Subclasses must implement _reload_provider_configs method") + if not with_providers: + self._use_providers_configuration = False + # Only invalidate cached properties that depend on _use_providers_configuration. + # Do NOT use invalidate_cache() here — it would also evict expensive provider-discovery + # caches (_provider_metadata_configuration_description, _provider_metadata_config_fallback_default_values) + # that don't depend on this flag. + self._invalidate_provider_flag_caches() + try: + yield + finally: + if not with_providers: + self._use_providers_configuration = True + self._invalidate_provider_flag_caches() diff --git a/shared/configuration/tests/configuration/test_parser.py b/shared/configuration/tests/configuration/test_parser.py index 45b8ee1fd57ee..6c1b5c1989f43 100644 --- a/shared/configuration/tests/configuration/test_parser.py +++ b/shared/configuration/tests/configuration/test_parser.py @@ -45,10 +45,6 @@ class _NoOpProvidersManager: def provider_configs(self): return [] - @property - def already_initialized_provider_configs(self): - return [] - def _create_empty_config_parser(desc: dict) -> ConfigParser: return ConfigParser() @@ -92,39 +88,13 @@ def __init__( *args, **kwargs, ) - self.configuration_description = configuration_description + self._configuration_description = configuration_description self._default_values = _default_values self._suppress_future_warnings = False if default_config is not None: self._update_defaults_from_string(default_config) - def _update_defaults_from_string(self, config_string: str): - """Update defaults from string for testing.""" - parser = ConfigParser() - parser.read_string(config_string) - for section in parser.sections(): - if section not in self._default_values.sections(): - self._default_values.add_section(section) - for key, value in parser.items(section): - self._default_values.set(section, key, value) - - def _ensure_providers_config_loaded(self) -> None: - """Load provider configuration for tests when requested.""" - if not self._providers_configuration_loaded: - self.load_providers_configuration() - - def _ensure_providers_config_unloaded(self) -> bool: - """Unload provider configuration for tests when requested.""" - if self._providers_configuration_loaded: - self.restore_core_default_configuration() - return True - return False - - def _reload_provider_configs(self) -> None: - """Reload provider configuration for tests after temporary unloads.""" - self.load_providers_configuration() - class TestAirflowConfigParser: """Test the shared AirflowConfigParser parser methods.""" @@ -836,16 +806,6 @@ def test_get_mandatory_list_value(self): with pytest.raises(ValueError, match=r"The value test/missing_key should be set!"): test_conf.get_mandatory_list_value("test", "missing_key", fallback=None) - def test_as_dict_only_materializes_provider_sources_after_loading_providers(self): - test_conf = AirflowConfigParser() - - test_conf.as_dict(display_source=True) - assert "_provider_metadata_config_fallback_default_values" not in test_conf.__dict__ - - test_conf.load_providers_configuration() - test_conf.as_dict(display_source=True) - assert "_provider_metadata_config_fallback_default_values" in test_conf.__dict__ - def test_write_materializes_provider_sources_in_requested_context(self): test_conf = AirflowConfigParser() @@ -855,7 +815,12 @@ def test_write_materializes_provider_sources_in_requested_context(self): test_conf.write(StringIO(), include_sources=True, include_providers=True) assert "_provider_metadata_config_fallback_default_values" in test_conf.__dict__ - def test_get_uses_provider_metadata_fallback_before_loading_providers(self): + # we will not clear the cached _provider_metadata_config_fallback_default_values after the first call + test_conf.write(StringIO(), include_sources=True, include_providers=False) + assert "_provider_metadata_config_fallback_default_values" in test_conf.__dict__ + + def test_get_resolves_provider_metadata_fallback(self): + """conf.get returns values from provider metadata for provider-only sections.""" provider_configs = [ ( "apache-airflow-providers-test", @@ -876,20 +841,55 @@ class ProvidersManagerWithConfig: def provider_configs(self): return provider_configs + test_conf = AirflowConfigParser( + provider_manager_type=ProvidersManagerWithConfig, + create_default_config_parser_callable=_create_default_config_parser, + ) + + assert test_conf._use_providers_configuration is True + assert test_conf.get("test_provider", "test_option") == "provider-default" + # Provider metadata is merged into configuration_description + assert test_conf.configuration_description.get("test_provider") is not None + # Base configuration is not mutated + assert "test_provider" not in test_conf._configuration_description + + def test_has_option_uses_provider_metadata_fallback(self): + """has_option must reach provider-metadata fallback for provider-only sections. + + Regression test: has_option passes ``fallback=None`` to get(), which leaked + into _get_option_from_defaults via **kwargs. The ``"fallback" in kwargs`` + guard caused _get_option_from_defaults to return None (the fallback) instead + of VALUE_NOT_FOUND_SENTINEL, short-circuiting the lookup before the provider + metadata fallback was consulted. + """ + provider_configs = [ + ( + "apache-airflow-providers-test", + { + "test_provider": { + "options": { + "test_option": { + "default": "provider-default", + } + } + } + }, + ) + ] + + class ProvidersManagerWithConfig: @property - def already_initialized_provider_configs(self): - return [] + def provider_configs(self): + return provider_configs test_conf = AirflowConfigParser( provider_manager_type=ProvidersManagerWithConfig, create_default_config_parser_callable=_create_default_config_parser, ) - assert test_conf._providers_configuration_loaded is False - assert test_conf.configuration_description.get("test_provider") is None - assert test_conf.get("test_provider", "test_option") == "provider-default" - assert test_conf._providers_configuration_loaded is False - assert test_conf.configuration_description.get("test_provider") is None + assert test_conf.has_option("test_provider", "test_option") is True + assert test_conf.has_option("test_provider", "nonexistent_option") is False + assert test_conf.has_option("nonexistent_section", "nonexistent_option") is False def test_set_case_insensitive(self): # both get and set should be case insensitive @@ -1123,3 +1123,68 @@ def test_team_env_var_format(self): {"AIRFLOW__MY_TEAM___MY_SECTION__MY_KEY": "team_value"}, ): assert test_conf.get("my_section", "my_key", team_name="my_team") == "team_value" + + @pytest.mark.parametrize( + "populate_caches", + [ + pytest.param(set(), id="neither_cached"), + pytest.param({"configuration_description"}, id="only_configuration_description"), + pytest.param({"sensitive_config_values"}, id="only_sensitive_config_values"), + pytest.param({"configuration_description", "sensitive_config_values"}, id="both_cached"), + ], + ) + def test_invalidate_provider_flag_caches(self, populate_caches): + """Test that _invalidate_provider_flag_caches clears cached properties without error.""" + test_conf = AirflowConfigParser() + if "configuration_description" in populate_caches: + _ = test_conf.configuration_description + if "sensitive_config_values" in populate_caches: + _ = test_conf.sensitive_config_values + + test_conf._invalidate_provider_flag_caches() + + assert "configuration_description" not in test_conf.__dict__ + assert "sensitive_config_values" not in test_conf.__dict__ + + def test_invalidate_provider_flag_caches_allows_recomputation(self): + """Test that cached properties are recomputed after invalidation.""" + test_conf = AirflowConfigParser() + desc_before = test_conf.configuration_description + sensitive_before = test_conf.sensitive_config_values + + test_conf._invalidate_provider_flag_caches() + + # Access again — should recompute, not error + desc_after = test_conf.configuration_description + sensitive_after = test_conf.sensitive_config_values + assert desc_after == desc_before + assert sensitive_after == sensitive_before + + def test_load_providers_configuration_emits_deprecation_warning(self): + """Test that load_providers_configuration emits a DeprecationWarning.""" + test_conf = AirflowConfigParser() + with pytest.warns(DeprecationWarning, match="load_providers_configuration.*deprecated"): + test_conf.load_providers_configuration() + assert test_conf._use_providers_configuration is True + + def test_restore_core_default_configuration_emits_deprecation_warning(self): + """Test that restore_core_default_configuration emits a DeprecationWarning.""" + test_conf = AirflowConfigParser() + with pytest.warns(DeprecationWarning, match="restore_core_default_configuration.*deprecated"): + test_conf.restore_core_default_configuration() + assert test_conf._use_providers_configuration is False + + def test_deprecated_load_restore_round_trip(self): + """Test that the deprecated methods toggle _use_providers_configuration correctly.""" + test_conf = AirflowConfigParser() + assert test_conf._use_providers_configuration is True + + with pytest.warns(DeprecationWarning, match="restore_core_default_configuration"): + test_conf.restore_core_default_configuration() + assert test_conf._use_providers_configuration is False + assert "configuration_description" not in test_conf.__dict__ + + with pytest.warns(DeprecationWarning, match="load_providers_configuration"): + test_conf.load_providers_configuration() + assert test_conf._use_providers_configuration is True + assert "configuration_description" not in test_conf.__dict__ diff --git a/shared/providers_discovery/src/airflow_shared/providers_discovery/providers_discovery.py b/shared/providers_discovery/src/airflow_shared/providers_discovery/providers_discovery.py index dcab0fe3034aa..f0b698dea421c 100644 --- a/shared/providers_discovery/src/airflow_shared/providers_discovery/providers_discovery.py +++ b/shared/providers_discovery/src/airflow_shared/providers_discovery/providers_discovery.py @@ -21,6 +21,7 @@ import contextlib import json +import logging import pathlib from collections.abc import Callable, MutableMapping from dataclasses import dataclass @@ -29,12 +30,11 @@ from time import perf_counter from typing import Any, NamedTuple, ParamSpec, Protocol, cast -import structlog from packaging.utils import canonicalize_name from ..module_loading import entry_points_with_dist -log = structlog.getLogger(__name__) +log = logging.getLogger(__name__) PS = ParamSpec("PS") diff --git a/task-sdk/src/airflow/sdk/configuration.py b/task-sdk/src/airflow/sdk/configuration.py index 64bda4b3a56eb..fba300e4a8477 100644 --- a/task-sdk/src/airflow/sdk/configuration.py +++ b/task-sdk/src/airflow/sdk/configuration.py @@ -130,11 +130,11 @@ def __init__( from airflow.sdk.providers_manager_runtime import ProvidersManagerTaskRuntime # Read Core's config.yml (Phase 1: shared config.yml) - configuration_description = retrieve_configuration_description() + _configuration_description = retrieve_configuration_description() # Create default values parser - _default_values = create_default_config_parser(configuration_description) + _default_values = create_default_config_parser(_configuration_description) super().__init__( - configuration_description, + _configuration_description, _default_values, ProvidersManagerTaskRuntime, create_default_config_parser, @@ -142,7 +142,7 @@ def __init__( *args, **kwargs, ) - self.configuration_description = configuration_description + self._configuration_description = _configuration_description self._default_values = _default_values self._suppress_future_warnings = False diff --git a/task-sdk/src/airflow/sdk/providers_manager_runtime.py b/task-sdk/src/airflow/sdk/providers_manager_runtime.py index 4d764596814ed..abc4fa490abbe 100644 --- a/task-sdk/src/airflow/sdk/providers_manager_runtime.py +++ b/task-sdk/src/airflow/sdk/providers_manager_runtime.py @@ -225,10 +225,6 @@ def initialize_provider_configs(self): """Lazy initialization of provider configuration metadata and merge it into SDK ``conf``.""" self.initialize_providers_list() self._discover_config() - # Imported lazily to preserve SDK conf lazy initialization and avoid a configuration/runtime cycle. - from airflow.sdk.configuration import conf - - conf.load_providers_configuration() def _discover_config(self) -> None: """Retrieve all configs defined in the providers.""" @@ -620,10 +616,6 @@ def provider_configs(self) -> list[tuple[str, dict[str, Any]]]: self.initialize_provider_configs() return sorted(self._provider_configs.items(), key=lambda x: x[0]) - @property - def already_initialized_provider_configs(self) -> list[tuple[str, dict[str, Any]]]: - return sorted(self._provider_configs.items(), key=lambda x: x[0]) - def _cleanup(self): self._initialized_cache.clear() self._provider_dict.clear() diff --git a/task-sdk/tests/task_sdk/test_configuration.py b/task-sdk/tests/task_sdk/test_configuration.py new file mode 100644 index 0000000000000..32ecc7ff3a109 --- /dev/null +++ b/task-sdk/tests/task_sdk/test_configuration.py @@ -0,0 +1,158 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import re +from unittest import mock + +import pytest + +from airflow.sdk._shared.configuration.exceptions import AirflowConfigException +from airflow.sdk.configuration import conf +from airflow.sdk.providers_manager_runtime import ProvidersManagerTaskRuntime + +from tests_common.test_utils.config import ( + CFG_FALLBACK_CONFIG_OPTIONS, + PROVIDER_METADATA_CONFIG_OPTIONS, + PROVIDER_METADATA_OVERRIDES_CFG_FALLBACK, + conf_vars, + create_fresh_airflow_config, +) +from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker + + +@pytest.fixture(scope="module", autouse=True) +def restore_providers_manager_runtime_configuration(): + yield + ProvidersManagerTaskRuntime()._cleanup() + + +@skip_if_force_lowest_dependencies_marker +class TestSDKProviderConfigPriority: + """Tests that SDK conf.get and conf.has_option respect provider metadata and cfg fallbacks.""" + + @pytest.mark.parametrize( + ("section", "option", "expected"), + PROVIDER_METADATA_CONFIG_OPTIONS, + ids=[f"{s}.{o}" for s, o, _ in PROVIDER_METADATA_CONFIG_OPTIONS], + ) + def test_get_returns_provider_metadata_value(self, section, option, expected): + """conf.get returns provider metadata (provider.yaml) values.""" + assert conf.get(section, option) == expected + + @pytest.mark.parametrize( + ("section", "option", "expected"), + CFG_FALLBACK_CONFIG_OPTIONS, + ids=[f"{s}.{o}" for s, o, _ in CFG_FALLBACK_CONFIG_OPTIONS], + ) + def test_cfg_fallback_has_expected_value(self, section, option, expected): + """provider_config_fallback_defaults.cfg contains expected default values.""" + assert conf.get_from_provider_cfg_config_fallback_defaults(section, option) == expected + + @pytest.mark.parametrize( + ("section", "option", "expected"), + PROVIDER_METADATA_CONFIG_OPTIONS, + ids=[f"{s}.{o}" for s, o, _ in PROVIDER_METADATA_CONFIG_OPTIONS], + ) + def test_has_option_true_for_provider_metadata(self, section, option, expected): + """conf.has_option returns True for options defined in provider metadata.""" + assert conf.has_option(section, option) is True + + @pytest.mark.parametrize( + ("section", "option", "expected"), + CFG_FALLBACK_CONFIG_OPTIONS, + ids=[f"{s}.{o}" for s, o, _ in CFG_FALLBACK_CONFIG_OPTIONS], + ) + def test_has_option_true_for_cfg_fallback(self, section, option, expected): + """conf.has_option returns True for options in provider_config_fallback_defaults.cfg.""" + assert conf.has_option(section, option) is True + + def test_has_option_false_for_nonexistent_option(self): + """conf.has_option returns False for options not in any source.""" + assert conf.has_option("celery", "totally_nonexistent_option_xyz") is False + + @pytest.mark.parametrize( + ("section", "option", "metadata_value", "cfg_value"), + PROVIDER_METADATA_OVERRIDES_CFG_FALLBACK, + ids=[f"{s}.{o}" for s, o, _, _ in PROVIDER_METADATA_OVERRIDES_CFG_FALLBACK], + ) + def test_provider_metadata_overrides_cfg_fallback(self, section, option, metadata_value, cfg_value): + """Provider metadata values take priority over provider_config_fallback_defaults.cfg values.""" + assert conf.get(section, option) == metadata_value + assert conf.get_from_provider_cfg_config_fallback_defaults(section, option) == cfg_value + + @pytest.mark.parametrize( + ("section", "option", "metadata_value", "cfg_value"), + PROVIDER_METADATA_OVERRIDES_CFG_FALLBACK, + ids=[f"{s}.{o}" for s, o, _, _ in PROVIDER_METADATA_OVERRIDES_CFG_FALLBACK], + ) + def test_get_default_value_priority(self, section, option, metadata_value, cfg_value): + """get_default_value checks provider metadata before cfg fallback.""" + assert conf.get_default_value(section, option) == metadata_value + + @pytest.mark.parametrize( + ("section", "option", "expected"), + CFG_FALLBACK_CONFIG_OPTIONS + PROVIDER_METADATA_CONFIG_OPTIONS, + ids=[f"{s}.{o}" for s, o, _ in CFG_FALLBACK_CONFIG_OPTIONS + PROVIDER_METADATA_CONFIG_OPTIONS], + ) + def test_providers_disabled_dont_get_cfg_defaults_or_provider_metadata(self, section, option, expected): + """With providers disabled, conf.get raises for provider-only options.""" + test_conf = create_fresh_airflow_config("task-sdk") + with test_conf.make_sure_configuration_loaded(with_providers=False): + with pytest.raises( + AirflowConfigException, + match=re.escape(f"section/key [{section}/{option}] not found in config"), + ): + test_conf.get(section, option) + + @pytest.mark.parametrize( + ("section", "option", "expected"), + CFG_FALLBACK_CONFIG_OPTIONS, + ids=[f"{s}.{o}" for s, o, _ in CFG_FALLBACK_CONFIG_OPTIONS], + ) + def test_has_option_returns_false_for_cfg_fallback_when_providers_disabled( + self, section, option, expected + ): + """With providers disabled, conf.has_option returns False for cfg-fallback-only options.""" + test_conf = create_fresh_airflow_config("task-sdk") + with test_conf.make_sure_configuration_loaded(with_providers=False): + assert test_conf.has_option(section, option) is False + + @pytest.mark.parametrize( + ("section", "option", "expected"), + PROVIDER_METADATA_CONFIG_OPTIONS, + ids=[f"{s}.{o}" for s, o, _ in PROVIDER_METADATA_CONFIG_OPTIONS], + ) + def test_has_option_returns_false_for_provider_metadata_when_providers_disabled( + self, section, option, expected + ): + """With providers disabled, conf.has_option returns False for provider-metadata-only options.""" + test_conf = create_fresh_airflow_config("task-sdk") + with test_conf.make_sure_configuration_loaded(with_providers=False): + assert test_conf.has_option(section, option) is False + + def test_env_var_overrides_provider_values(self): + """Environment variables override both provider metadata and cfg fallback values.""" + with mock.patch.dict("os.environ", {"AIRFLOW__CELERY__CELERY_APP_NAME": "env_override"}): + assert conf.get("celery", "celery_app_name") == "env_override" + + def test_user_config_overrides_provider_values(self): + """User-set config values (airflow.cfg) override provider defaults.""" + custom_value = "my_custom.celery_executor" + with conf_vars({("celery", "celery_app_name"): custom_value}): + assert conf.get("celery", "celery_app_name") == custom_value diff --git a/task-sdk/tests/task_sdk/test_providers_manager_runtime.py b/task-sdk/tests/task_sdk/test_providers_manager_runtime.py index aee5a115363f5..eeb651b4bd0ac 100644 --- a/task-sdk/tests/task_sdk/test_providers_manager_runtime.py +++ b/task-sdk/tests/task_sdk/test_providers_manager_runtime.py @@ -260,7 +260,7 @@ def initialize_provider_configs() -> None: with patch.object(providers_manager, "initialize_providers_list"): providers_manager.initialize_provider_configs() - conf.restore_core_default_configuration() + conf.invalidate_cache() try: initialize_provider_configs() assert conf.get("test_sdk_provider", "test_option") == "provider-default" @@ -270,4 +270,4 @@ def initialize_provider_configs() -> None: initialize_provider_configs() assert conf.get("test_sdk_provider", "test_option") == "provider-default" finally: - conf.restore_core_default_configuration() + conf.invalidate_cache()