Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions qlib/workflow/expm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
# Licensed under the MIT License.

from urllib.parse import urlparse
from urllib.request import url2pathname
import mlflow
from filelock import FileLock
from mlflow.exceptions import MlflowException, RESOURCE_ALREADY_EXISTS, ErrorCode
from mlflow.entities import ViewType
import os
from typing import Optional, Text
from pathlib import Path

Expand All @@ -19,6 +19,14 @@
logger = get_module_logger("workflow")


def _file_uri_to_path(uri: Text) -> Path:
pr = urlparse(uri)
path = url2pathname(pr.path)
if pr.netloc and pr.netloc != "localhost":
path = f"//{pr.netloc}{path}"
return Path(path)


class ExpManager:
"""
This is the `ExpManager` class for managing experiments. The API is designed similar to mlflow.
Expand Down Expand Up @@ -233,7 +241,7 @@ def _get_or_create_exp(self, experiment_id=None, experiment_name=None) -> (objec
# So we supported it in the interface wrapper
pr = urlparse(self.uri)
if pr.scheme == "file":
with FileLock(Path(os.path.join(pr.netloc, pr.path.lstrip("/"), "filelock"))): # pylint: disable=E0110
with FileLock(_file_uri_to_path(self.uri) / "filelock"): # pylint: disable=E0110
return self.create_exp(experiment_name), True
# NOTE: for other schemes like http, we double check to avoid create exp conflicts
try:
Expand Down
28 changes: 21 additions & 7 deletions qlib/workflow/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,16 +366,30 @@ def _log_uncommitted_code(self):
"""
# TODO: the sub-directories maybe git repos.
# So it will be better if we can walk the sub-directories and log the uncommitted changes.
try:
proc = subprocess.run(
["git", "rev-parse", "--is-inside-work-tree"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=True,
)
except (FileNotFoundError, subprocess.CalledProcessError):
logger.debug(f"Skip logging uncommitted code because $CWD({os.getcwd()}) is not a git work tree.")
return
if proc.stdout.decode().strip() != "true":
logger.debug(f"Skip logging uncommitted code because $CWD({os.getcwd()}) is not a git work tree.")
return

for cmd, fname in [
("git diff", "code_diff.txt"),
("git status", "code_status.txt"),
("git diff --cached", "code_cached.txt"),
(["git", "diff"], "code_diff.txt"),
(["git", "status"], "code_status.txt"),
(["git", "diff", "--cached"], "code_cached.txt"),
]:
try:
out = subprocess.check_output(cmd, shell=True)
self.client.log_text(self.id, out.decode(), fname) # this behaves same as above
except subprocess.CalledProcessError:
logger.info(f"Fail to log the uncommitted code of $CWD({os.getcwd()}) when run {cmd}.")
out = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True)
self.client.log_text(self.id, out.stdout.decode(), fname) # this behaves same as above
except (FileNotFoundError, subprocess.CalledProcessError):
logger.debug(f"Fail to log the uncommitted code of $CWD({os.getcwd()}) when run {' '.join(cmd)}.")

def end_run(self, status: str = Recorder.STATUS_S):
assert status in [
Expand Down
37 changes: 37 additions & 0 deletions tests/dependency_tests/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,25 @@
# Licensed under the MIT License.
import unittest
import platform
import contextlib
import io
import mlflow
import os
import time
from pathlib import Path
import shutil
import tempfile

from qlib.workflow.expm import _file_uri_to_path
from qlib.workflow.recorder import MLflowRecorder


class DummyClient:
def __init__(self):
self.logged_text = []

def log_text(self, run_id, text, fname):
self.logged_text.append((run_id, text, fname))


class MLflowTest(unittest.TestCase):
Expand Down Expand Up @@ -33,6 +48,28 @@ def test_creating_client(self):
self.assertLess(elapsed, 2e-2)
print(elapsed)

def test_file_uri_to_path_keeps_absolute_paths(self):
self.assertEqual(_file_uri_to_path("file:///tmp/qlib/mlruns"), Path("/tmp/qlib/mlruns"))
self.assertEqual(_file_uri_to_path("file:/tmp/qlib/mlruns"), Path("/tmp/qlib/mlruns"))

def test_log_uncommitted_code_skips_non_git_cwd_quietly(self):
recorder = object.__new__(MLflowRecorder)
recorder.id = "run-id"
recorder.client = DummyClient()
stderr = io.StringIO()

old_cwd = os.getcwd()
with tempfile.TemporaryDirectory() as tmpdir:
try:
os.chdir(tmpdir)
with contextlib.redirect_stderr(stderr):
recorder._log_uncommitted_code()
finally:
os.chdir(old_cwd)

self.assertEqual(stderr.getvalue(), "")
self.assertEqual(recorder.client.logged_text, [])


if __name__ == "__main__":
unittest.main()