diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index bd8d76823..e8dde3ecb 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -18,18 +18,18 @@ } }, { - "code": "reportUnknownMemberType", + "code": "reportUnknownArgumentType", "range": { - "startColumn": 15, - "endColumn": 31, - "lineCount": 1 + "startColumn": 32, + "endColumn": 9, + "lineCount": 5 } }, { - "code": "reportArgumentType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 42, - "endColumn": 44, + "startColumn": 15, + "endColumn": 31, "lineCount": 1 } }, @@ -180,18 +180,18 @@ } }, { - "code": "reportUnknownMemberType", + "code": "reportUnknownArgumentType", "range": { - "startColumn": 37, - "endColumn": 45, - "lineCount": 1 + "startColumn": 32, + "endColumn": 9, + "lineCount": 5 } }, { - "code": "reportArgumentType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 75, - "endColumn": 79, + "startColumn": 37, + "endColumn": 45, "lineCount": 1 } }, @@ -486,18 +486,18 @@ } }, { - "code": "reportUnknownMemberType", + "code": "reportUnknownArgumentType", "range": { - "startColumn": 28, - "endColumn": 36, - "lineCount": 1 + "startColumn": 32, + "endColumn": 9, + "lineCount": 5 } }, { - "code": "reportArgumentType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 66, - "endColumn": 69, + "startColumn": 28, + "endColumn": 36, "lineCount": 1 } }, @@ -680,18 +680,18 @@ } }, { - "code": "reportUnknownMemberType", + "code": "reportUnknownArgumentType", "range": { - "startColumn": 37, - "endColumn": 45, - "lineCount": 1 + "startColumn": 32, + "endColumn": 9, + "lineCount": 5 } }, { - "code": "reportArgumentType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 75, - "endColumn": 79, + "startColumn": 37, + "endColumn": 45, "lineCount": 1 } }, @@ -1165,22 +1165,6 @@ "lineCount": 1 } }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 45, - "endColumn": 49, - "lineCount": 1 - } - }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 35, - "endColumn": 69, - "lineCount": 1 - } - }, { "code": "reportUnknownLambdaType", "range": { @@ -1221,14 +1205,6 @@ "lineCount": 1 } }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 70, - "endColumn": 73, - "lineCount": 1 - } - }, { "code": "reportArgumentType", "range": { @@ -3723,6 +3699,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownParameterType", + "range": { + "startColumn": 8, + "endColumn": 12, + "lineCount": 1 + } + }, { "code": "reportUnknownParameterType", "range": { @@ -3771,6 +3755,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownParameterType", + "range": { + "startColumn": 12, + "endColumn": 21, + "lineCount": 1 + } + }, { "code": "reportMissingParameterType", "range": { @@ -3779,6 +3771,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownParameterType", + "range": { + "startColumn": 12, + "endColumn": 30, + "lineCount": 1 + } + }, { "code": "reportMissingParameterType", "range": { @@ -3803,6 +3803,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownParameterType", + "range": { + "startColumn": 12, + "endColumn": 40, + "lineCount": 1 + } + }, { "code": "reportMissingParameterType", "range": { @@ -3811,6 +3819,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownParameterType", + "range": { + "startColumn": 12, + "endColumn": 43, + "lineCount": 1 + } + }, { "code": "reportMissingParameterType", "range": { @@ -3819,6 +3835,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownParameterType", + "range": { + "startColumn": 12, + "endColumn": 39, + "lineCount": 1 + } + }, { "code": "reportMissingParameterType", "range": { @@ -3891,6 +3915,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownParameterType", + "range": { + "startColumn": 12, + "endColumn": 36, + "lineCount": 1 + } + }, { "code": "reportMissingParameterType", "range": { @@ -3915,6 +3947,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownParameterType", + "range": { + "startColumn": 12, + "endColumn": 22, + "lineCount": 1 + } + }, { "code": "reportMissingParameterType", "range": { @@ -3939,6 +3979,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownParameterType", + "range": { + "startColumn": 12, + "endColumn": 17, + "lineCount": 1 + } + }, { "code": "reportMissingParameterType", "range": { @@ -3947,6 +3995,30 @@ "lineCount": 1 } }, + { + "code": "reportUnknownParameterType", + "range": { + "startColumn": 12, + "endColumn": 27, + "lineCount": 1 + } + }, + { + "code": "reportMissingParameterType", + "range": { + "startColumn": 12, + "endColumn": 27, + "lineCount": 1 + } + }, + { + "code": "reportUnknownParameterType", + "range": { + "startColumn": 12, + "endColumn": 31, + "lineCount": 1 + } + }, { "code": "reportMissingParameterType", "range": { @@ -3979,6 +4051,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 45, + "endColumn": 73, + "lineCount": 1 + } + }, { "code": "reportUnknownArgumentType", "range": { @@ -3988,23 +4068,23 @@ } }, { - "code": "reportArgumentType", + "code": "reportUnknownArgumentType", "range": { "startColumn": 22, - "endColumn": 72, + "endColumn": 71, "lineCount": 3 } }, { - "code": "reportArgumentType", + "code": "reportUnknownArgumentType", "range": { - "startColumn": 36, - "endColumn": 50, - "lineCount": 5 + "startColumn": 32, + "endColumn": 46, + "lineCount": 4 } }, { - "code": "reportArgumentType", + "code": "reportUnknownArgumentType", "range": { "startColumn": 48, "endColumn": 62, @@ -4012,7 +4092,7 @@ } }, { - "code": "reportArgumentType", + "code": "reportUnknownArgumentType", "range": { "startColumn": 44, "endColumn": 58, @@ -4052,7 +4132,7 @@ } }, { - "code": "reportArgumentType", + "code": "reportUnknownArgumentType", "range": { "startColumn": 41, "endColumn": 55, @@ -4068,7 +4148,7 @@ } }, { - "code": "reportArgumentType", + "code": "reportUnknownArgumentType", "range": { "startColumn": 27, "endColumn": 41, @@ -4099,6 +4179,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 18, + "endColumn": 24, + "lineCount": 1 + } + }, { "code": "reportArgumentType", "range": { @@ -5083,6 +5171,14 @@ "lineCount": 1 } }, + { + "code": "reportEmptyAbstractUsage", + "range": { + "startColumn": 19, + "endColumn": 72, + "lineCount": 1 + } + }, { "code": "reportUnknownArgumentType", "range": { @@ -21805,7 +21901,7 @@ "code": "reportUnknownMemberType", "range": { "startColumn": 7, - "endColumn": 38, + "endColumn": 34, "lineCount": 1 } }, @@ -22113,6 +22209,30 @@ "lineCount": 1 } }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 15, + "endColumn": 42, + "lineCount": 1 + } + }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 32, + "endColumn": 54, + "lineCount": 1 + } + }, + { + "code": "reportUnknownVariableType", + "range": { + "startColumn": 12, + "endColumn": 16, + "lineCount": 1 + } + }, { "code": "reportUnknownMemberType", "range": { @@ -22121,6 +22241,22 @@ "lineCount": 1 } }, + { + "code": "reportUnknownVariableType", + "range": { + "startColumn": 12, + "endColumn": 32, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 35, + "endColumn": 48, + "lineCount": 1 + } + }, { "code": "reportUnknownMemberType", "range": { @@ -22129,6 +22265,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 31, + "endColumn": 35, + "lineCount": 1 + } + }, { "code": "reportUnknownArgumentType", "range": { @@ -22301,7 +22445,7 @@ "code": "reportUnknownMemberType", "range": { "startColumn": 7, - "endColumn": 38, + "endColumn": 34, "lineCount": 1 } }, @@ -22505,6 +22649,22 @@ "lineCount": 1 } }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 15, + "endColumn": 42, + "lineCount": 1 + } + }, + { + "code": "reportUnknownArgumentType", + "range": { + "startColumn": 32, + "endColumn": 54, + "lineCount": 1 + } + }, { "code": "reportUnknownVariableType", "range": { diff --git a/pyproject.toml b/pyproject.toml index 157ecf4f4..4cde04131 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,8 @@ dependencies = [ "pytools>=2024.1", "scipy>=1.2", "sumpy>=2022.1", + # for sentinel + "typing-extensions>=4.14", ] [project.optional-dependencies] @@ -163,6 +165,7 @@ reportPrivateUsage = "hint" pythonVersion = "3.10" pythonPlatform = "All" +enableExperimentalFeatures = true ignore = [ "build", "doc/_build", diff --git a/pytential/qbx/__init__.py b/pytential/qbx/__init__.py index e392844d2..33d716a8f 100644 --- a/pytential/qbx/__init__.py +++ b/pytential/qbx/__init__.py @@ -28,7 +28,7 @@ from typing import TYPE_CHECKING, Literal, TypeAlias, cast import numpy as np -from typing_extensions import override +from typing_extensions import Sentinel, override from arraycontext import ( Array, @@ -46,6 +46,7 @@ ) from sumpy.expansion.local import LocalExpansionBase +from pytential.qbx.refinement import QBXRefinementMode, QBXRefinementNeededError from pytential.qbx.target_assoc import QBXTargetAssociationFailedError from pytential.source import LayerPotentialSourceBase @@ -86,6 +87,10 @@ .. autoclass:: NonFFTExpansionFactory .. autodata:: FMMBackend + +.. autoclass:: QBXRefinementMode + +.. autoclass:: QBXRefinementNeededError """ @@ -118,8 +123,7 @@ class NonFFTExpansionFactory(QBXDefaultExpansionFactory): get_local_expansion_class = QBXDefaultExpansionFactory.get_qbx_local_expansion_class -class _not_provided: # noqa: N801 - pass +NOT_PROVIDED = Sentinel("NOT_PROVIDED") class QBXLayerPotentialSource(LayerPotentialSourceBase): @@ -147,7 +151,7 @@ class QBXLayerPotentialSource(LayerPotentialSourceBase): target_association_tolerance: float debug: bool - _disable_refinement: bool + refinement_mode: QBXRefinementMode _expansions_in_tree_have_extent: bool _expansion_stick_out_factor: float _well_sep_is_n_away: int @@ -170,12 +174,13 @@ def __init__( fmm_level_to_order: Literal[False] | FMMLevelToOrder | None = None, expansion_factory: QBXDefaultExpansionFactory | None = None, target_association_tolerance: ( - float | type[_not_provided] | None) = _not_provided, + float | NOT_PROVIDED | None) = NOT_PROVIDED, # begin experimental arguments # FIXME default debug=False once everything has matured debug: bool = True, - _disable_refinement: bool = False, + refinement_mode: QBXRefinementMode | None = None, + _disable_refinement: bool | None = None, _expansions_in_tree_have_extent: bool = True, _expansion_stick_out_factor: float = 0.5, _max_leaf_refine_weight: int | None = None, @@ -208,6 +213,8 @@ def __init__( the FMM evaluations. :arg target_association_tolerance: passed on to :func:`pytential.qbx.target_assoc.associate_targets_to_qbx_centers`. + :arg refinement_mode: A :class:`~pytential.qbx.refinement.QBXRefinementMode` + controlling whether and how refinement is performed. Experimental arguments without a promise of forward compatibility: @@ -234,6 +241,7 @@ def __init__( :arg cost_model: Either *None* or an object implementing the :class:`~pytential.qbx.cost.AbstractQBXCostModel` interface, used for gathering modeled costs if provided (experimental). + :arg _disable_refinement: Deprecated. Use *refinement_mode* instead. """ # {{{ argument processing @@ -246,7 +254,7 @@ def __init__( raise ValueError("'qbx_order' must be provided.") assert isinstance(qbx_order, int) - if target_association_tolerance is _not_provided: + if target_association_tolerance is NOT_PROVIDED: target_association_tolerance = ( 1.0e+3 * float(np.finfo(density_discr.real_dtype).eps)) assert isinstance(target_association_tolerance, float) @@ -317,6 +325,21 @@ def fmm_lto(kernel, kernel_args, tree, level): from pytential.qbx.cost import QBXCostModel cost_model = QBXCostModel() + if _disable_refinement is not None: + from warnings import warn + warn( + "'_disable_refinement' is deprecated. " + "Use 'refinement_mode' instead.", + DeprecationWarning, stacklevel=2) + if refinement_mode is None: + refinement_mode = ( + QBXRefinementMode.NO_REFINEMENT + if _disable_refinement + else QBXRefinementMode.REFINE) + + if refinement_mode is None: + refinement_mode = QBXRefinementMode.REFINE + # }}} if density_discr.dim != density_discr.ambient_dim - 1: @@ -335,7 +358,7 @@ def fmm_lto(kernel, kernel_args, tree, level): self.target_association_tolerance = target_association_tolerance self.debug = debug - self._disable_refinement = _disable_refinement + self.refinement_mode = refinement_mode self._expansions_in_tree_have_extent = _expansions_in_tree_have_extent self._expansion_stick_out_factor = _expansion_stick_out_factor self._well_sep_is_n_away = _well_sep_is_n_away @@ -360,40 +383,53 @@ def copy( density_discr=None, fine_order=None, qbx_order=None, - fmm_order=_not_provided, - fmm_level_to_order=_not_provided, + fmm_order=NOT_PROVIDED, + fmm_level_to_order=NOT_PROVIDED, expansion_factory=None, - target_association_tolerance=_not_provided, - _expansions_in_tree_have_extent=_not_provided, - _expansion_stick_out_factor=_not_provided, + target_association_tolerance=NOT_PROVIDED, + _expansions_in_tree_have_extent=NOT_PROVIDED, + _expansion_stick_out_factor=NOT_PROVIDED, _max_leaf_refine_weight=None, _box_extent_norm=None, _from_sep_smaller_crit=None, _tree_kind=None, - _use_target_specific_qbx=_not_provided, + _use_target_specific_qbx=NOT_PROVIDED, geometry_data_inspector=None, - cost_model=_not_provided, + cost_model=NOT_PROVIDED, fmm_backend=None, - debug=_not_provided, - _disable_refinement=_not_provided, + debug=NOT_PROVIDED, + refinement_mode=NOT_PROVIDED, + _disable_refinement=NOT_PROVIDED, ): - if target_association_tolerance is _not_provided: + if target_association_tolerance is NOT_PROVIDED: target_association_tolerance = self.target_association_tolerance kwargs = {} - if (fmm_order is not _not_provided - and fmm_level_to_order is not _not_provided): + if (fmm_order is not NOT_PROVIDED + and fmm_level_to_order is not NOT_PROVIDED): raise TypeError( "may not specify both 'fmm_order' and 'fmm_level_to_order'") - elif fmm_order is not _not_provided: + elif fmm_order is not NOT_PROVIDED: kwargs["fmm_order"] = fmm_order - elif fmm_level_to_order is not _not_provided: + elif fmm_level_to_order is not NOT_PROVIDED: kwargs["fmm_level_to_order"] = fmm_level_to_order else: kwargs["fmm_level_to_order"] = self.fmm_level_to_order + if _disable_refinement is not NOT_PROVIDED: + from warnings import warn + warn( + "'_disable_refinement' is deprecated. " + "Use 'refinement_mode' instead.", + DeprecationWarning, stacklevel=2) + if refinement_mode is NOT_PROVIDED: + refinement_mode = ( + QBXRefinementMode.NO_REFINEMENT + if _disable_refinement + else QBXRefinementMode.REFINE) + # FIXME Could/should share wrangler and geometry kernels # if no relevant changes have been made. return type(self)( @@ -408,21 +444,20 @@ def copy( debug=( # False is a valid value here - debug if debug is not _not_provided else self.debug), - _disable_refinement=( - # False is a valid value here - _disable_refinement - if _disable_refinement is not _not_provided - else self._disable_refinement), + debug if debug is not NOT_PROVIDED else self.debug), + refinement_mode=( + refinement_mode + if refinement_mode is not NOT_PROVIDED + else self.refinement_mode), _expansions_in_tree_have_extent=( # False is a valid value here _expansions_in_tree_have_extent - if _expansions_in_tree_have_extent is not _not_provided + if _expansions_in_tree_have_extent is not NOT_PROVIDED else self._expansions_in_tree_have_extent), _expansion_stick_out_factor=( # 0 is a valid value here _expansion_stick_out_factor - if _expansion_stick_out_factor is not _not_provided + if _expansion_stick_out_factor is not NOT_PROVIDED else self._expansion_stick_out_factor), _well_sep_is_n_away=self._well_sep_is_n_away, _max_leaf_refine_weight=( @@ -434,14 +469,14 @@ def copy( self._from_sep_smaller_min_nsources_cumul), _tree_kind=_tree_kind or self._tree_kind, _use_target_specific_qbx=(_use_target_specific_qbx - if _use_target_specific_qbx is not _not_provided + if _use_target_specific_qbx is not NOT_PROVIDED else self._use_target_specific_qbx), geometry_data_inspector=( geometry_data_inspector or self.geometry_data_inspector), cost_model=( # None is a valid value here cost_model - if cost_model is not _not_provided + if cost_model is not NOT_PROVIDED else self.cost_model), fmm_backend=fmm_backend or self.fmm_backend, **kwargs) @@ -569,7 +604,7 @@ def drive_cost_model( def _dispatch_compute_potential_insn(self, actx, insn, bound_expr, evaluate, func, extra_args=None): - if self._disable_refinement: + if self.refinement_mode == QBXRefinementMode.NO_REFINEMENT: from warnings import warn warn( "Executing global QBX without refinement. " @@ -1072,6 +1107,8 @@ def get_flat_strengths_from_densities( "LocalExpansionBase", "QBXDefaultExpansionFactory", "QBXLayerPotentialSource", + "QBXRefinementMode", + "QBXRefinementNeededError", "QBXTargetAssociationFailedError", ) diff --git a/pytential/qbx/refinement.py b/pytential/qbx/refinement.py index 429ac6b4b..a690ff062 100644 --- a/pytential/qbx/refinement.py +++ b/pytential/qbx/refinement.py @@ -27,6 +27,7 @@ """ import logging +from enum import Enum, auto from typing import TYPE_CHECKING, Any, cast import numpy as np @@ -77,8 +78,15 @@ The element size is bounded by a kernel length scale. This applies only to Helmholtz kernels. -Warnings emitted by refinement -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Refinement mode +^^^^^^^^^^^^^^^ + +.. autoclass:: QBXRefinementMode + +Errors and warnings emitted by refinement +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: QBXRefinementNeededError .. autoclass:: RefinerNotConvergedWarning @@ -97,6 +105,48 @@ .. autofunction:: refine_geometry_collection """ + +# {{{ QBXRefinementMode + +class QBXRefinementMode(Enum): + """Controls the refinement behavior of a + :class:`~pytential.qbx.QBXLayerPotentialSource`. + + .. attribute:: REFINE + + Perform refinement as needed. This is the default behavior. + + .. attribute:: NO_REFINEMENT + + Skip refinement entirely. An + :class:`~meshmode.discretization.connection.IdentityDiscretizationConnection` + is returned instead of performing any mesh refinement. + + .. warning:: + + Executing global QBX without refinement is unlikely to give + accurate results. + + .. attribute:: COMPLAIN + + Do not perform any refinement, but raise a + :class:`QBXRefinementNeededError` if stage-1 or stage-2 refinement + would be required to satisfy the QBX refinement criteria. + """ + + REFINE = auto() + NO_REFINEMENT = auto() + COMPLAIN = auto() + + +class QBXRefinementNeededError(RuntimeError): + """Raised when :attr:`QBXRefinementMode.COMPLAIN` is in effect and + refinement would be needed to satisfy the QBX refinement criteria. + """ + +# }}} + + # {{{ kernels # Refinement checker for Condition 1. @@ -616,7 +666,7 @@ def _refine_qbx_stage1(lpot_source, density_discr, expansion_disturbance_tolerance=None, maxiter=None, debug=None, visualize=False): from pytential import bind, sym - if lpot_source._disable_refinement: + if lpot_source.refinement_mode == QBXRefinementMode.NO_REFINEMENT: from meshmode.discretization.connection import IdentityDiscretizationConnection return density_discr, IdentityDiscretizationConnection(density_discr) @@ -717,6 +767,13 @@ def _refine_qbx_stage1(lpot_source, density_discr, if iter_violated_criteria: violated_criteria.append(" and ".join(iter_violated_criteria)) + if lpot_source.refinement_mode == QBXRefinementMode.COMPLAIN: + raise QBXRefinementNeededError( + "Stage-1 QBX refinement is needed but refinement mode is " + f"'{QBXRefinementMode.COMPLAIN.name}'. " + "Criteria requiring refinement: " + + ", ".join(iter_violated_criteria)) + conn = wrangler.refine( stage1_density_discr, refiner, refine_flags, group_factory, debug) @@ -737,7 +794,7 @@ def _refine_qbx_stage2(lpot_source, stage1_density_discr, expansion_disturbance_tolerance=None, force_stage2_uniform_refinement_rounds=None, maxiter=None, debug=None, visualize=False): - if lpot_source._disable_refinement: + if lpot_source.refinement_mode == QBXRefinementMode.NO_REFINEMENT: from meshmode.discretization.connection import IdentityDiscretizationConnection return (stage1_density_discr, IdentityDiscretizationConnection(stage1_density_discr)) @@ -789,6 +846,13 @@ def _refine_qbx_stage2(lpot_source, stage1_density_discr, if iter_violated_criteria: violated_criteria.append(" and ".join(iter_violated_criteria)) + if lpot_source.refinement_mode == QBXRefinementMode.COMPLAIN: + raise QBXRefinementNeededError( + "Stage-2 QBX refinement is needed but refinement mode is " + f"'{QBXRefinementMode.COMPLAIN.name}'. " + "Criteria requiring refinement: " + + ", ".join(iter_violated_criteria)) + conn = wrangler.refine( stage2_density_discr, refiner, refine_flags, group_factory, debug) diff --git a/test/extra_int_eq_data.py b/test/extra_int_eq_data.py index 35675ff01..944ce5a70 100644 --- a/test/extra_int_eq_data.py +++ b/test/extra_int_eq_data.py @@ -39,6 +39,7 @@ from pytential import sym from pytential.qbx import FMMBackend, QBXLayerPotentialSource +from pytential.qbx.refinement import QBXRefinementMode from pytential.source import PointPotentialSource from pytential.target import PointsTarget @@ -261,7 +262,10 @@ def get_layer_potential(self, qbx_order=self.qbx_order, fmm_backend=fmm_backend, **fmm_kwargs, - _disable_refinement=not self.use_refinement, + refinement_mode=( + QBXRefinementMode.REFINE + if self.use_refinement + else QBXRefinementMode.NO_REFINEMENT), _box_extent_norm=self.box_extent_norm, _from_sep_smaller_crit=self.from_sep_smaller_crit, _from_sep_smaller_min_nsources_cumul=30,