From a08666cbf1f2f0b37d172504213865feba58e5ea Mon Sep 17 00:00:00 2001 From: nicknoproblems Date: Mon, 31 Oct 2016 00:55:23 +0300 Subject: [PATCH] fix predict on ubuntu --- pylightgbm/models.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pylightgbm/models.py b/pylightgbm/models.py index 6079a3e..3554cc7 100644 --- a/pylightgbm/models.py +++ b/pylightgbm/models.py @@ -112,9 +112,10 @@ def fit(self, X, y, test_data=None): if test_data and self.param['early_stopping_round'] > 0: self.best_round = max(map(int, re.findall("Tree=(\d+)", self.model))) + 1 + def predict(self, X): tmp_dir = tempfile.mkdtemp() - predict_filepath = os.path.abspath(os.path.join(tmp_dir, "X_to_pred.svm")) + predict_filepath = os.path.abspath(os.path.join(tmp_dir, "X_to_pred.csv")) output_model = os.path.abspath(os.path.join(tmp_dir, "model")) output_results = os.path.abspath(os.path.join(tmp_dir, "LightGBM_predict_result.txt")) conf_filepath = os.path.join(tmp_dir, "predict.conf") @@ -122,7 +123,7 @@ def predict(self, X): with open(output_model, mode="w") as file: file.write(self.model) - datasets.dump_svmlight_file(X, np.zeros(len(X)), f=predict_filepath) + np.savetxt(predict_filepath, X, delimiter=",") calls = ["task = predict\n", "data = {}\n".format(predict_filepath), @@ -265,7 +266,7 @@ def predict_proba(self, X): with open(output_model, mode="w") as file: file.write(self.model) - datasets.dump_svmlight_file(X, np.zeros(len(X)), f=predict_filepath) + np.savetxt(predict_filepath, X, delimiter=",") calls = [ "task = predict\n", @@ -344,3 +345,4 @@ def __init__(self, exec_path="LighGBM/lightgbm", is_unbalance=is_unbalance, verbose=verbose, model=model) +