Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/probly/method/ensemble/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from __future__ import annotations

from collections.abc import Iterable
from typing import Protocol, runtime_checkable
from typing import TYPE_CHECKING, Protocol, runtime_checkable

if TYPE_CHECKING:
from flax.nnx import rnglib

from lazy_dispatch import lazydispatch
from probly.method.method import predictor_transformation
Expand Down Expand Up @@ -81,19 +84,25 @@ def register_ensemble_members(ensemble: EnsemblePredictor, t: type[Predictor] |
)
@EnsemblePredictor.register_factory
def ensemble[**In, Out](
base: Predictor[In, Out], num_members: int, reset_params: bool = True
base: Predictor[In, Out],
num_members: int,
reset_params: bool = True,
seed: int = 1,
rngs: rnglib.Rngs | None = None,
) -> EnsemblePredictor[In, Out]:
"""Create an ensemble predictor from a base predictor based on :cite:`lakshminarayananSimpleScalable2017`.

Args:
base: Predictor, The base model to be used for the ensemble.
num_members: The number of members in the ensemble.
reset_params: Whether to reset the parameters of each member.
seed: int, seed to be used for deterministic member reset.
rngs: nnx.Rngs used for flax member re-initialization, overwrites seed.

Returns:
Predictor, The ensemble predictor.
"""
return ensemble_generator(base, num_members=num_members, reset_params=reset_params)
return ensemble_generator(base, num_members=num_members, reset_params=reset_params, seed=seed, rngs=rngs) # ty:ignore[unknown-argument]


@predict_raw.register(EnsemblePredictor)
Expand Down
38 changes: 31 additions & 7 deletions src/probly/method/ensemble/flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,58 @@
from flax import nnx

from probly.traverse_nn import nn_compose, nn_traverser
from pytraverse import CLONE, singledispatch_traverser, traverse
from pytraverse import CLONE, GlobalVariable, singledispatch_traverser, traverse
from pytraverse.core import State # noqa: TC001

from ._common import ensemble_generator

reset_traverser = singledispatch_traverser[nnx.Module](name="reset_traverser")

RNGS = GlobalVariable[nnx.Rngs]("RNGS")


@reset_traverser.register
def _(obj: nnx.Module) -> nnx.Module:
msg = "resetting parameters of flax models is not supported yet."
raise NotImplementedError(msg)
def _(obj: nnx.Module, state: State) -> tuple[nnx.Module, State]:
if hasattr(obj, "rngs") and hasattr(obj, "rng_collection") and not any(obj.iter_children()):
obj.rngs = RNGS(state)[obj.rng_collection].fork() # ty:ignore[invalid-assignment, invalid-argument-type]
return obj, state
if not any(obj.iter_children()) and "rngs" in obj.__init__.__code__.co_varnames:
params = {}
params.update(
{
name: getattr(obj, name)
for name in obj.__init__.__code__.co_varnames
if name in obj.__dict__ and name != "rngs"
}
)
params["rngs"] = RNGS(state)

params.pop("kernel_shape", None)

new_obj = obj.__class__(**params)
return new_obj, state
return obj, state


def _clone(obj: nnx.Module) -> nnx.Module:
return traverse(obj, nn_traverser, init={CLONE: True})


def _clone_reset(obj: nnx.Module) -> nnx.Module:
return traverse(obj, nn_compose(reset_traverser), init={CLONE: True})
def _clone_reset(obj: nnx.Module, rngs: nnx.Rngs) -> nnx.Module:
return traverse(obj, nn_compose(reset_traverser), init={CLONE: True, RNGS: rngs})


@ensemble_generator.register(nnx.Module)
def generate_flax_ensemble(
obj: nnx.Module,
num_members: int,
reset_params: bool,
seed: int,
rngs: nnx.Rngs,
) -> nnx.List:
"""Build a flax ensemble based on :cite:`lakshminarayananSimpleScalable2017`."""
if reset_params:
return nnx.List([_clone_reset(obj) for _ in range(num_members)])
if rngs is None:
rngs = nnx.Rngs(seed)
return nnx.List([_clone_reset(obj, rngs) for _ in range(num_members)])
return nnx.List([_clone(obj) for _ in range(num_members)])
2 changes: 1 addition & 1 deletion src/probly/method/ensemble/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


@ensemble_generator.register(BaseEstimator)
def generate_sklearn_ensemble(obj: BaseEstimator, num_members: int, reset_params: bool) -> list[object]:
def generate_sklearn_ensemble(obj: BaseEstimator, num_members: int, reset_params: bool, **_kwargs) -> list[object]: # noqa: ANN003
"""Generates an ensemble model from a sklearn base estimator."""
if reset_params:
obj.__setattr__("random_state", None)
Expand Down
1 change: 1 addition & 0 deletions src/probly/method/ensemble/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def generate_torch_ensemble(
obj: nn.Module,
num_members: int,
reset_params: bool = True,
**_kwargs, # noqa: ANN003
) -> nn.ModuleList:
"""Build a torch ensemble based on :cite:`lakshminarayananSimpleScalable2017`."""
if reset_params:
Expand Down
7 changes: 5 additions & 2 deletions stubs/probly/method/ensemble/_common.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ This type stub file was generated by pyright.
import probly

from collections.abc import Iterable
from typing import Protocol, runtime_checkable
from typing import Protocol, TYPE_CHECKING, runtime_checkable
from flax.nnx import rnglib
from lazy_dispatch import lazydispatch
from probly.method.method import predictor_transformation
from probly.predictor import IterablePredictor, Predictor, predict_raw
Expand All @@ -13,6 +14,8 @@ from probly.representation.distribution import CategoricalDistribution, Dirichle
"""
This type stub file was generated by pyright.
"""
if TYPE_CHECKING:
...
@runtime_checkable
class EnsemblePredictor[**In,Out](IterablePredictor[In, Out], Iterable[Predictor[In, Out]], Protocol):
"""Protocol for ensemble predictors."""
Expand Down Expand Up @@ -48,7 +51,7 @@ def ensemble_generator[**In,Out](base: Predictor[In, Out], num_members: int, res
def register_ensemble_members(ensemble: EnsemblePredictor, t: type[Predictor] | None) -> EnsemblePredictor:
"""Register the members of an ensemble predictor."""
...
def ensemble[**In, Out](base: Predictor[In, Out], num_members: int, reset_params: bool = True, *, predictor_type: probly.predictor.PredictorName | type[probly.predictor.Predictor] | None = None) -> EnsemblePredictor[In, Out]: ...
def ensemble[**In, Out](base: Predictor[In, Out], num_members: int, reset_params: bool = True, seed: int = 1, rngs: rnglib.Rngs | None = None, *, predictor_type: probly.predictor.PredictorName | type[probly.predictor.Predictor] | None = None) -> EnsemblePredictor[In, Out]: ...

@predict_raw.register(EnsemblePredictor)
def predict_list[**In,Out](predictor: EnsemblePredictor[In, Out], *args: In.args, **kwargs: In.kwargs) -> list[Out]:
Expand Down
2 changes: 2 additions & 0 deletions tests/probly/method/ensemble/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,7 @@ def test_registered_generator_called(dummy_predictor: Predictor) -> None:
dummy_predictor,
num_members=4,
reset_params=True,
seed=1,
rngs=None,
)
assert result is expected_result
34 changes: 25 additions & 9 deletions tests/probly/method/ensemble/test_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ def test_ensemble_attributes_without_reset(self, flax_model_small_2d_2d) -> None
jax.tree_util.tree_map(jnp.array_equal, original_params, member_params),
) # no difference

@pytest.mark.skip(reason="not implemented yet")
def test_ensemble_attributes_with_reset(self, flax_model_small_2d_2d) -> None:
"""Tests if the member attributes are not inherited from the base model."""
num_members = 2
ensemble_model = ensemble(flax_model_small_2d_2d, num_members=2, reset_params=True)
seed = 2
ensemble_model = ensemble(flax_model_small_2d_2d, num_members=2, reset_params=True, seed=seed)

assert ensemble_model is not None
assert isinstance(ensemble_model, nnx.List)
Expand Down Expand Up @@ -122,12 +122,28 @@ def test_custom_model(self, flax_custom_model) -> None:
count_module_member = count_layers(member, nnx.Module)
assert count_module_member == count_module_original

def test_not_implemented_error_with_reset(self, flax_model_small_2d_2d) -> None:
num_members = 2
def test_rngs_reset(self, flax_dropout_model) -> None:
num_members = 1
rngs = nnx.Rngs(0, params=1, dropout=2)
ensemble_model = ensemble(flax_dropout_model, num_members=num_members, reset_params=True, rngs=rngs)

assert ensemble_model is not None
assert isinstance(ensemble_model, nnx.List)
assert len(ensemble_model) == num_members

# simulate rngs
new_rngs = nnx.Rngs(0, params=1, dropout=2)
kernel1_key = new_rngs.params()
kernel_init = jax.nn.initializers.lecun_normal()
linear1_kernel = kernel_init(kernel1_key, (2, 2))
_ = new_rngs.params() # linear1 bias
dropout_rngs = new_rngs["dropout"].fork()
kernel2_key = new_rngs.params()
linear2_kernel = kernel_init(kernel2_key, (2, 2))

msg = "resetting parameters of flax models is not supported yet."
with pytest.raises(NotImplementedError, match=msg):
ensemble(flax_model_small_2d_2d, num_members=num_members, reset_params=True)
assert ensemble_model[0].layers[1].rngs.key.value == dropout_rngs.key.value
assert jnp.equal(ensemble_model[0].layers[0].kernel.value, linear1_kernel).all()
assert jnp.equal(ensemble_model[0].layers[2].kernel.value, linear2_kernel).all()


class TestEnsembleCalls:
Expand All @@ -145,10 +161,10 @@ def test_ensemble_flax_custom_model_call(self, flax_custom_model) -> None:
assert custom_model_out.shape == member_out.shape
assert jnp.equal(custom_model_out, member_out).all() # no parameter reset

@pytest.mark.skip(reason="not implemented yet")
def test_ensemble_flax_custom_model_call_with_reset(self, flax_custom_model) -> None:
num_members = 2
ensemble_model = ensemble(flax_custom_model, num_members=num_members, reset_params=True)
seed = 2
ensemble_model = ensemble(flax_custom_model, num_members=num_members, reset_params=True, seed=seed)

x = jnp.ones((2, 1, 10))
custom_model_out = flax_custom_model(x)
Expand Down
Loading