Skip to content

Commit 1c29a9f

Browse files
guptaakacopybara-github
authored andcommitted
Add background log streaming to detect TPU placement completion
A background thread watches for specific log messages indicating that the proxy pod is waiting for placement until the TPU placement process has finished. This allows for better tracking of the Pathways service readiness. Continued "waiting" messages from proxy might indicate that the Pathways service doesn't have enough TPU availability to process the request. PiperOrigin-RevId: 888835053
1 parent 44d0853 commit 1c29a9f

2 files changed

Lines changed: 83 additions & 2 deletions

File tree

pathwaysutils/experimental/shared_pathways_service/gke_utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
import socket
55
import subprocess
6+
import time
67
import urllib.parse
78

89
import portpicker
@@ -189,6 +190,7 @@ def wait_for_pod(job_name: str) -> str:
189190
RuntimeError: If the pod is not ready.
190191
"""
191192
_logger.info("Waiting for pod to be created...")
193+
time.sleep(1)
192194
pod_name = get_pod_from_job(job_name)
193195

194196
_logger.info(
@@ -296,6 +298,33 @@ def enable_port_forwarding(
296298
return (port_available, port_forward_process)
297299

298300

301+
def stream_pod_logs(pod_name: str) -> subprocess.Popen[str]:
302+
"""Streams logs from the given pod.
303+
304+
Args:
305+
pod_name: The name of the pod.
306+
307+
Returns:
308+
The process for streaming the logs.
309+
310+
Raises:
311+
Exception: If the log streaming fails.
312+
"""
313+
command = ["kubectl", "logs", "-f", pod_name]
314+
try:
315+
process = subprocess.Popen(
316+
command,
317+
stdout=subprocess.PIPE,
318+
stderr=subprocess.PIPE,
319+
text=True,
320+
bufsize=1, # Line buffered
321+
)
322+
return process
323+
except Exception as e:
324+
_logger.exception("Error streaming logs for pod %s: %r", pod_name, e)
325+
raise
326+
327+
299328
def delete_gke_job(job_name: str) -> None:
300329
"""Deletes the given job from the GKE cluster.
301330

pathwaysutils/experimental/shared_pathways_service/isc_pathways.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import random
1010
import string
1111
import subprocess
12+
import threading
1213
from typing import Any
1314

1415
import jax
@@ -123,6 +124,41 @@ def _deploy_pathways_proxy_server(
123124
_logger.info("Successfully deployed Pathways proxy.")
124125

125126

127+
def _wait_for_placement(pod_name: str) -> None:
128+
"""Waits for the placement to be complete by checking proxy logs."""
129+
_logger.info("Streaming proxy logs until the placement is complete...")
130+
log_process = gke_utils.stream_pod_logs(pod_name)
131+
132+
keywords = [
133+
"placement",
134+
"Signaling to RM",
135+
"Transition slice",
136+
"FAILED_PRECONDITION",
137+
]
138+
end_phrase = "unplaced -> placed"
139+
140+
if log_process.stdout:
141+
for line in iter(log_process.stdout.readline, ""):
142+
line_lower = line.lower()
143+
if any(keyword.lower() in line_lower for keyword in keywords):
144+
_logger.info("Proxy log: %s", line.strip())
145+
146+
if end_phrase.lower() in line_lower:
147+
_logger.info("TPU placement complete: %s", line.strip())
148+
break
149+
_logger.info("Closing log process stdout.")
150+
log_process.stdout.close()
151+
152+
# Ensure the process is terminated
153+
log_process.terminate()
154+
try:
155+
log_process.wait(timeout=5)
156+
except subprocess.TimeoutExpired:
157+
_logger.warning("Log streaming process did not terminate gracefully.")
158+
log_process.kill()
159+
_logger.info("Finished waiting for placement.")
160+
161+
126162
def _restore_env_var(key: str, original_value: str | None) -> None:
127163
"""Restores an environment variable to its original value or unsets it."""
128164
if original_value is None:
@@ -147,6 +183,7 @@ class _ISCPathways:
147183
expected_tpu_instances: A dictionary mapping TPU machine types to the number
148184
of instances.
149185
proxy_job_name: The name to use for the deployed proxy.
186+
proxy_pod_name: The name of the proxy pod, assigned during deployment.
150187
proxy_server_image: The image to use for the proxy server.
151188
proxy_options: Configuration options for the Pathways proxy.
152189
"""
@@ -171,6 +208,7 @@ def __init__(
171208
self.pathways_service = pathways_service
172209
self.expected_tpu_instances = expected_tpu_instances
173210
self._proxy_job_name = proxy_job_name
211+
self.proxy_pod_name = ""
174212
self._port_forward_process = None
175213
self._proxy_port = None
176214
self.proxy_server_image = proxy_server_image
@@ -218,9 +256,11 @@ def __enter__(self):
218256
)
219257
_logger.info("View proxy logs in Cloud Logging: %s", cloud_logging_link)
220258

221-
proxy_pod = gke_utils.wait_for_pod(self._proxy_job_name)
259+
self.proxy_pod_name = gke_utils.wait_for_pod(self._proxy_job_name)
222260
self._proxy_port, self._port_forward_process = (
223-
gke_utils.enable_port_forwarding(proxy_pod, PROXY_SERVER_PORT)
261+
gke_utils.enable_port_forwarding(
262+
self.proxy_pod_name, PROXY_SERVER_PORT
263+
)
224264
)
225265

226266
# Update the JAX backend to use the proxy.
@@ -349,4 +389,16 @@ def connect(
349389
proxy_server_image=proxy_server_image,
350390
proxy_options=proxy_options,
351391
) as t:
392+
if t.proxy_pod_name:
393+
placement_thread = threading.Thread(
394+
target=_wait_for_placement,
395+
args=(t.proxy_pod_name,),
396+
daemon=True,
397+
)
398+
placement_thread.start()
399+
else:
400+
_logger.warning(
401+
"proxy_pod_name not set on _ISCPathways instance, skipping background"
402+
" _wait_for_placement."
403+
)
352404
yield t

0 commit comments

Comments
 (0)