Skip to content

Commit d7c8099

Browse files
authored
Merge pull request #15 from ysims/init-model
Load a model and either use directly or continue training
2 parents df47861 + 8a5203a commit d7c8099

2 files changed

Lines changed: 36 additions & 2 deletions

File tree

README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ EcoNetToolkit lets you train a shallow neural network or classical models on you
2525
- [Inspecting Saved Models](#inspecting-saved-models)
2626
- [Config reference (YAML)](#config-reference-yaml)
2727
- [Simple example (single model, classification)](#simple-example-single-model-classification)
28+
- [Multi-output (multi-target) prediction](#multi-output-multi-target-prediction)
2829
- [Available models and key parameters](#available-models-and-key-parameters)
30+
- [Loading a model from a saved path](#loading-a-model-from-a-saved-path)
2931
- [Notes on metrics](#notes-on-metrics)
3032
- [Additional notes](#additional-notes)
3133
- [Hyperparameter Tuning](#hyperparameter-tuning)
@@ -306,6 +308,27 @@ See `configs/penguins_multilabel.yaml` and `configs/possum_multilabel.yaml` for
306308
**Linear Regression** (regression only)
307309
- `fit_intercept`: Whether to calculate the intercept (default: `true`)
308310

311+
## Loading a model from a saved path
312+
313+
You can load a previously trained model directly from a file by specifying `model_path` in the model's `params` section of your YAML config. This is useful for reusing or updating models without retraining.
314+
315+
**Example:**
316+
317+
```yaml
318+
models:
319+
- name: random_forest
320+
params:
321+
model_path: outputs/possum/random_forest/model_random_forest_seed42.joblib
322+
no_train: true # If true, use the loaded model with no further training
323+
```
324+
325+
- If `no_train: true`, EcoNetToolkit will use the loaded model for prediction only and will not retrain it with the current data.
326+
- If `no_train` is omitted or set to `false`, the loaded model will be further trained (fit) on the current data, allowing you to continue training or fine-tune.
327+
328+
If `model_path` is provided, EcoNetToolkit will load the model from disk using joblib and use it for predictions or further evaluation. All other parameters are ignored when loading from a path and `no_train: true`.
329+
330+
---
331+
309332
### Notes on metrics
310333

311334
**Classification:**

ecosci/models.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
class ModelZoo:
2121
@staticmethod
2222
def get_model(
23-
name: str,
24-
problem_type: str = "classification",
23+
name: str,
24+
problem_type: str = "classification",
2525
params: Dict[str, Any] = None,
2626
n_outputs: int = 1
2727
):
@@ -39,6 +39,17 @@ def get_model(
3939
Number of output targets. If > 1, wraps the model in MultiOutput wrapper.
4040
"""
4141
params = params or {}
42+
# If model_path is provided, load the model from disk
43+
model_path = params.get("model_path")
44+
no_train = params.get("no_train", False)
45+
if model_path:
46+
import joblib
47+
model = joblib.load(model_path)
48+
# If no_train is True, return the loaded model directly
49+
if no_train:
50+
return model
51+
# Otherwise, return the loaded model for further training (fit will be called)
52+
return model
4253

4354
if name.lower() == "mlp":
4455
from sklearn.neural_network import MLPClassifier, MLPRegressor

0 commit comments

Comments
 (0)