diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index cb48d156acf..2a1d1ad2bb2 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -2,11 +2,11 @@ # Licensed under the MIT License. from urllib.parse import urlparse +from urllib.request import url2pathname import mlflow from filelock import FileLock from mlflow.exceptions import MlflowException, RESOURCE_ALREADY_EXISTS, ErrorCode from mlflow.entities import ViewType -import os from typing import Optional, Text from pathlib import Path @@ -19,6 +19,14 @@ logger = get_module_logger("workflow") +def _file_uri_to_path(uri: Text) -> Path: + pr = urlparse(uri) + path = url2pathname(pr.path) + if pr.netloc and pr.netloc != "localhost": + path = f"//{pr.netloc}{path}" + return Path(path) + + class ExpManager: """ This is the `ExpManager` class for managing experiments. The API is designed similar to mlflow. @@ -233,7 +241,7 @@ def _get_or_create_exp(self, experiment_id=None, experiment_name=None) -> (objec # So we supported it in the interface wrapper pr = urlparse(self.uri) if pr.scheme == "file": - with FileLock(Path(os.path.join(pr.netloc, pr.path.lstrip("/"), "filelock"))): # pylint: disable=E0110 + with FileLock(_file_uri_to_path(self.uri) / "filelock"): # pylint: disable=E0110 return self.create_exp(experiment_name), True # NOTE: for other schemes like http, we double check to avoid create exp conflicts try: diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index 5fd99c0769f..32f80102490 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -366,16 +366,30 @@ def _log_uncommitted_code(self): """ # TODO: the sub-directories maybe git repos. # So it will be better if we can walk the sub-directories and log the uncommitted changes. + try: + proc = subprocess.run( + ["git", "rev-parse", "--is-inside-work-tree"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=True, + ) + except (FileNotFoundError, subprocess.CalledProcessError): + logger.debug(f"Skip logging uncommitted code because $CWD({os.getcwd()}) is not a git work tree.") + return + if proc.stdout.decode().strip() != "true": + logger.debug(f"Skip logging uncommitted code because $CWD({os.getcwd()}) is not a git work tree.") + return + for cmd, fname in [ - ("git diff", "code_diff.txt"), - ("git status", "code_status.txt"), - ("git diff --cached", "code_cached.txt"), + (["git", "diff"], "code_diff.txt"), + (["git", "status"], "code_status.txt"), + (["git", "diff", "--cached"], "code_cached.txt"), ]: try: - out = subprocess.check_output(cmd, shell=True) - self.client.log_text(self.id, out.decode(), fname) # this behaves same as above - except subprocess.CalledProcessError: - logger.info(f"Fail to log the uncommitted code of $CWD({os.getcwd()}) when run {cmd}.") + out = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) + self.client.log_text(self.id, out.stdout.decode(), fname) # this behaves same as above + except (FileNotFoundError, subprocess.CalledProcessError): + logger.debug(f"Fail to log the uncommitted code of $CWD({os.getcwd()}) when run {' '.join(cmd)}.") def end_run(self, status: str = Recorder.STATUS_S): assert status in [ diff --git a/tests/dependency_tests/test_mlflow.py b/tests/dependency_tests/test_mlflow.py index 4b4d0105ba4..17a62c42f65 100644 --- a/tests/dependency_tests/test_mlflow.py +++ b/tests/dependency_tests/test_mlflow.py @@ -2,10 +2,25 @@ # Licensed under the MIT License. import unittest import platform +import contextlib +import io import mlflow +import os import time from pathlib import Path import shutil +import tempfile + +from qlib.workflow.expm import _file_uri_to_path +from qlib.workflow.recorder import MLflowRecorder + + +class DummyClient: + def __init__(self): + self.logged_text = [] + + def log_text(self, run_id, text, fname): + self.logged_text.append((run_id, text, fname)) class MLflowTest(unittest.TestCase): @@ -33,6 +48,28 @@ def test_creating_client(self): self.assertLess(elapsed, 2e-2) print(elapsed) + def test_file_uri_to_path_keeps_absolute_paths(self): + self.assertEqual(_file_uri_to_path("file:///tmp/qlib/mlruns"), Path("/tmp/qlib/mlruns")) + self.assertEqual(_file_uri_to_path("file:/tmp/qlib/mlruns"), Path("/tmp/qlib/mlruns")) + + def test_log_uncommitted_code_skips_non_git_cwd_quietly(self): + recorder = object.__new__(MLflowRecorder) + recorder.id = "run-id" + recorder.client = DummyClient() + stderr = io.StringIO() + + old_cwd = os.getcwd() + with tempfile.TemporaryDirectory() as tmpdir: + try: + os.chdir(tmpdir) + with contextlib.redirect_stderr(stderr): + recorder._log_uncommitted_code() + finally: + os.chdir(old_cwd) + + self.assertEqual(stderr.getvalue(), "") + self.assertEqual(recorder.client.logged_text, []) + if __name__ == "__main__": unittest.main()