Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/easymode/core/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,41 @@ def list_remote_models():
except Exception as e:
print(f"Error listing remote models: {e}")
return []


FIXED_MODEL_TITLES = {"n2n_splits", "n2n_direct", "ddw_splits", "ddw_direct", "tilt"}


def download_models(model_titles=None, silent=False):
"""Download model weights for offline use. If model_titles is None or empty, download all; otherwise download only the listed models."""
if not is_online():
print("\nAn internet connection is required to download models. Aborting.\n")
return
if not model_titles:
models = list_remote_models()
if not models:
return
for d in models:
if d["has_3d"]:
get_model(d["title"], silent=silent)
if d["has_2d"]:
get_model(d["title"], _2d=True, silent=silent)
for t in FIXED_MODEL_TITLES:
get_model(t, silent=silent)
else:
models = list_remote_models()
if not models:
return
segment_by_title = {d["title"].lower(): d for d in models}
for name in model_titles:
key = name.lower()
if key in segment_by_title:
d = segment_by_title[key]
if d["has_3d"]:
get_model(d["title"], silent=silent)
if d["has_2d"]:
get_model(d["title"], _2d=True, silent=silent)
elif name in FIXED_MODEL_TITLES:
get_model(name, silent=silent)
else:
print(f"Warning: unknown model '{name}' - skipping.")
7 changes: 7 additions & 0 deletions src/easymode/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def main():

subparsers.add_parser('list', help='List the features for which pretrained general segmentation networks are available.')

download = subparsers.add_parser('download', help='Download model weights for offline use. With no arguments, download all models; otherwise download only the listed models (e.g. ribosome, n2n_direct). Use "easymode list" to see segment feature names. Requires internet.')
download.add_argument('model_titles', metavar='MODEL', nargs='*', type=str, help='One or more model names to download (default: all).')
download.add_argument('--quiet', action='store_true', help='Suppress per-model download messages.')

if os.path.exists('/lmb/home/mlast/easymode_dev'):
package = subparsers.add_parser('package', description='Package model and weights. Note that this is used for 3D models only; 2D models are packaged and distributed with Ais.')
package.add_argument('-c', "--checkpoint_directory", type=str, required=True, help="Path to the checkpoint directory to package from.")
Expand Down Expand Up @@ -260,6 +264,9 @@ def main():
elif args.command == 'list':
from easymode.core.distribution import list_remote_models
list_remote_models()
elif args.command == 'download':
from easymode.core.distribution import download_models
download_models(model_titles=args.model_titles if args.model_titles else None, silent=args.quiet)

if __name__ == "__main__":
main()
Expand Down
Loading