From bb2ff1969eab542b03031ed7e49745d309baea7e Mon Sep 17 00:00:00 2001 From: msorvoja Date: Mon, 25 Nov 2024 09:06:39 +0200 Subject: [PATCH 1/4] fix(SMOTETomek): improve documentation and parameter typing --- .../training_data_tools/class_balancing.py | 19 +++++++++++-------- .../class_balancing_test.py | 5 +++-- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/eis_toolkit/training_data_tools/class_balancing.py b/eis_toolkit/training_data_tools/class_balancing.py index 3120bf38..f8b80c10 100644 --- a/eis_toolkit/training_data_tools/class_balancing.py +++ b/eis_toolkit/training_data_tools/class_balancing.py @@ -1,7 +1,7 @@ import numpy as np import pandas as pd from beartype import beartype -from beartype.typing import Optional, Union +from beartype.typing import Literal, Optional, Union from imblearn.combine import SMOTETomek from eis_toolkit.exceptions import NonMatchingParameterLengthsException @@ -11,24 +11,27 @@ def balance_SMOTETomek( X: Union[pd.DataFrame, np.ndarray], y: Union[pd.Series, np.ndarray], - sampling_strategy: Union[float, str, dict] = "auto", + sampling_strategy: Union[float, Literal["minority", "not minority", "not majority", "all", "auto"], dict] = "auto", random_state: Optional[int] = None, ) -> tuple[Union[pd.DataFrame, np.ndarray], Union[pd.Series, np.ndarray]]: - """Balances the classes of input dataset using SMOTETomek resampling method. + """ + Balances the classes of input dataset using SMOTETomek resampling method. + + For more information about Imblearn SMOTETomek read the documentation here: + https://imbalanced-learn.org/stable/references/generated/imblearn.combine.SMOTETomek.html. Args: - X: The feature matrix (input data as a DataFrame). - y: The target labels corresponding to the feature matrix. + X: Input feature data to be sampled. + y: Target labels corresponding to the input features. sampling_strategy: Parameter controlling how to perform the resampling. If float, specifies the ratio of samples in minority class to samples of majority class, if str, specifies classes to be resampled ("minority", "not minority", "not majority", "all", "auto"), if dict, the keys should be targeted classes and values the desired number of samples for the class. Defaults to "auto", which will resample all classes except the majority class. - random_state: Parameter controlling randomization of the algorithm. Can be given a seed (number). - Defaults to None, which randomizes the seed. + random_state: Seed for random number generation. Defaults to None. Returns: - Resampled feature matrix and target labels. + Resampled feature data and target labels. Raises: NonMatchingParameterLengthsException: If X and y have different length. diff --git a/tests/training_data_tools/class_balancing_test.py b/tests/training_data_tools/class_balancing_test.py index 6703d433..b2ce3d49 100644 --- a/tests/training_data_tools/class_balancing_test.py +++ b/tests/training_data_tools/class_balancing_test.py @@ -1,5 +1,6 @@ import numpy as np import pytest +from beartype.roar import BeartypeCallHintParamViolation from sklearn.datasets import make_classification from eis_toolkit.exceptions import NonMatchingParameterLengthsException @@ -37,6 +38,6 @@ def test_invalid_label_length(): def test_invalid_sampling_strategy(): - """Test that invalid value for sampling strategy raises the correct exception (generated by imblearn).""" - with pytest.raises(ValueError): + """Test that invalid value for sampling strategy raises the correct exception.""" + with pytest.raises(BeartypeCallHintParamViolation): balance_SMOTETomek(X, y, sampling_strategy="invalid_strategy") From e751e941132d382a532594c858fb0cd551471c2e Mon Sep 17 00:00:00 2001 From: msorvoja Date: Tue, 26 Nov 2024 07:18:16 +0200 Subject: [PATCH 2/4] feat(SMOTETomek): add cli function --- eis_toolkit/cli.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/eis_toolkit/cli.py b/eis_toolkit/cli.py index d29d0eac..5f7ee1ff 100644 --- a/eis_toolkit/cli.py +++ b/eis_toolkit/cli.py @@ -3026,6 +3026,40 @@ def gamma_overlay_cli(input_rasters: INPUT_FILES_ARGUMENT, output_raster: OUTPUT # WOFE # TODO +# --- TRAINING DATA TOOLS --- + + +# BALANCE SMOTETOMEK +@app.command() +def balance_data_cli( + input_rasters: INPUT_FILES_ARGUMENT, + input_labels: INPUT_FILE_OPTION, + output_raster: OUTPUT_FILE_OPTION, + output_labels: OUTPUT_FILE_OPTION, + sampling_strategy: str = "auto", + random_state: Optional[int] = None, +): + """Resample feature data using SMOTETomek.""" + from eis_toolkit.prediction.machine_learning_general import prepare_data_for_ml + from eis_toolkit.training_data_tools.class_balancing import balance_SMOTETomek + + X, y, profile, _ = prepare_data_for_ml(input_rasters, input_labels) + typer.echo("Progress: 30%") + + X_res, y_res = balance_SMOTETomek(X, y, sampling_strategy, random_state) + typer.echo("Progress 80%") + + with rasterio.open(output_raster, "w", **profile) as dst: + dst.write(X_res, 1) + + with rasterio.open(output_labels, "w", **profile) as dst: + dst.write(y_res, 1) + typer.echo("Progress: 100%") + typer.echo( + f"Balancing data completed, writing resampled feature data to {output_raster} \ + and corresponding labels to {output_labels}." + ) + # --- TRANSFORMATIONS --- From e6dc779c1a9ea0531c81cc916fc16fcd24dd3d0c Mon Sep 17 00:00:00 2001 From: msorvoja Date: Tue, 26 Nov 2024 07:26:38 +0200 Subject: [PATCH 3/4] fix(balance_data_cli): add sampling strategy class --- eis_toolkit/cli.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/eis_toolkit/cli.py b/eis_toolkit/cli.py index 5f7ee1ff..104a073d 100644 --- a/eis_toolkit/cli.py +++ b/eis_toolkit/cli.py @@ -332,6 +332,16 @@ class KerasRegressorMetrics(str, Enum): mae = "mae" +class SMOTETomekSamplingStrategy(str, Enum): + """Sampling strategies available for SMOTETomek.""" + + minority = "minority" + not_minority = "not minority" + not_majority = "not majority" + all = "all" + auto = "auto" + + INPUT_FILE_OPTION = Annotated[ Path, typer.Option( @@ -3036,7 +3046,7 @@ def balance_data_cli( input_labels: INPUT_FILE_OPTION, output_raster: OUTPUT_FILE_OPTION, output_labels: OUTPUT_FILE_OPTION, - sampling_strategy: str = "auto", + sampling_strategy: Annotated[SMOTETomekSamplingStrategy, typer.Option()] = SMOTETomekSamplingStrategy.auto, random_state: Optional[int] = None, ): """Resample feature data using SMOTETomek.""" From b8c499365eb74fcd890150ec082f7e7c15a2bd0b Mon Sep 17 00:00:00 2001 From: msorvoja Date: Fri, 29 Nov 2024 09:32:56 +0200 Subject: [PATCH 4/4] Override sampling strategy literal parameter with float if given --- eis_toolkit/cli.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/eis_toolkit/cli.py b/eis_toolkit/cli.py index 104a073d..bea22b26 100644 --- a/eis_toolkit/cli.py +++ b/eis_toolkit/cli.py @@ -3046,16 +3046,25 @@ def balance_data_cli( input_labels: INPUT_FILE_OPTION, output_raster: OUTPUT_FILE_OPTION, output_labels: OUTPUT_FILE_OPTION, - sampling_strategy: Annotated[SMOTETomekSamplingStrategy, typer.Option()] = SMOTETomekSamplingStrategy.auto, + sampling_strategy_literal: Annotated[SMOTETomekSamplingStrategy, typer.Option()] = SMOTETomekSamplingStrategy.auto, + sampling_strategy_float: Optional[float] = None, random_state: Optional[int] = None, ): - """Resample feature data using SMOTETomek.""" + """Resample feature data using SMOTETomek. + + Parameter sampling_strategy_float will override sampling_strategy_literal if given. + """ from eis_toolkit.prediction.machine_learning_general import prepare_data_for_ml from eis_toolkit.training_data_tools.class_balancing import balance_SMOTETomek X, y, profile, _ = prepare_data_for_ml(input_rasters, input_labels) typer.echo("Progress: 30%") + if sampling_strategy_float is not None: + sampling_strategy = sampling_strategy_float + else: + sampling_strategy = sampling_strategy_literal + X_res, y_res = balance_SMOTETomek(X, y, sampling_strategy, random_state) typer.echo("Progress 80%")