-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdeploy_lightning_inference.py
More file actions
133 lines (108 loc) · 4.18 KB
/
deploy_lightning_inference.py
File metadata and controls
133 lines (108 loc) · 4.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from __future__ import annotations
import argparse
import json
import os
from pathlib import Path
import sys
ROOT_DIR = Path(__file__).resolve().parent
QP_SRC_DIR = ROOT_DIR / "quant_platform" / "src"
if str(QP_SRC_DIR) not in sys.path:
sys.path.insert(0, str(QP_SRC_DIR))
from lightning_cloud_utils import ( # noqa: E402
ensure_auth_env,
find_app_by_name,
get_client_and_project,
json_safe,
phase_name,
set_process_env,
)
try: # noqa: E402
from lightning.app.runners.runtime import dispatch
from lightning.app.runners.runtime_type import RuntimeType
except ModuleNotFoundError: # noqa: E402
from lightning_app.runners.runtime import dispatch
from lightning_app.runners.runtime_type import RuntimeType
ENV_KEYS = (
"PATH",
"TRAINED_MODEL_BASE_MODEL",
"TRAINED_MODEL_NAME",
"TRAINED_MODEL_CPU_THREADS",
"TRAINED_MODEL_CPU",
"TRAINED_MODEL_API_KEY",
"TRAINED_MODEL_ADAPTER_PATH",
"TRAINED_MODEL_ADAPTER_ARCHIVE_URL",
"TRAINED_MODEL_ADAPTER_ARCHIVE_TOKEN",
"TRAINED_MODEL_CACHE_DIR",
"LIGHTNING_INFERENCE_COMPUTE_NAME",
"LIGHTNING_INFERENCE_DISK_GB",
"LIGHTNING_INFERENCE_PORT",
"TRAINED_MODEL_LOG_LEVEL",
"LIGHTNING_CLOUD_PROJECT_ID",
)
def _default_entrypoint() -> Path:
bundle_entrypoint = ROOT_DIR / "lightning_inference_bundle" / "lightning_trained_model_app.py"
if bundle_entrypoint.exists():
return bundle_entrypoint
return ROOT_DIR / "lightning_trained_model_app.py"
def _patch_lightning_dispatch_compat() -> None:
# Lightning 2.3.2 can pass an unsupported `app_id` kwarg into older
# lightning-cloud OpenAPI clients. Drop it so cloud dispatch works.
try:
from lightning_cloud.openapi.api.lightningapp_instance_service_api import LightningappInstanceServiceApi
except Exception:
return
orig = LightningappInstanceServiceApi.lightningapp_instance_service_list_lightningapp_instances
orig_with_http_info = LightningappInstanceServiceApi.lightningapp_instance_service_list_lightningapp_instances_with_http_info
def patched(self, project_id, **kwargs):
kwargs.pop("app_id", None)
return orig(self, project_id, **kwargs)
def patched_with_http_info(self, project_id, **kwargs):
kwargs.pop("app_id", None)
return orig_with_http_info(self, project_id, **kwargs)
LightningappInstanceServiceApi.lightningapp_instance_service_list_lightningapp_instances = patched
LightningappInstanceServiceApi.lightningapp_instance_service_list_lightningapp_instances_with_http_info = (
patched_with_http_info
)
def _collect_env() -> dict[str, str]:
env_vars: dict[str, str] = {}
for key in ENV_KEYS:
value = os.getenv(key)
if value:
env_vars[key] = value
return env_vars
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--app-name", default="trading-bot-lightning-inference")
parser.add_argument("--blocking", action="store_true")
parser.add_argument("--open-ui", action="store_true")
args = parser.parse_args()
auth_env = ensure_auth_env()
set_process_env(auth_env)
_patch_lightning_dispatch_compat()
project_id = str(os.getenv("LIGHTNING_CLOUD_PROJECT_ID") or os.getenv("LIGHTNING_PROJECT_ID") or "").strip() or None
client, project = get_client_and_project(project_id=project_id)
entrypoint = _default_entrypoint()
env_vars = _collect_env()
dispatch(
entrypoint,
RuntimeType.CLOUD,
start_server=False,
no_cache=False,
blocking=args.blocking,
open_ui=args.open_ui,
name=args.app_name,
env_vars=env_vars,
secrets={},
)
latest = find_app_by_name(client, project.project_id, args.app_name)
payload = {
"project_id": project.project_id,
"project_name": project.name,
"app_name": args.app_name,
"app_id": getattr(latest, "id", None) if latest else None,
"phase": phase_name(latest) if latest else None,
"note": "Copy the Lightning service URL from the app layout once the inference work is running.",
}
print(json.dumps(json_safe(payload), indent=2))
if __name__ == "__main__":
main()