@@ -64,6 +64,29 @@ def _find_uv() -> str:
6464 return uv
6565
6666
67+ def _activate_accelerate (opts : dict [str , str | None ]) -> None :
68+ """Enable model loading acceleration after environment is ready."""
69+ from zerostart .accelerate import accelerate
70+ accelerate (cache_dir = opts .get ("cache_dir" ))
71+
72+
73+ def _activate_env (site_packages : Path ) -> None :
74+ """Activate the zerostart environment by adding it to sys.path and isolating from system packages.
75+
76+ System dist-packages (e.g. /usr/local/lib/python3.11/dist-packages) can contain
77+ packages incompatible with the zerostart env (e.g. torchvision built for an older torch).
78+ We remove system dist-packages from sys.path to prevent these from leaking through.
79+ """
80+ sys .path .insert (0 , str (site_packages ))
81+ # Remove system dist-packages to prevent incompatible system packages from leaking.
82+ # Keep the zerostart site-packages and stdlib paths.
83+ zs_sp = str (site_packages )
84+ sys .path [:] = [
85+ p for p in sys .path
86+ if p == zs_sp or "dist-packages" not in p
87+ ]
88+
89+
6790def _is_script (target : str ) -> bool :
6891 """Determine if target is a Python script (vs a package name)."""
6992 if target .endswith (".py" ):
@@ -408,7 +431,7 @@ def prepare_env(requirements: list[str]) -> tuple[Path, Path, ArtifactPlan | Non
408431 # Try uv-only fast path first
409432 if complete_marker .exists ():
410433 log .info ("Cache hit — environment ready" )
411- sys . path . insert ( 0 , str ( site_packages ) )
434+ _activate_env ( site_packages )
412435 return venv , site_packages , None , None , [], None
413436
414437 # Cold path: resolve to figure out which wheels need daemon
@@ -418,7 +441,7 @@ def prepare_env(requirements: list[str]) -> tuple[Path, Path, ArtifactPlan | Non
418441 if not plan .artifacts :
419442 log .warning ("No artifacts resolved" )
420443 complete_marker .touch ()
421- sys . path . insert ( 0 , str ( site_packages ) )
444+ _activate_env ( site_packages )
422445 return venv , site_packages , plan , None , [], None
423446
424447 log .info (
@@ -428,7 +451,7 @@ def prepare_env(requirements: list[str]) -> tuple[Path, Path, ArtifactPlan | Non
428451 len (plan .fast_wheels ),
429452 )
430453
431- sys . path . insert ( 0 , str ( site_packages ) )
454+ _activate_env ( site_packages )
432455
433456 # Install only SMALL wheels via uv (fast, metadata-sensitive)
434457 # Large wheels go through the daemon for streaming extraction
@@ -546,6 +569,7 @@ def run(
546569 script : str ,
547570 requirements : list [str ] | None = None ,
548571 requirements_file : str | None = None ,
572+ accelerate_opts : dict [str , str | None ] | None = None ,
549573) -> None :
550574 """Run a Python script with lazy imports and progressive installation."""
551575 if requirements is None :
@@ -562,11 +586,17 @@ def run(
562586
563587 if not requirements :
564588 log .warning ("No requirements found — running script directly" )
589+ if accelerate_opts is not None :
590+ _activate_accelerate (accelerate_opts )
565591 exec (compile (open (script ).read (), script , "exec" ), {"__name__" : "__main__" })
566592 return
567593
568594 venv , site_packages , plan , daemon , whl_paths , uv_thread = prepare_env (requirements )
569595
596+ # Enable accelerate AFTER env is ready so we import the correct torch
597+ if accelerate_opts is not None :
598+ _activate_accelerate (accelerate_opts )
599+
570600 if not daemon :
571601 # Warm path or all-uv — just run
572602 exec (compile (open (script ).read (), script , "exec" ), {"__name__" : "__main__" })
@@ -590,6 +620,7 @@ def run_package(
590620 package : str ,
591621 args : list [str ] | None = None ,
592622 extra_packages : list [str ] | None = None ,
623+ accelerate_opts : dict [str , str | None ] | None = None ,
593624) -> None :
594625 """Install a package and run its console_script entry point."""
595626 if args is None :
@@ -601,6 +632,10 @@ def run_package(
601632
602633 venv , site_packages , plan , daemon , whl_paths , uv_thread = prepare_env (requirements )
603634
635+ # Enable accelerate AFTER env is ready so we import the correct torch
636+ if accelerate_opts is not None :
637+ _activate_accelerate (accelerate_opts )
638+
604639 if plan and not (venv / ".complete" ).exists ():
605640 # Wait for the target package's metadata to be on disk
606641 pkg_normalized = re .sub (r"[-_.]+" , "-" , package .split ("[" )[0 ]).lower ()
@@ -677,22 +712,27 @@ def main() -> None:
677712 datefmt = "%H:%M:%S" ,
678713 )
679714
715+ # NOTE: accelerate() must be called AFTER prepare_env() sets up the
716+ # environment, otherwise it imports system torch (which may be older)
717+ # and caches it in sys.modules before the zerostart env is on sys.path.
718+ accelerate_opts = None
680719 if args .accelerate :
681- from zerostart .accelerate import accelerate
682- accelerate (cache_dir = args .model_cache_dir )
720+ accelerate_opts = {"cache_dir" : args .model_cache_dir }
683721
684722 if _is_script (args .target ):
685723 sys .argv = [args .target ] + args .target_args
686724 run (
687725 script = args .target ,
688726 requirements = args .packages ,
689727 requirements_file = args .requirements ,
728+ accelerate_opts = accelerate_opts ,
690729 )
691730 else :
692731 run_package (
693732 package = args .target ,
694733 args = args .target_args ,
695734 extra_packages = args .packages ,
735+ accelerate_opts = accelerate_opts ,
696736 )
697737
698738
0 commit comments