|
| 1 | +import os |
| 2 | +import pickle as pkl |
| 3 | +from typing import Any, Dict, List, Optional |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import torch |
| 7 | +import tqdm |
| 8 | +from sklearn.exceptions import NotFittedError |
| 9 | +from sklearn.linear_model import LogisticRegression as SklearnLogisticRegression |
| 10 | + |
| 11 | +from chebai.models.base import ChebaiBaseNet |
| 12 | + |
| 13 | +LR_MODEL_PATH = os.path.join("models", "LR") |
| 14 | + |
| 15 | + |
| 16 | +class LogisticRegression(ChebaiBaseNet): |
| 17 | + """ |
| 18 | + Logistic Regression model using scikit-learn, wrapped to fit the ChebaiBaseNet interface. |
| 19 | + """ |
| 20 | + |
| 21 | + def __init__( |
| 22 | + self, |
| 23 | + out_dim: int, |
| 24 | + input_dim: int, |
| 25 | + only_predict_classes: Optional[List] = None, |
| 26 | + n_classes=1528, |
| 27 | + **kwargs, |
| 28 | + ): |
| 29 | + super().__init__(out_dim=out_dim, input_dim=input_dim, **kwargs) |
| 30 | + self.models = [ |
| 31 | + SklearnLogisticRegression(solver="liblinear") for _ in range(n_classes) |
| 32 | + ] |
| 33 | + # indices of classes (in the dataset used for training) where a model should be trained |
| 34 | + self.only_predict_classes = only_predict_classes |
| 35 | + |
| 36 | + def forward(self, x: Dict[str, Any], **kwargs) -> torch.Tensor: |
| 37 | + print( |
| 38 | + f"forward called with x[features].shape {x['features'].shape}, self.training {self.training}" |
| 39 | + ) |
| 40 | + if self.training: |
| 41 | + self.fit_sklearn(x["features"], x["labels"]) |
| 42 | + preds = [] |
| 43 | + for model in self.models: |
| 44 | + try: |
| 45 | + p = torch.from_numpy(model.predict(x["features"])).float() |
| 46 | + p = p.to(x["features"].device) |
| 47 | + preds.append(p) |
| 48 | + except NotFittedError: |
| 49 | + preds.append( |
| 50 | + torch.zeros((x["features"].shape[0]), device=(x["features"].device)) |
| 51 | + ) |
| 52 | + except AttributeError: |
| 53 | + preds.append( |
| 54 | + torch.zeros((x["features"].shape[0]), device=(x["features"].device)) |
| 55 | + ) |
| 56 | + preds = torch.stack(preds, dim=1) |
| 57 | + print(f"preds shape {preds.shape}") |
| 58 | + return preds.squeeze(-1) |
| 59 | + |
| 60 | + def fit_sklearn(self, X, y): |
| 61 | + """ |
| 62 | + Fit the underlying sklearn model. X and y should be numpy arrays. |
| 63 | + """ |
| 64 | + for i, model in tqdm.tqdm(enumerate(self.models), desc="Fitting models"): |
| 65 | + import os |
| 66 | + |
| 67 | + if os.path.exists(os.path.join(LR_MODEL_PATH, f"LR_model_{i}.pkl")): |
| 68 | + print(f"Loading model {i} from file") |
| 69 | + self.models[i] = pkl.load( |
| 70 | + open(os.path.join(LR_MODEL_PATH, f"LR_model_{i}.pkl"), "rb") |
| 71 | + ) |
| 72 | + else: |
| 73 | + if ( |
| 74 | + self.only_predict_classes and i not in self.only_predict_classes |
| 75 | + ): # only try these classes |
| 76 | + continue |
| 77 | + try: |
| 78 | + model.fit(X, y[:, i]) |
| 79 | + except ValueError: |
| 80 | + self.models[i] = PlaceholderModel() |
| 81 | + # dump |
| 82 | + pkl.dump( |
| 83 | + model, open(os.path.join(LR_MODEL_PATH, f"LR_model_{i}.pkl"), "wb") |
| 84 | + ) |
| 85 | + |
| 86 | + def configure_optimizers(self, **kwargs): |
| 87 | + pass |
| 88 | + |
| 89 | + |
| 90 | +class PlaceholderModel: |
| 91 | + """Acts like a trained model, but isn't. Use this if training fails and you need a placeholder.""" |
| 92 | + |
| 93 | + def __init__(self, default_prediction=1): |
| 94 | + self.default_prediction = default_prediction |
| 95 | + |
| 96 | + def predict(self, preds): |
| 97 | + return np.ones(preds.shape[0]) * self.default_prediction |
0 commit comments