From b362ee321ece75f67263804f3779aed418a44f24 Mon Sep 17 00:00:00 2001 From: Nathan Simpson Date: Mon, 13 Jan 2025 16:07:54 +0000 Subject: [PATCH 1/2] clean up repo structure --- .bumpversion.cfg | 8 -- .coveragerc | 2 - .github/workflows/ci.yml | 54 +++++++++ .github/workflows/release.yaml | 10 -- .github/workflows/workflows.yaml | 18 --- .gitignore | 35 +++++- .pre-commit-config.yaml | 45 +++---- .readthedocs.yml | 8 -- CODE_OF_CONDUCT.md | 47 ++++++++ CONTRIBUTING.md | 53 +++++++++ MANIFEST.in | 1 - README.md | 34 +++--- pyproject.toml | 112 ++++++++++++++++++ requirements.txt | 6 - ruff.toml | 57 --------- run.sh | 1 - sat_pred/__init__.py | 1 - setup.py | 24 ---- src/sat_pred/__init__.py | 7 ++ .../sat_pred}/load_model_from_checkpoint.py | 0 {sat_pred => src/sat_pred}/loss.py | 0 .../sat_pred}/models/earthformer_model.py | 0 .../sat_pred}/models/simvp_model.py | 0 {sat_pred => src/sat_pred}/optimizers.py | 0 {sat_pred => src/sat_pred}/ssim.py | 0 {sat_pred => src/sat_pred}/training_module.py | 0 sat_pred/train.py => train.py | 0 27 files changed, 344 insertions(+), 179 deletions(-) delete mode 100644 .bumpversion.cfg delete mode 100644 .coveragerc create mode 100644 .github/workflows/ci.yml delete mode 100644 .github/workflows/release.yaml delete mode 100644 .github/workflows/workflows.yaml delete mode 100644 .readthedocs.yml create mode 100644 CODE_OF_CONDUCT.md create mode 100644 CONTRIBUTING.md delete mode 100644 MANIFEST.in create mode 100644 pyproject.toml delete mode 100644 requirements.txt delete mode 100644 ruff.toml delete mode 100644 run.sh delete mode 100644 sat_pred/__init__.py delete mode 100644 setup.py create mode 100644 src/sat_pred/__init__.py rename {sat_pred => src/sat_pred}/load_model_from_checkpoint.py (100%) rename {sat_pred => src/sat_pred}/loss.py (100%) rename {sat_pred => src/sat_pred}/models/earthformer_model.py (100%) rename {sat_pred => src/sat_pred}/models/simvp_model.py (100%) rename {sat_pred => src/sat_pred}/optimizers.py (100%) rename {sat_pred => src/sat_pred}/ssim.py (100%) rename {sat_pred => src/sat_pred}/training_module.py (100%) rename sat_pred/train.py => train.py (100%) diff --git a/.bumpversion.cfg b/.bumpversion.cfg deleted file mode 100644 index 930ac03..0000000 --- a/.bumpversion.cfg +++ /dev/null @@ -1,8 +0,0 @@ -[bumpversion] -commit = False -tag = True -current_version = 0.0.1 - -[bumpversion:file:setup.py] -search = version="{current_version}" -replace = version="{new_version}" diff --git a/.coveragerc b/.coveragerc deleted file mode 100644 index c712d25..0000000 --- a/.coveragerc +++ /dev/null @@ -1,2 +0,0 @@ -[run] -omit = tests/* diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..105bbf0 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,54 @@ +name: CI + +on: + workflow_dispatch: + pull_request: + push: + branches: + - main + +jobs: + pre-commit: + name: Format + lint code + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - uses: actions/setup-python@v4 + with: + python-version: "3.x" + - uses: pre-commit/action@v3.0.0 + with: + extra_args: --hook-stage manual --all-files + + checks: + name: Run tests for Python ${{ matrix.python-version }} on ${{ matrix.runs-on }} + runs-on: ${{ matrix.runs-on }} + needs: [pre-commit] + strategy: + fail-fast: false + matrix: + python-version: ["3.10", "3.12"] # test oldest and latest supported versions + runs-on: [ubuntu-latest, macos-latest] # can be extended to other OSes, e.g. [ubuntu-latest, macos-latest, windows-latest] + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + allow-prereleases: true + + - name: Install package + run: python -m pip install .[dev] + + - name: Test package + run: >- + python -m pytest -ra --cov --cov-report=xml --cov-report=term + --durations=20 + + - name: Upload coverage report + uses: codecov/codecov-action@v3.1.4 diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml deleted file mode 100644 index 3b6e961..0000000 --- a/.github/workflows/release.yaml +++ /dev/null @@ -1,10 +0,0 @@ -name: Bump version and auto-release -on: - push: - branches: - - main -jobs: - call-run-python-release: - uses: openclimatefix/.github/.github/workflows/python-release.yml@main - secrets: - token: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.github/workflows/workflows.yaml b/.github/workflows/workflows.yaml deleted file mode 100644 index 727de59..0000000 --- a/.github/workflows/workflows.yaml +++ /dev/null @@ -1,18 +0,0 @@ -name: Python package tests - -on: - push: - schedule: - - cron: "0 12 * * 1" -jobs: - call-run-python-tests: - uses: openclimatefix/.github/.github/workflows/python-test.yml@main - with: - # 0 means don't use pytest-xdist - pytest_numcpus: "4" - # pytest-cov looks at this folder - pytest_cov_dir: "src" - # extra things to install - sudo_apt_install: "libgeos++-dev libproj-dev proj-data proj-bin" - # brew_install: "proj geos librttopo" - os_list: '["ubuntu-latest"]' diff --git a/.gitignore b/.gitignore index fa4c1b7..25cf9a4 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,7 @@ __pycache__/ # C extensions *.so -.idea/ + # Distribution / packaging .Python build/ @@ -20,7 +20,6 @@ parts/ sdist/ var/ wheels/ -pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg @@ -50,6 +49,7 @@ coverage.xml *.py,cover .hypothesis/ .pytest_cache/ +cover/ # Translations *.mo @@ -72,6 +72,7 @@ instance/ docs/_build/ # PyBuilder +.pybuilder/ target/ # Jupyter Notebook @@ -82,7 +83,9 @@ profile_default/ ipython_config.py # pyenv -.python-version +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. @@ -127,3 +130,29 @@ dmypy.json # Pyre type checker .pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# setuptools_scm +src/*/_version.py + + +# ruff +.ruff_cache/ + +# OS specific stuff +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# Common editor files +*~ +*.swp diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e92053a..9fa74ae 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,31 +1,32 @@ -default_language_version: - python: python3 - repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: "v5.0.0" hooks: - # list of supported hooks: https://pre-commit.com/hooks.html - - id: trailing-whitespace - - id: end-of-file-fixer + - id: check-added-large-files + - id: check-case-conflict + - id: check-merge-conflict + - id: check-symlinks + - id: check-yaml - id: debug-statements - - id: detect-private-key + - id: end-of-file-fixer + - id: mixed-line-ending + - id: name-tests-test + args: ["--pytest-test-first"] + - id: requirements-txt-fixer + - id: trailing-whitespace - # python code formatting/linting - - repo: https://github.com/charliermarsh/ruff-pre-commit - # Ruff version. - rev: "v0.0.260" + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: "v0.9.1" hooks: + # first, lint + autofix - id: ruff - args: [--fix] - - repo: https://github.com/psf/black - rev: 23.3.0 - hooks: - - id: black - args: [--line-length, "100"] - # yaml formatting - - repo: https://github.com/pre-commit/mirrors-prettier - rev: v3.0.0-alpha.6 + types_or: [python, pyi, jupyter] + args: ["--fix", "--show-fixes"] + # then, format + - id: ruff-format + + - repo: https://github.com/rbubley/mirrors-prettier + rev: v3.4.2 hooks: - id: prettier - types: [yaml] + types: [yaml] \ No newline at end of file diff --git a/.readthedocs.yml b/.readthedocs.yml deleted file mode 100644 index c99dff2..0000000 --- a/.readthedocs.yml +++ /dev/null @@ -1,8 +0,0 @@ -version: 2 -mkdocs: {} # tell readthedocs to use mkdocs -python: - version: 3.8 - install: - - method: pip - path: . - - requirements: docs/requirements.txt diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..eee3ec3 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,47 @@ +# Code of Conduct + +We value the participation of every member of our community and want to ensure +that every contributor has an enjoyable and fulfilling experience. Accordingly, +everyone who participates in the sat_pred project is expected to show respect and courtesy to other community members at all time. + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers are dedicated to making participation in our project +a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behaviour that contributes to creating a positive environment +include: + +- Using welcoming and inclusive language +- Being respectful of differing viewpoints and experiences +- Gracefully accepting constructive criticism +- Focusing on what is best for the community +- Showing empathy towards other community members + +Examples of unacceptable behaviour by participants include: + +- The use of sexualized language or imagery and unwelcome sexual attention or + advances +- Trolling, insulting/derogatory comments, and personal or political attacks +- Public or private harassment +- Publishing others' private information, such as a physical or electronic + address, without explicit permission +- Other conduct which could reasonably be considered inappropriate in a + professional setting + + + +## Attribution + +This Code of Conduct is adapted from the [Turing Data Stories Code of Conduct](https://github.com/alan-turing-institute/TuringDataStories/blob/main/CODE_OF_CONDUCT.md) which is based on the [scona project Code of Conduct](https://github.com/WhitakerLab/scona/blob/master/CODE_OF_CONDUCT.md) +and the [Contributor Covenant](https://www.contributor-covenant.org), version [1.4](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..59cb971 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,53 @@ +See the [Scientific Python Developer Guide][spc-dev-intro] for a detailed +description of best practices for developing scientific packages. + +[spc-dev-intro]: https://learn.scientific-python.org/development/ + +# Setting up a development environment manually + +You can set up a development environment by running: + +```zsh +python3 -m venv venv # create a virtualenv called venv +source ./venv/bin/activate # now `python` points to the virtualenv python +pip install -v -e ".[dev]" # -v for verbose, -e for editable, [dev] for dev dependencies +``` + +# Post setup + +You should prepare pre-commit, which will help you by checking that commits pass +required checks: + +```bash +pip install pre-commit # or brew install pre-commit on macOS +pre-commit install # this will install a pre-commit hook into the git repo +``` + +You can also/alternatively run `pre-commit run` (changes only) or +`pre-commit run --all-files` to check even without installing the hook. + +# Testing + +Use pytest to run the unit checks: + +```bash +pytest +``` + +# Coverage + +Use pytest-cov to generate coverage reports: + +```bash +pytest --cov=sat_pred +``` + +# Pre-commit + +This project uses pre-commit for all style checking. Install pre-commit and run: + +```bash +pre-commit run -a +``` + +to check all files. diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index ab30e9a..0000000 --- a/MANIFEST.in +++ /dev/null @@ -1 +0,0 @@ -include *.txt diff --git a/README.md b/README.md index e138d21..2f0ab5f 100644 --- a/README.md +++ b/README.md @@ -1,29 +1,23 @@ -# A repo for training deterministic models to predict future satellite +![Actions Status][actions-badge]][actions-link] +# Satellite Image Forecasting with Neural Networks ## Installation -Create and activate a new python environment, e.g. - +After cloning and entering this repo, create a fresh Python environment (e.g. via `uv`, `venv`, `conda`), then install this package and its dependencies: ``` -conda create -n sat_pred python=3.10 -conda activate sat_pred +pip install . ``` -Clone this repo +### Developer installation +As above, but install in editable mode with the `dev` dependencies: ``` -git clone https://github.com/openclimatefix/sat_pred.git +pip install -e ".[dev]" ``` -Install this package and its dependencies - -``` -cd sat_pred -pip install -e . -``` -You will also need to install the cloudcasting package following the [instructions here](https://github.com/alan-turing-institute/cloudcasting) +## Training If you want to train the earthformer model you should clone and install the earthformer repo as well @@ -34,12 +28,10 @@ cd earth-forecasting-transformer pip install -e . ``` -## Training - You can train a model by running ``` -python sat_pred/train.py +python train.py ``` from the root of the library. @@ -71,7 +63,13 @@ python sat_pred/train.py model=earthformer model_name="earthformer-v1" model.opt will train the model defined in `configs/model/earthformer.yaml` log ther training results to wandb under the name `earthformer-v1`. It will also overwrite the learning rate of the optimiser to 0.0002. - + +[actions-badge]: https://github.com/openclimatefix/sat_pred/workflows/CI/badge.svg +[actions-link]: https://github.com/openclimatefix/sat_pred/actions +[pypi-link]: https://pypi.org/project/sat_pred/ +[pypi-platforms]: https://img.shields.io/pypi/pyversions/sat_pred +[pypi-version]: https://img.shields.io/pypi/v/sat_pred + diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..cca05a6 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,112 @@ +[build-system] +requires = ["setuptools>=61"] +build-backend = "setuptools.build_meta" + +[project] +name = "sat-pred" +version = "0.1.0" +authors = [ + { name = "Open Climate Fix", email = "info@openclimatefix.org" }, +] +description = "Repository for the forecasting of satellite data, using models developed by Open Climate Fix and The Alan Turing Institute." +readme = "README.md" +requires-python = ">=3.10" +classifiers = [ + "Development Status :: 1 - Planning", + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering", +] +dependencies = [ + "lightning", + "torch", + "numpy", + "hydra-core", + "matplotlib", + "pyaml_env", + "git+https://github.com/alan-turing-institute/cloudcasting.git", +] + +[project.optional-dependencies] +dev = [ + "pytest >=6", + "pytest-cov >=3", + "pre-commit", +] + +[project.urls] +Homepage = "https://github.com/openclimatefix/sat_pred" +"Bug Tracker" = "https://github.com/openclimatefix/sat_pred/issues" +Discussions = "https://github.com/openclimatefix/sat_pred/discussions" +Changelog = "https://github.com/openclimatefix/sat_pred/releases" + +[tool.pytest.ini_options] +minversion = "6.0" +addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"] +xfail_strict = true +filterwarnings = [ + "error", +] +log_cli_level = "INFO" +testpaths = [ + "tests", +] + +[tool.coverage] +run.source = ["sat_pred"] +port.exclude_lines = [ + 'pragma: no cover', + '\.\.\.', + 'if typing.TYPE_CHECKING:', +] + + + +[tool.ruff] +src = ["src"] +exclude = [] +line-length = 100 # how long you want lines to be + +[tool.ruff.format] +docstring-code-format = true # code snippets in docstrings will be formatted + +[tool.ruff.lint] +select = [ + "E", "F", "W", # flake8 + "B", # flake8-bugbear + "I", # isort + "ARG", # flake8-unused-arguments + "C4", # flake8-comprehensions + "EM", # flake8-errmsg + "ICN", # flake8-import-conventions + "ISC", # flake8-implicit-str-concat + "G", # flake8-logging-format + "PGH", # pygrep-hooks + "PIE", # flake8-pie + "PL", # pylint + "PT", # flake8-pytest-style + "RET", # flake8-return + "RUF", # Ruff-specific + "SIM", # flake8-simplify + "UP", # pyupgrade + "YTT", # flake8-2020 + "EXE", # flake8-executable +] +ignore = [ + "PLR", # Design related pylint codes + "ISC001", # Conflicts with formatter +] +unfixable = [ + "F401", # Would remove unused imports + "F841", # Would remove unused variables +] +flake8-unused-arguments.ignore-variadic-names = true # allow unused *args/**kwargs diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index ade5916..0000000 --- a/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -lightning -torch -numpy -hydra-core -matplotlib -pyaml_env diff --git a/ruff.toml b/ruff.toml deleted file mode 100644 index 5df253f..0000000 --- a/ruff.toml +++ /dev/null @@ -1,57 +0,0 @@ -# Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default. -select = ["B", "E", "F", "D", "I"] -ignore = ["D200","D202","D210","D212","D415","D105",] - -# Allow autofix for all enabled rules (when `--fix`) is provided. -fixable = ["A", "B", "C", "D", "E", "F", "I"] -unfixable = [] - -# Exclude a variety of commonly ignored directories. -exclude = [ - ".bzr", - ".direnv", - ".eggs", - ".git", - ".hg", - ".mypy_cache", - ".nox", - ".pants.d", - ".pytype", - ".ruff_cache", - ".svn", - ".tox", - ".venv", - "__pypackages__", - "_build", - "buck-out", - "build", - "dist", - "node_modules", - "venv", - "tests", -] - -# Same as Black. -line-length = 100 - -# Allow unused variables when underscore-prefixed. -dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" - -# Assume Python 3.10. -target-version = "py310" -fix = false - -# Group violations by containing file. -format = "github" -ignore-init-module-imports = true - -[mccabe] -# Unlike Flake8, default to a complexity level of 10. -max-complexity = 10 - -[pydocstyle] -# Use Google-style docstrings. -convention = "google" - -[per-file-ignores] -"__init__.py" = ["F401", "E402"] diff --git a/run.sh b/run.sh deleted file mode 100644 index 40f42ae..0000000 --- a/run.sh +++ /dev/null @@ -1 +0,0 @@ -python sat_pred/train.py model_name="earthformer" model=earthformer diff --git a/sat_pred/__init__.py b/sat_pred/__init__.py deleted file mode 100644 index 5baa628..0000000 --- a/sat_pred/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""sat_pred""" diff --git a/setup.py b/setup.py deleted file mode 100644 index 704312c..0000000 --- a/setup.py +++ /dev/null @@ -1,24 +0,0 @@ -""" Usual setup file for package """ -# read the contents of your README file -from pathlib import Path - -from setuptools import find_packages, setup - -this_directory = Path(__file__).parent -long_description = (this_directory / "README.md").read_text() -install_requires = (this_directory / "requirements.txt").read_text().splitlines() - -setup( - name="sat_pred", - version="0.0.1", - license="MIT", - description="A starter repo for predicting future satellite", - author="Open Climate Fix", - author_email="info@openclimatefix.org", - company="Open Climate Fix Ltd", - install_requires=install_requires, - long_description=long_description, - long_description_content_type="text/markdown", - include_package_data=True, - packages=find_packages(), -) diff --git a/src/sat_pred/__init__.py b/src/sat_pred/__init__.py new file mode 100644 index 0000000..e24bcbc --- /dev/null +++ b/src/sat_pred/__init__.py @@ -0,0 +1,7 @@ +""" +sat_pred: Repository for the forecasting of satellite data, using models developed by Open Climate Fix and The Alan Turing Institute. +""" +from importlib.metadata import version + +__all__ = ("__version__",) +__version__ = version(__name__) diff --git a/sat_pred/load_model_from_checkpoint.py b/src/sat_pred/load_model_from_checkpoint.py similarity index 100% rename from sat_pred/load_model_from_checkpoint.py rename to src/sat_pred/load_model_from_checkpoint.py diff --git a/sat_pred/loss.py b/src/sat_pred/loss.py similarity index 100% rename from sat_pred/loss.py rename to src/sat_pred/loss.py diff --git a/sat_pred/models/earthformer_model.py b/src/sat_pred/models/earthformer_model.py similarity index 100% rename from sat_pred/models/earthformer_model.py rename to src/sat_pred/models/earthformer_model.py diff --git a/sat_pred/models/simvp_model.py b/src/sat_pred/models/simvp_model.py similarity index 100% rename from sat_pred/models/simvp_model.py rename to src/sat_pred/models/simvp_model.py diff --git a/sat_pred/optimizers.py b/src/sat_pred/optimizers.py similarity index 100% rename from sat_pred/optimizers.py rename to src/sat_pred/optimizers.py diff --git a/sat_pred/ssim.py b/src/sat_pred/ssim.py similarity index 100% rename from sat_pred/ssim.py rename to src/sat_pred/ssim.py diff --git a/sat_pred/training_module.py b/src/sat_pred/training_module.py similarity index 100% rename from sat_pred/training_module.py rename to src/sat_pred/training_module.py diff --git a/sat_pred/train.py b/train.py similarity index 100% rename from sat_pred/train.py rename to train.py From 885aae509be5548d7adbb636eebea42ec5e5de12 Mon Sep 17 00:00:00 2001 From: Nathan Simpson Date: Mon, 13 Jan 2025 17:53:50 +0000 Subject: [PATCH 2/2] run formatting --- .github/workflows/ci.yml | 4 +- .pre-commit-config.yaml | 2 +- README.md | 5 +- configs/datamodule/default.yaml | 6 +- configs/model/earthformer.yaml | 104 ++++---- configs/model/finetune_msmae_simvp.yaml | 8 +- configs/model/finetune_simvp.yaml | 4 +- configs/model/simvp.yaml | 4 +- configs/model/simvp_v2.yaml | 4 +- configs/trainer/default.yaml | 2 +- scripts/backtest.py | 93 +++---- scripts/validate_model.py | 47 ++-- src/sat_pred/__init__.py | 1 + src/sat_pred/load_model_from_checkpoint.py | 18 +- src/sat_pred/loss.py | 10 +- src/sat_pred/models/earthformer_model.py | 5 +- src/sat_pred/models/simvp_model.py | 279 +++++++++++---------- src/sat_pred/optimizers.py | 14 +- src/sat_pred/ssim.py | 115 ++++----- src/sat_pred/training_module.py | 140 +++++------ train.py | 48 ++-- 21 files changed, 439 insertions(+), 474 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 105bbf0..ed77907 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -29,8 +29,8 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10", "3.12"] # test oldest and latest supported versions - runs-on: [ubuntu-latest, macos-latest] # can be extended to other OSes, e.g. [ubuntu-latest, macos-latest, windows-latest] + python-version: ["3.10", "3.12"] # test oldest and latest supported versions + runs-on: [ubuntu-latest, macos-latest] # can be extended to other OSes, e.g. [ubuntu-latest, macos-latest, windows-latest] steps: - uses: actions/checkout@v4 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9fa74ae..f7ffe79 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,4 +29,4 @@ repos: rev: v3.4.2 hooks: - id: prettier - types: [yaml] \ No newline at end of file + types: [yaml] diff --git a/README.md b/README.md index 2f0ab5f..b121d57 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ You can train a model by running python train.py ``` -from the root of the library. +from the root of the library. The model and training options used are defined in the config files. The most important parts of the config files you may wish to train are: @@ -70,6 +70,3 @@ will train the model defined in `configs/model/earthformer.yaml` log ther traini [pypi-platforms]: https://img.shields.io/pypi/pyversions/sat_pred [pypi-version]: https://img.shields.io/pypi/v/sat_pred - - - diff --git a/configs/datamodule/default.yaml b/configs/datamodule/default.yaml index 9b31be6..ede6b39 100644 --- a/configs/datamodule/default.yaml +++ b/configs/datamodule/default.yaml @@ -1,5 +1,5 @@ _target_: cloudcasting.dataset.SatelliteDataModule -zarr_path: +zarr_path: - /mnt/disks/sat_data/sat_data_all/2008_training_nonhrv.zarr - /mnt/disks/sat_data/sat_data_all/2009_training_nonhrv.zarr - /mnt/disks/sat_data/sat_data_all/2010_training_nonhrv.zarr @@ -13,10 +13,10 @@ history_mins: 165 forecast_mins: 180 sample_freq_mins: 15 train_period: ["2008-01-01 00:00", "2015-12-31 23:55"] # [start, end] -val_period: ["2016-01-01 00:00", "2016-12-31 23:55"] # [start, end] +val_period: ["2016-01-01 00:00", "2016-12-31 23:55"] # [start, end] num_workers: 8 prefetch_factor: 2 batch_size: 1 nan_to_num: true pin_memory: false -persistent_workers: true \ No newline at end of file +persistent_workers: true diff --git a/configs/model/earthformer.yaml b/configs/model/earthformer.yaml index c0cb112..7f572a4 100644 --- a/configs/model/earthformer.yaml +++ b/configs/model/earthformer.yaml @@ -1,65 +1,65 @@ _target_: sat_pred.training_module.TrainingModule -model: - _target_: sat_pred.models.earthformer_model.Earthformer - input_shape: [12, 372, 614, 11] # "NTHWC" - target_shape: [12, 372, 614, 11] # "NTHWC" - base_units: 128 - block_units: null - scale_alpha: 1.0 +model: + _target_: sat_pred.models.earthformer_model.Earthformer + input_shape: [12, 372, 614, 11] # "NTHWC" + target_shape: [12, 372, 614, 11] # "NTHWC" + base_units: 128 + block_units: null + scale_alpha: 1.0 - enc_depth: [1, 1] - dec_depth: [1, 1] - enc_use_inter_ffn: true - dec_use_inter_ffn: true - dec_hierarchical_pos_embed: false + enc_depth: [1, 1] + dec_depth: [1, 1] + enc_use_inter_ffn: true + dec_use_inter_ffn: true + dec_hierarchical_pos_embed: false - downsample: 2 - downsample_type: "patch_merge" - upsample_type: "upsample" + downsample: 2 + downsample_type: "patch_merge" + upsample_type: "upsample" - num_global_vectors: 8 - use_dec_self_global: false - dec_self_update_global: true - use_dec_cross_global: false - use_global_vector_ffn: false - use_global_self_attn: true - separate_global_qkv: true - global_dim_ratio: 1 + num_global_vectors: 8 + use_dec_self_global: false + dec_self_update_global: true + use_dec_cross_global: false + use_global_vector_ffn: false + use_global_self_attn: true + separate_global_qkv: true + global_dim_ratio: 1 - enc_attn_patterns: "axial" - dec_self_attn_patterns: "axial" - dec_cross_attn_patterns: "cross_1x1" - dec_cross_last_n_frames: null + enc_attn_patterns: "axial" + dec_self_attn_patterns: "axial" + dec_cross_attn_patterns: "cross_1x1" + dec_cross_last_n_frames: null - attn_drop: 0.1 - proj_drop: 0.1 - ffn_drop: 0.1 - num_heads: 4 + attn_drop: 0.1 + proj_drop: 0.1 + ffn_drop: 0.1 + num_heads: 4 - ffn_activation: "gelu" - gated_ffn: false - norm_layer: "layer_norm" - padding_type: "zeros" - pos_embed_type: "t+h+w" - use_relative_pos: true - self_attn_use_final_proj: true - dec_use_first_self_attn: false + ffn_activation: "gelu" + gated_ffn: false + norm_layer: "layer_norm" + padding_type: "zeros" + pos_embed_type: "t+h+w" + use_relative_pos: true + self_attn_use_final_proj: true + dec_use_first_self_attn: false - z_init_method: "zeros" - checkpoint_level: 0 + z_init_method: "zeros" + checkpoint_level: 0 - initial_downsample_type: "stack_conv" - initial_downsample_activation: "leaky" - initial_downsample_stack_conv_num_layers: 3 - initial_downsample_stack_conv_dim_list: [16, 64, 128] - initial_downsample_stack_conv_downscale_list: [3, 2, 2] - initial_downsample_stack_conv_num_conv_list: [2, 2, 2] + initial_downsample_type: "stack_conv" + initial_downsample_activation: "leaky" + initial_downsample_stack_conv_num_layers: 3 + initial_downsample_stack_conv_dim_list: [16, 64, 128] + initial_downsample_stack_conv_downscale_list: [3, 2, 2] + initial_downsample_stack_conv_num_conv_list: [2, 2, 2] - attn_linear_init_mode: "0" - ffn_linear_init_mode: "0" - conv_init_mode: "0" - down_up_linear_init_mode: "0" - norm_init_mode: "0" + attn_linear_init_mode: "0" + ffn_linear_init_mode: "0" + conv_init_mode: "0" + down_up_linear_init_mode: "0" + norm_init_mode: "0" optimizer: _target_: sat_pred.optimizers.AdamWReduceLROnPlateau diff --git a/configs/model/finetune_msmae_simvp.yaml b/configs/model/finetune_msmae_simvp.yaml index 96da902..69fd0db 100644 --- a/configs/model/finetune_msmae_simvp.yaml +++ b/configs/model/finetune_msmae_simvp.yaml @@ -1,14 +1,14 @@ _target_: sat_pred.training_module.TrainingModule -model: +model: from_pretrained: true checkpoint_dir: /home/jamesfulton/repos/sat_pred/checkpoints/ob9v9128 val_best: true optimizer: _target_: sat_pred.optimizers.AdamWReduceLROnPlateau lr: 0.0005 -target_loss: +target_loss: _target_: sat_pred.loss.MultiscaleMAE - scales: + scales: - [1, 1, 1] - [1, 4, 4] - [2, 8, 8] @@ -32,5 +32,3 @@ video_crop_plots: i: 120 j: 160 s: 30 - - diff --git a/configs/model/finetune_simvp.yaml b/configs/model/finetune_simvp.yaml index 3a33ef4..ff700ce 100644 --- a/configs/model/finetune_simvp.yaml +++ b/configs/model/finetune_simvp.yaml @@ -1,5 +1,5 @@ _target_: sat_pred.training_module.TrainingModule -model: +model: from_pretrained: true checkpoint_dir: /home/jamesfulton/repos/sat_pred/checkpoints/ob9v9128 val_best: true @@ -26,5 +26,3 @@ video_crop_plots: i: 120 j: 160 s: 30 - - diff --git a/configs/model/simvp.yaml b/configs/model/simvp.yaml index d720c88..905fdec 100644 --- a/configs/model/simvp.yaml +++ b/configs/model/simvp.yaml @@ -1,5 +1,5 @@ _target_: sat_pred.training_module.TrainingModule -model: +model: _target_: sat_pred.models.simvp_model.SimVP num_channels: 11 history_len: 12 @@ -27,5 +27,3 @@ video_crop_plots: i: 120 j: 160 s: 30 - - diff --git a/configs/model/simvp_v2.yaml b/configs/model/simvp_v2.yaml index cb4921d..3cce619 100644 --- a/configs/model/simvp_v2.yaml +++ b/configs/model/simvp_v2.yaml @@ -1,5 +1,5 @@ _target_: sat_pred.training_module.TrainingModule -model: +model: _target_: sat_pred.models.simvp_model.SimVP num_channels: 11 history_len: 12 @@ -32,5 +32,3 @@ video_crop_plots: i: 120 j: 160 s: 30 - - diff --git a/configs/trainer/default.yaml b/configs/trainer/default.yaml index 2181e91..e708f09 100644 --- a/configs/trainer/default.yaml +++ b/configs/trainer/default.yaml @@ -1,7 +1,7 @@ _target_: lightning.pytorch.trainer.trainer.Trainer accelerator: gpu -devices: [0,] # [0,1] +devices: [0] # [0,1] #strategy: ddp_spawn precision: 16-mixed gradient_clip_val: 0.5 diff --git a/scripts/backtest.py b/scripts/backtest.py index 8fa9f3e..cc2d8a5 100644 --- a/scripts/backtest.py +++ b/scripts/backtest.py @@ -5,25 +5,24 @@ except RuntimeError: pass -from cloudcasting.dataset import load_satellite_zarrs, find_valid_t0_times +import os +from datetime import datetime, timedelta +from glob import glob import hydra -import torch -from pyaml_env import parse_config import numpy as np import pandas as pd +import torch import xarray as xr -from datetime import datetime, timedelta +from cloudcasting.dataset import find_valid_t0_times, load_satellite_zarrs +from numcodecs import Blosc +from pyaml_env import parse_config from torch.utils.data import DataLoader, Dataset from tqdm import tqdm -import os -from glob import glob -from numcodecs import Blosc - checkpoint = "/home/jamesfulton/repos/sat_pred/checkpoints/ob9v9128" save_dir = "/mnt/disks/sat_preds/simvp_preds" -compressor = Blosc(cname='zstd', clevel=5, shuffle=Blosc.BITSHUFFLE) +compressor = Blosc(cname="zstd", clevel=5, shuffle=Blosc.BITSHUFFLE) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -34,7 +33,7 @@ def get_model_from_checkpoints( val_best: bool = True, ): """Load a model from its checkpoint directory - + Args: checkpoint_dir_path: Path to the checkpoint directory val_best: Whether to use the best performing checkpoint found during training, else uses @@ -50,52 +49,45 @@ def get_model_from_checkpoints( # Only one epoch (best) saved per model files = glob(f"{checkpoint_dir_path}/epoch*.ckpt") if len(files) != 1: + msg = f"Found {len(files)} checkpoints @ {checkpoint_dir_path}/epoch*.ckpt. Expected one." raise ValueError( - f"Found {len(files)} checkpoints @ {checkpoint_dir_path}/epoch*.ckpt. Expected one." + msg ) checkpoint = torch.load(files[0], map_location="cpu") else: checkpoint = torch.load(f"{checkpoint_dir_path}/last.ckpt", map_location="cpu") lightning_wrapped_model.load_state_dict(state_dict=checkpoint["state_dict"]) - + # discard the lightning wrapper on the model - model = lightning_wrapped_model.model + model = lightning_wrapped_model.model # Check for data config data_config = parse_config(f"{checkpoint_dir_path}/data_config.yaml") - return model, model_config, data_config - class MLModel: - def __init__(self, checkpoint_dir_path: str) -> None: - - model, model_config, data_config = get_model_from_checkpoints(checkpoint) self.model = model.to(DEVICE) - self.history_mins = (model_config["model"]['history_len'] - 1) * 15 + self.history_mins = (model_config["model"]["history_len"] - 1) * 15 self.model_config = model_config self.data_config = data_config self.checkpoint_dir_path = checkpoint_dir_path - def __call__(self, X): # The input X is a numpy array with shape (batch_size, channels, time, height, width) X = torch.Tensor(X).to(DEVICE) - + with torch.no_grad(): y_hat = self.model(X).cpu().numpy() - + # Clip the values to be between 0 and 1 - y_hat = y_hat.clip(0, 1) + return y_hat.clip(0, 1) - return y_hat - def backtest_collate_fn( @@ -110,10 +102,11 @@ def backtest_collate_fn( X_all[i] = X ts.append(t) return X_all, pd.to_datetime(ts) - + DataIndex = str | datetime | pd.Timestamp | int + class BacktestSatelliteDataset(Dataset): def __init__( self, @@ -149,7 +142,7 @@ def __init__( ) # Only do 30 minute intervals - self.t0_times = self.t0_times[self.t0_times.minute%30==0] + self.t0_times = self.t0_times[self.t0_times.minute % 30 == 0] self.history_mins = history_mins self.sample_freq_mins = sample_freq_mins @@ -157,9 +150,7 @@ def __init__( @staticmethod def _find_t0_times( - date_range: pd.DatetimeIndex, - history_mins: int, - sample_freq_mins: int + date_range: pd.DatetimeIndex, history_mins: int, sample_freq_mins: int ) -> pd.DatetimeIndex: return find_valid_t0_times(date_range, history_mins, 0, sample_freq_mins) @@ -196,14 +187,13 @@ def __getitem__(self, key: DataIndex): return self._get_datetime(t0) - def run_backtest( model: MLModel, dataset: BacktestSatelliteDataset, batch_size: int = 1, num_workers: int = 0, batch_limit: int | None = None, - agg_batches: int = 1 + agg_batches: int = 1, ) -> None: """Calculate the scoreboard metrics for the given model on the validation dataset. @@ -232,31 +222,30 @@ def run_backtest( da_y_hats = [] save_batch_num = 0 - attrs_dict = {k:v for k,v in dataset.ds.attrs.items()} + attrs_dict = dict(dataset.ds.attrs.items()) attrs_dict["model_checkpoint"] = model.checkpoint_dir_path for i, (X, t) in tqdm(enumerate(backtest_dataloader), total=loop_steps): - y_hat = model(X) init_times = pd.DatetimeIndex(t) steps = pd.timedelta_range("15min", periods=y_hat.shape[2], freq="15min") da_y_hat = xr.DataArray( - y_hat, - dims=["init_time", "variable", "step", "y_geostationary", "x_geostationary"], + y_hat, + dims=["init_time", "variable", "step", "y_geostationary", "x_geostationary"], coords={ "init_time": init_times, "variable": dataset.ds.variable, "step": steps, "y_geostationary": dataset.ds.y_geostationary, "x_geostationary": dataset.ds.x_geostationary, - } + }, ).chunk( { - "init_time": 1, - "variable":-1, - "step":-1, - "y_geostationary": 100, + "init_time": 1, + "variable": -1, + "step": -1, + "y_geostationary": 100, "x_geostationary": 100, } ) @@ -264,32 +253,28 @@ def run_backtest( da_y_hats.append(da_y_hat) del da_y_hat - - if len(da_y_hats)==agg_batches or i==loop_steps-1: + if len(da_y_hats) == agg_batches or i == loop_steps - 1: da_y_hats = xr.concat(da_y_hats, dim="init_time") da_y_hats.attrs = attrs_dict da_y_hats = da_y_hats.to_dataset(name="sat_pred") - + da_y_hats.to_zarr( - f"{save_dir}/part_{save_batch_num}.zarr", + f"{save_dir}/part_{save_batch_num}.zarr", mode="w", - encoding={var: {'compressor': compressor} for var in da_y_hats.data_vars}, + encoding={var: {"compressor": compressor} for var in da_y_hats.data_vars}, ) - + save_batch_num += 1 da_y_hats = [] - if batch_limit is not None and i == batch_limit: break -if __name__=="__main__": - - +if __name__ == "__main__": os.makedirs(save_dir, exist_ok=False) model = MLModel(checkpoint) @@ -302,12 +287,12 @@ def run_backtest( "/mnt/disks/all_data/sat/2022_nonhrv.zarr", "/mnt/disks/all_data/sat/2023_nonhrv.zarr", ], - start_time=None, + start_time=None, end_time=None, history_mins=model.history_mins, sample_freq_mins=15, - nan_to_num=model.data_config['nan_to_num'], - ) + nan_to_num=model.data_config["nan_to_num"], + ) run_backtest( model=model, diff --git a/scripts/validate_model.py b/scripts/validate_model.py index cc90591..105596b 100644 --- a/scripts/validate_model.py +++ b/scripts/validate_model.py @@ -1,13 +1,11 @@ -from cloudcasting.validation import validate -from cloudcasting.models import AbstractModel - import glob import hydra import torch +from cloudcasting.models import AbstractModel +from cloudcasting.validation import validate from pyaml_env import parse_config - checkpoint = "/home/jamesfulton/repos/sat_pred/checkpoints/ob9v9128" WANDB_PROJECT = "cloudcasting" WANDB_RUN_NAME = "simVP_2008-2016" @@ -21,7 +19,7 @@ def get_model_from_checkpoints( val_best: bool = True, ): """Load a model from its checkpoint directory - + Args: checkpoint_dir_path: Path to the checkpoint directory val_best: Whether to use the best performing checkpoint found during training, else uses @@ -37,61 +35,56 @@ def get_model_from_checkpoints( # Only one epoch (best) saved per model files = glob.glob(f"{checkpoint_dir_path}/epoch*.ckpt") if len(files) != 1: + msg = f"Found {len(files)} checkpoints @ {checkpoint_dir_path}/epoch*.ckpt. Expected one." raise ValueError( - f"Found {len(files)} checkpoints @ {checkpoint_dir_path}/epoch*.ckpt. Expected one." + msg ) checkpoint = torch.load(files[0], map_location="cpu", weights_only=True) else: - checkpoint = torch.load(f"{checkpoint_dir_path}/last.ckpt", map_location="cpu", weights_only=True) + checkpoint = torch.load( + f"{checkpoint_dir_path}/last.ckpt", map_location="cpu", weights_only=True + ) state_dict = checkpoint["state_dict"] lightning_wrapped_model.load_state_dict(state_dict=state_dict) - + # discard the lightning wrapper on the model - model = lightning_wrapped_model.model + model = lightning_wrapped_model.model # Check for data config data_config = parse_config(f"{checkpoint_dir_path}/data_config.yaml") - return model, model_config, data_config - # We define a new class that inherits from AbstractModel class MLModel(AbstractModel): """A persistence model which predicts a blury version of the most recent frame""" def __init__(self, checkpoint_dir_path: str) -> None: - - model, model_config, data_config = get_model_from_checkpoints(checkpoint) - - super().__init__(history_steps=12) + super().__init__(history_steps=12) self.model = model.to(DEVICE) self.model_config = model_config self.data_config = data_config self.checkpoint_dir_path = checkpoint_dir_path - def forward(self, X): # The input X is a numpy array with shape (batch_size, channels, time, height, width) - + X = torch.Tensor(X).to(DEVICE) - + with torch.no_grad(): y_hat = self.model(X).cpu().numpy() - + # Clip the values to be between 0 and 1 - y_hat = y_hat.clip(0, 1) + return y_hat.clip(0, 1) - return y_hat def hyperparameters_dict(self): - wandb_id = self.checkpoint_dir_path.split("/")[-1] params_dict = { "training_run_link": f"https://wandb.ai/openclimatefix/sat_pred/runs/{wandb_id}", @@ -100,8 +93,8 @@ def hyperparameters_dict(self): return params_dict -if __name__=="__main__": +if __name__ == "__main__": model = MLModel(checkpoint) validate( @@ -109,8 +102,8 @@ def hyperparameters_dict(self): data_path="/mnt/disks/sat_data_all/2022_test_nonhrv.zarr", wandb_project_name=WANDB_PROJECT, wandb_run_name=WANDB_RUN_NAME, - batch_size = 2, - num_workers = 10, - batch_limit = None, - nan_to_num = model.data_config['nan_to_num'] + batch_size=2, + num_workers=10, + batch_limit=None, + nan_to_num=model.data_config["nan_to_num"], ) diff --git a/src/sat_pred/__init__.py b/src/sat_pred/__init__.py index e24bcbc..e04da51 100644 --- a/src/sat_pred/__init__.py +++ b/src/sat_pred/__init__.py @@ -1,6 +1,7 @@ """ sat_pred: Repository for the forecasting of satellite data, using models developed by Open Climate Fix and The Alan Turing Institute. """ + from importlib.metadata import version __all__ = ("__version__",) diff --git a/src/sat_pred/load_model_from_checkpoint.py b/src/sat_pred/load_model_from_checkpoint.py index e2cb2b6..a8ced99 100644 --- a/src/sat_pred/load_model_from_checkpoint.py +++ b/src/sat_pred/load_model_from_checkpoint.py @@ -1,18 +1,18 @@ +from glob import glob -import torch import hydra +import torch from pyaml_env import parse_config -from glob import glob - checkpoint = "/home/jamesfulton/repos/sat_pred/checkpoints/ob9v9128" + def get_model_from_checkpoints( checkpoint_dir_path: str, val_best: bool = True, ): """Load a model from its checkpoint directory - + Args: checkpoint_dir_path: Path to the checkpoint directory val_best: Whether to use the best performing checkpoint found during training, else uses @@ -28,20 +28,20 @@ def get_model_from_checkpoints( # Only one epoch (best) saved per model files = glob(f"{checkpoint_dir_path}/epoch*.ckpt") if len(files) != 1: + msg = f"Found {len(files)} checkpoints @ {checkpoint_dir_path}/epoch*.ckpt. Expected one." raise ValueError( - f"Found {len(files)} checkpoints @ {checkpoint_dir_path}/epoch*.ckpt. Expected one." + msg ) checkpoint = torch.load(files[0], map_location="cpu") else: checkpoint = torch.load(f"{checkpoint_dir_path}/last.ckpt", map_location="cpu") lightning_wrapped_model.load_state_dict(state_dict=checkpoint["state_dict"]) - + # discard the lightning wrapper on the model - model = lightning_wrapped_model.model + model = lightning_wrapped_model.model # Check for data config data_config = parse_config(f"{checkpoint_dir_path}/data_config.yaml") - - return model, model_config, data_config \ No newline at end of file + return model, model_config, data_config diff --git a/src/sat_pred/loss.py b/src/sat_pred/loss.py index cd5a81d..1876ee8 100644 --- a/src/sat_pred/loss.py +++ b/src/sat_pred/loss.py @@ -3,6 +3,7 @@ import torch from torch.nn import functional as F + class LossFunction(ABC): """Loss function""" @@ -11,18 +12,19 @@ class LossFunction(ABC): @abstractmethod def name(self) -> str: """Return name of the loss function""" - pass @abstractmethod def __call__(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Return loss""" - pass + class MultiscaleMAE(LossFunction): """Multiscale Mean Absolute Error""" - def __init__(self, scales: list[tuple[int]]=[(1,1,1),(2,4,4)]): + def __init__(self, scales: list[tuple[int]] | None = None): """Multiscale Mean Absolute Error""" + if scales is None: + scales = [(1, 1, 1), (2, 4, 4)] self.scales = scales @property @@ -33,7 +35,7 @@ def __call__(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Return loss""" target = target.copy() - target[target==-1] = float('nan') + target[target == -1] = float("nan") loss = 0 diff --git a/src/sat_pred/models/earthformer_model.py b/src/sat_pred/models/earthformer_model.py index 92fbdf8..c97f8cf 100644 --- a/src/sat_pred/models/earthformer_model.py +++ b/src/sat_pred/models/earthformer_model.py @@ -2,13 +2,12 @@ class Earthformer(CuboidTransformerModel): - def forward(self, X, verbose=False): - # The cloudcasting dataloader created batches of shape: + # The cloudcasting dataloader created batches of shape: # (batch, channel, time, height, width) # Earthformer expects shape: (batch, time, height, width, channel) X = X.permute(0, 2, 3, 4, 1).contiguous() y_hat = super().forward(X, verbose=verbose) # Transpose back to cloudcasting shape - return y_hat.permute(0, 4, 1, 2, 3) \ No newline at end of file + return y_hat.permute(0, 4, 1, 2, 3) diff --git a/src/sat_pred/models/simvp_model.py b/src/sat_pred/models/simvp_model.py index 7dcd4b5..125fd1c 100644 --- a/src/sat_pred/models/simvp_model.py +++ b/src/sat_pred/models/simvp_model.py @@ -1,48 +1,43 @@ -"""Adapted from https://github.com/A4Bio/SimVP -""" +"""Adapted from https://github.com/A4Bio/SimVP""" import torch -from torch import nn import torch.nn.functional as F +from torch import nn class BasicConv2d(nn.Module): def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride, - padding, - transpose=False, - act_norm=False + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + transpose=False, + act_norm=False, ): - super(BasicConv2d, self).__init__() - + super().__init__() + if transpose: conv_layer = nn.ConvTranspose2d( - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, padding=padding, - output_padding=stride //2 + output_padding=stride // 2, ) else: conv_layer = nn.Conv2d( - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding + in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding ) - + layers = [conv_layer] - + if act_norm: layers.append(nn.GroupNorm(2, out_channels)) layers.append(nn.LeakyReLU(0.2)) - + self.model = nn.Sequential(*layers) def forward(self, x): @@ -51,19 +46,19 @@ def forward(self, x): class ConvSC(nn.Module): def __init__(self, C_in, C_out, stride, transpose=False, act_norm=True): - super(ConvSC, self).__init__() - + super().__init__() + if stride == 1: transpose = False - + self.model = BasicConv2d( - C_in, - C_out, - kernel_size=3, + C_in, + C_out, + kernel_size=3, stride=stride, - padding=1, - transpose=transpose, - act_norm=act_norm + padding=1, + transpose=transpose, + act_norm=act_norm, ) def forward(self, x): @@ -72,56 +67,51 @@ def forward(self, x): class GroupConv2d(nn.Module): def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride, - padding, - groups, - act_norm=False + self, in_channels, out_channels, kernel_size, stride, padding, groups, act_norm=False ): - super(GroupConv2d, self).__init__() - + super().__init__() + if in_channels % groups != 0: groups = 1 - + layers = [ nn.Conv2d( - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, padding=padding, - groups=groups + groups=groups, ) ] - + if act_norm: layers.append(nn.GroupNorm(groups, out_channels)) layers.append(nn.LeakyReLU(0.2, inplace=True)) - + self.model = nn.Sequential(*layers) - + def forward(self, x): return self.model(x) class Inception(nn.Module): - def __init__(self, C_in, C_hid, C_out, incep_ker=[3,5,7,11], groups=8): - super(Inception, self).__init__() + def __init__(self, C_in, C_hid, C_out, incep_ker=None, groups=8): + if incep_ker is None: + incep_ker = [3, 5, 7, 11] + super().__init__() self.conv1 = nn.Conv2d(C_in, C_hid, kernel_size=1, stride=1, padding=0) layers = [] for kernel_size in incep_ker: layers.append( GroupConv2d( - C_hid, - C_out, - kernel_size=kernel_size, - stride=1, - padding=kernel_size//2, - groups=groups, - act_norm=True + C_hid, + C_out, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + groups=groups, + act_norm=True, ) ) self.layers = nn.ModuleList(layers) @@ -134,35 +124,31 @@ def forward(self, x): return y -# I think this might have a problem for odd values of N when reverse=True +# I think this might have a problem for odd values of N when reverse=True def stride_generator(N, reverse=False): - strides = [1, 2]*10 - if reverse: + strides = [1, 2] * 10 + if reverse: return list(reversed(strides[:N])) - else: - return strides[:N] + return strides[:N] + - def stride_generator_new(N, reverse=False): - if reverse: - strides = [2, 1] - else: - strides = [1, 2] - - return (strides*((N+1)//2))[:N] + strides = [2, 1] if reverse else [1, 2] + + return (strides * ((N + 1) // 2))[:N] class Encoder(nn.Module): def __init__(self, C_in, C_hid, N_S): - super(Encoder,self).__init__() + super().__init__() strides = stride_generator(N_S) - + layers = [ConvSC(C_in, C_hid, stride=strides[0])] layers.extend([ConvSC(C_hid, C_hid, stride=s) for s in strides[1:]]) - + self.encoder_layers = nn.ModuleList(layers) - - def forward(self, x):# B*4, 3, 128, 128 + + def forward(self, x): # B*4, 3, 128, 128 enc1 = self.encoder_layers[0](x) latent = enc1 for layer in self.encoder_layers[1:]: @@ -171,48 +157,74 @@ def forward(self, x):# B*4, 3, 128, 128 class Decoder(nn.Module): - def __init__(self,C_hid, C_out, N_S): - super(Decoder,self).__init__() - strides = stride_generator(N_S, reverse=True) - + def __init__(self, C_hid, C_out, N_S): + super().__init__() + strides = stride_generator(N_S, reverse=True) + layers = [ConvSC(C_hid, C_hid, stride=s, transpose=True) for s in strides[:-1]] - layers.append(ConvSC(2*C_hid, C_hid, stride=strides[-1], transpose=True)) - + layers.append(ConvSC(2 * C_hid, C_hid, stride=strides[-1], transpose=True)) + self.decoder_layers = nn.ModuleList(layers) - + self.readout = nn.Conv2d(C_hid, C_out, 1) - + def forward(self, hid, enc1=None): - for i in range(0,len(self.decoder_layers)-1): + for i in range(len(self.decoder_layers) - 1): hid = self.decoder_layers[i](hid) - hid = hid[..., :enc1.shape[-2], :enc1.shape[-1]] + hid = hid[..., : enc1.shape[-2], : enc1.shape[-1]] Y = self.decoder_layers[-1](torch.cat([hid, enc1], dim=1)) - Y = self.readout(Y) - return Y + return self.readout(Y) + - class Mid_Xnet(nn.Module): - - def __init__(self, channel_in, channel_hid, N_T, incep_ker = [3,5,7,11], groups=8): - super(Mid_Xnet, self).__init__() + def __init__(self, channel_in, channel_hid, N_T, incep_ker=None, groups=8): + if incep_ker is None: + incep_ker = [3, 5, 7, 11] + super().__init__() self.N_T = N_T - enc_layers = [Inception(channel_in, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups)] - for i in range(1, N_T-1): - enc_layers.append(Inception(channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups)) - enc_layers.append(Inception(channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups)) + enc_layers = [ + Inception(channel_in, channel_hid // 2, channel_hid, incep_ker=incep_ker, groups=groups) + ] + for _i in range(1, N_T - 1): + enc_layers.append( + Inception( + channel_hid, channel_hid // 2, channel_hid, incep_ker=incep_ker, groups=groups + ) + ) + enc_layers.append( + Inception( + channel_hid, channel_hid // 2, channel_hid, incep_ker=incep_ker, groups=groups + ) + ) - dec_layers = [Inception(channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups)] - for i in range(1, N_T-1): - dec_layers.append(Inception(2*channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups)) - dec_layers.append(Inception(2*channel_hid, channel_hid//2, channel_in, incep_ker= incep_ker, groups=groups)) + dec_layers = [ + Inception( + channel_hid, channel_hid // 2, channel_hid, incep_ker=incep_ker, groups=groups + ) + ] + for _i in range(1, N_T - 1): + dec_layers.append( + Inception( + 2 * channel_hid, + channel_hid // 2, + channel_hid, + incep_ker=incep_ker, + groups=groups, + ) + ) + dec_layers.append( + Inception( + 2 * channel_hid, channel_hid // 2, channel_in, incep_ker=incep_ker, groups=groups + ) + ) self.enc = nn.Sequential(*enc_layers) self.dec = nn.Sequential(*dec_layers) def forward(self, x): B, T, C, H, W = x.shape - x = x.reshape(B, T*C, H, W) + x = x.reshape(B, T * C, H, W) # encoder skips = [] @@ -227,60 +239,57 @@ def forward(self, x): for i in range(1, self.N_T): z = self.dec[i](torch.cat([z, skips[-i]], dim=1)) - y = z.reshape(B, T, C, H, W) - return y - + return z.reshape(B, T, C, H, W) class SimVP(nn.Module): def __init__( - self, - num_channels, - history_len, - forecast_len, - spatial_size=(279, 386), - hid_S=16, - hid_T=256, + self, + num_channels, + history_len, + forecast_len, + spatial_size=(279, 386), + hid_S=16, + hid_T=256, N_S=4, - N_T=8, - incep_ker=[3,5,7,11], - groups=8 + N_T=8, + incep_ker=None, + groups=8, ): - super(SimVP, self).__init__() - + if incep_ker is None: + incep_ker = [3, 5, 7, 11] + super().__init__() + self.enc = Encoder(num_channels, hid_S, N_S) - self.hid = Mid_Xnet(history_len*hid_S, hid_T, N_T, incep_ker, groups) + self.hid = Mid_Xnet(history_len * hid_S, hid_T, N_T, incep_ker, groups) self.dec = Decoder(hid_S, num_channels, N_S) self.spatial_size = spatial_size - def forward(self, x_raw): - # Pad out to a multiple of downsample factor - #pad_top = pad_left = 0 - #downsample_factor = (N_S // 2)*2 - #pad_bottom = downsample_factor - (self.spatial_size[0] % downsample_factor) - #pad_right = downsample_factor - (self.spatial_size[1] % downsample_factor) - #x_raw = F.pad(x_raw, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0) + # pad_top = pad_left = 0 + # downsample_factor = (N_S // 2)*2 + # pad_bottom = downsample_factor - (self.spatial_size[0] % downsample_factor) + # pad_right = downsample_factor - (self.spatial_size[1] % downsample_factor) + # x_raw = F.pad(x_raw, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0) # (batch, channel, time, height, width) -> (batch, time, channel, height, width) - x_raw = x_raw.permute(0,2,1,3,4) - + x_raw = x_raw.permute(0, 2, 1, 3, 4) + B, T, C, H, W = x_raw.shape - x = x_raw.reshape(B*T, C, H, W) + x = x_raw.reshape(B * T, C, H, W) embed, skip = self.enc(x) _, C_, H_, W_ = embed.shape z = embed.view(B, T, C_, H_, W_) hid = self.hid(z) - hid = hid.reshape(B*T, C_, H_, W_) + hid = hid.reshape(B * T, C_, H_, W_) Y = self.dec(hid, skip) Y = Y.reshape(B, T, C, H, W) - - Y = Y.permute(0,2,1,3,4) + + return Y.permute(0, 2, 1, 3, 4) # Remove padding # Y = Y[..., :self.spatial_size[0]-pad_bottom, :self.spatial_size[1]-pad_right] - return Y \ No newline at end of file diff --git a/src/sat_pred/optimizers.py b/src/sat_pred/optimizers.py index e8c4241..71102fd 100644 --- a/src/sat_pred/optimizers.py +++ b/src/sat_pred/optimizers.py @@ -1,6 +1,8 @@ import torch + from sat_pred.loss import LossFunction + class AdamW: """AdamW optimizer""" @@ -13,7 +15,7 @@ def __call__(self, model): """Return optimizer""" return torch.optim.AdamW(model.parameters(), lr=self.lr, **self.kwargs) - + class AdamWReduceLROnPlateau: """AdamW optimizer and reduce on plateau scheduler""" @@ -29,17 +31,15 @@ def __init__( self.opt_kwargs = opt_kwargs def __call__(self, model): - - opt = torch.optim.AdamW( - model.parameters(), lr=self.lr, **self.opt_kwargs - ) + opt = torch.optim.AdamW(model.parameters(), lr=self.lr, **self.opt_kwargs) if isinstance(model.target_loss, str): monitor = f"{model.target_loss}/val" elif isinstance(model.target_loss, LossFunction): monitor = f"{model.target_loss.name}/val" else: - raise ValueError(f"Unknown loss type: {type(model)}") + msg = f"Unknown loss type: {type(model)}" + raise ValueError(msg) sch = { "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( @@ -51,4 +51,4 @@ def __call__(self, model): "monitor": monitor, } - return [opt], [sch] \ No newline at end of file + return [opt], [sch] diff --git a/src/sat_pred/ssim.py b/src/sat_pred/ssim.py index 14bc9cc..f7a26e0 100644 --- a/src/sat_pred/ssim.py +++ b/src/sat_pred/ssim.py @@ -12,12 +12,13 @@ import numpy as np from sat_pred.ssim import SSIM3D import torch + torch.manual_seed(1) # Create some sample data -n_samples = 1 # only 1 sample so compatible with skimage function +n_samples = 1 # only 1 sample so compatible with skimage function n_channels = 5 -n_timesteps = 1 # only 1 time step so compatible with skimage function +n_timesteps = 1 # only 1 time step so compatible with skimage function x_dim = 400 y_dim = 700 @@ -27,32 +28,31 @@ # Compute SSIM map with this class and squeeze out the extra dimensions -ssim_map1 = SSIM3D()(y_hat, y).numpy().squeeze((0,2)) +ssim_map1 = SSIM3D()(y_hat, y).numpy().squeeze((0, 2)) # Compute SSIM map with skimage _, ssim_map2 = structural_similarity( - y_hat[0, :, 0].numpy(), # remove extra dimensions and convert to numpy - y[0, :, 0].numpy(), # remove extra dimensions and convert to numpy - channel_axis=0, - # The settings below are required to match the two calculations + y_hat[0, :, 0].numpy(), # remove extra dimensions and convert to numpy + y[0, :, 0].numpy(), # remove extra dimensions and convert to numpy + channel_axis=0, + # The settings below are required to match the two calculations data_range=1, gaussian_weights=True, - full=True, + full=True, sigma=1.5, use_sample_covariance=False, ) # The skimage version of SSIM uses reflection padding when applying the gaussian kernel -# whilst our version uses zero padding. We expect the two versions to be the same +# whilst our version uses zero padding. We expect the two versions to be the same def trim_border(x, num_pixels): - return x[..., num_pixels:x.shape[-2]-num_pixels, num_pixels:x.shape[-1]-num_pixels] + return x[..., num_pixels : x.shape[-2] - num_pixels, num_pixels : x.shape[-1] - num_pixels] + # If we don't trim the border ~96% of the SSIM values are the same np.isclose( - trim_border(ssim_map1, num_pixels=0), - trim_border(ssim_map2, num_pixels=0), - atol=1e-05 + trim_border(ssim_map1, num_pixels=0), trim_border(ssim_map2, num_pixels=0), atol=1e-05 ).mean() # >> 0.9610714285714286 @@ -60,23 +60,22 @@ def trim_border(x, num_pixels): # In both calculations we have used a window size of 11, so a padding size of 11//2 = 5 np.isclose( - trim_border(ssim_map1, num_pixels=5), - trim_border(ssim_map2, num_pixels=5), - atol=1e-05 + trim_border(ssim_map1, num_pixels=5), trim_border(ssim_map2, num_pixels=5), atol=1e-05 ).mean() # >> 1.0 - ``` """ +from collections.abc import Sequence + import torch -from torch import nn import torch.nn.functional as F -from collections.abc import Sequence +from torch import nn + def create_1d_gaussian_kernel(kernel_size: int, sigma: float) -> torch.Tensor: """Create a 1D gaussian kernel - + Args: kernel_size: The size of the kernel sigma: The standard deviation @@ -84,26 +83,25 @@ def create_1d_gaussian_kernel(kernel_size: int, sigma: float) -> torch.Tensor: ksize_half = (kernel_size - 1) * 0.5 kernel = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) gauss = torch.exp(-0.5 * (kernel / sigma).pow(2)) - return (gauss / gauss.sum()) + return gauss / gauss.sum() def create_2d_gaussian_kernel( - kernel_size: int | list[int], - sigma: float | list[float] - ) -> torch.Tensor: + kernel_size: int | list[int], sigma: float | list[float] +) -> torch.Tensor: """Create a 2D gaussian kernel - + Args: kernel_size: The size of the kernel sigma: The standard deviation of the kernel """ - + if isinstance(kernel_size, int): kernel_size = [kernel_size, kernel_size] - + if isinstance(sigma, float): sigma = [sigma, sigma] - + kernel_x = create_1d_gaussian_kernel(kernel_size[0], sigma[0]).unsqueeze(dim=1) kernel_y = create_1d_gaussian_kernel(kernel_size[1], sigma[1]).unsqueeze(dim=0) @@ -112,11 +110,11 @@ def create_2d_gaussian_kernel( class SSIM3D(nn.Module): def __init__( - self, - kernel_size: int | list[int] = 11, - sigma: float | list[float] = 1.5, + self, + kernel_size: int | list[int] = 11, + sigma: float | list[float] = 1.5, k1: float = 0.01, - k2: float = 0.03, + k2: float = 0.03, data_range: float = 1, ): """Module to compute the SSIM between two sequences of images @@ -127,17 +125,17 @@ def __init__( k1: Algorithm parameter, K1 (small constant, see [a]). k2: Algorithm parameter, K2 (small constant, see [a]). data_range: The range of the data - + References: - [a] Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). Image quality - assessment: From error visibility to structural similarity. IEEE Transactions on + [a] Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). Image quality + assessment: From error visibility to structural similarity. IEEE Transactions on Image Processing, 13, 600-612. DOI:10.1109/TIP.2003.819861 """ - super(SSIM3D, self).__init__() + super().__init__() assert data_range > 0 assert k1 > 0 assert k2 > 0 - + if isinstance(kernel_size, int): kernel_size = [kernel_size, kernel_size] elif isinstance(kernel_size, Sequence): @@ -147,22 +145,23 @@ def __init__( sigma = [sigma, sigma] elif isinstance(sigma, Sequence): sigma = sigma - + self.c1 = (k1 * data_range) ** 2 self.c2 = (k2 * data_range) ** 2 self.kernel = nn.Parameter( - data=create_2d_gaussian_kernel(kernel_size=kernel_size, sigma=sigma), - requires_grad=False - ) + data=create_2d_gaussian_kernel(kernel_size=kernel_size, sigma=sigma), + requires_grad=False, + ) + + self.pad = [ + 0, + ] + [(k - 1) // 2 for k in kernel_size] - self.pad = [0,] + [(k - 1) // 2 for k in kernel_size] - - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Compute the SSIM between two sequences of images - - + + Args: x: The predicted sequence of images y: The true sequence of images @@ -178,19 +177,21 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # whilst only convolving across the spatial dimensions kernel = self.kernel.expand(num_channels, 1, 1, -1, -1) - kernal_inputs = torch.cat([x, y, x**2, y**2, x*y]) + kernal_inputs = torch.cat([x, y, x**2, y**2, x * y]) kernel_outputs = F.conv3d(kernal_inputs, kernel, padding=self.pad, groups=num_channels) del kernal_inputs - - ux, uy, uxx, uyy, uxy = [kernel_outputs[i*batch_size:(i+1)*batch_size] for i in range(5)] - - vx = (uxx - ux * ux) - vy = (uyy - uy * uy) - vxy = (uxy - ux * uy) - + + ux, uy, uxx, uyy, uxy = [ + kernel_outputs[i * batch_size : (i + 1) * batch_size] for i in range(5) + ] + + vx = uxx - ux * ux + vy = uyy - uy * uy + vxy = uxy - ux * uy + a1 = 2 * ux * uy + self.c1 - a2 = 2 * vxy + self.c2 + a2 = 2 * vxy + self.c2 b1 = ux**2 + uy**2 + self.c1 - b2 = vx + vy + self.c2 + b2 = vx + vy + self.c2 - return (a1 * a2) / (b1 * b2) \ No newline at end of file + return (a1 * a2) / (b1 * b2) diff --git a/src/sat_pred/training_module.py b/src/sat_pred/training_module.py index 32620a0..4d3eee2 100644 --- a/src/sat_pred/training_module.py +++ b/src/sat_pred/training_module.py @@ -1,19 +1,18 @@ """Training class to wrap model and optimizer""" +import lightning.pytorch as pl import numpy as np import pandas as pd import torch import torch.nn.functional as F -from torch.utils.data import default_collate -import lightning.pytorch as pl - import wandb +from torch.utils.data import default_collate -from sat_pred.ssim import SSIM3D -from sat_pred.optimizers import AdamWReduceLROnPlateau from sat_pred.loss import LossFunction +from sat_pred.optimizers import AdamWReduceLROnPlateau +from sat_pred.ssim import SSIM3D + - class MetricAccumulator: """Dictionary of metrics accumulator. @@ -27,7 +26,7 @@ class MetricAccumulator: def __init__(self) -> None: """Dictionary of metrics accumulator.""" self._metrics = {} - + def __bool__(self) -> None: return self._metrics != {} @@ -51,7 +50,7 @@ def check_nan_and_finite(X: torch.Tensor, y: torch.Tensor, y_hat: torch.Tensor) if X is not None: assert not np.isnan(X.cpu().numpy()).any(), "NaNs in X" assert np.isfinite(X.cpu().numpy()).all(), "infs in X" - + if y is not None: assert not np.isnan(y.cpu().numpy()).any(), "NaNs in y" assert np.isfinite(y.cpu().numpy()).all(), "infs in y" @@ -62,14 +61,14 @@ def check_nan_and_finite(X: torch.Tensor, y: torch.Tensor, y_hat: torch.Tensor) def upload_video( - y: torch.Tensor, - y_hat: torch.Tensor, - video_name: str, - channel_nums: list[int] = [8, 1], - fps: int=4 + y: torch.Tensor, + y_hat: torch.Tensor, + video_name: str, + channel_nums: list[int] | None = None, + fps: int = 4, ) -> None: """Upload prediction video to wandb - + Args: y: The true future satellite sequence y_hat: The predicted future satellite sequence @@ -77,31 +76,32 @@ def upload_video( channel_nums: The channel numbers to log fps: The frames per second of the video """ + if channel_nums is None: + channel_nums = [8, 1] y = y.cpu().numpy() y_hat = y_hat.cpu().numpy() channel_frames = [] - + for channel_num in channel_nums: - y_frames = y.transpose(1,0,2,3)[:, channel_num:channel_num+1, ::-1, ::-1] - y_hat_frames = y_hat.transpose(1,0,2,3)[:, channel_num:channel_num+1, ::-1, ::-1] + y_frames = y.transpose(1, 0, 2, 3)[:, channel_num : channel_num + 1, ::-1, ::-1] + y_hat_frames = y_hat.transpose(1, 0, 2, 3)[:, channel_num : channel_num + 1, ::-1, ::-1] channel_frames.append(np.concatenate([y_hat_frames, y_frames], axis=3)) - + channel_frames = np.concatenate(channel_frames, axis=2) channel_frames = channel_frames.clip(0, 1) - channel_frames = np.repeat(channel_frames, 3, axis=1)*255 + channel_frames = np.repeat(channel_frames, 3, axis=1) * 255 channel_frames = channel_frames.astype(np.uint8) wandb.log({video_name: wandb.Video(channel_frames, fps=fps)}) - - -class TrainingModule(pl.LightningModule): + +class TrainingModule(pl.LightningModule): def __init__( self, model: torch.nn.Module, target_loss: str = "MAE", - optimizer = AdamWReduceLROnPlateau(), - video_plot_t0_times: list[str] = None, + optimizer=AdamWReduceLROnPlateau(), + video_plot_t0_times: list[str] | None = None, video_crop_plots=None, multi_gpu: bool = False, ): @@ -113,15 +113,15 @@ def __init__( optimizer: The optimizer to use. Defaults to AdamWReduceLROnPlateau(). """ super().__init__() - + assert target_loss in ["MAE", "MSE", "SSIM"] or isinstance(target_loss, LossFunction) self.model = model self._optimizer = optimizer - + self.ssim_func = SSIM3D() self.target_loss = target_loss - + self._accumulated_metrics = MetricAccumulator() self.video_plot_t0_times = video_plot_t0_times @@ -129,29 +129,27 @@ def __init__( self.multi_gpu = multi_gpu def _calculate_common_losses( - self, - y: torch.Tensor, - y_hat: torch.Tensor + self, y: torch.Tensor, y_hat: torch.Tensor ) -> dict[str, torch.Tensor]: """Calculate losses common to train and val - + Args: y: The true future satellite sequence y_hat: The predicted future satellite sequence """ - + losses = {} - - mask = y==-1 + + mask = y == -1 mse_loss = F.mse_loss(y_hat, y, reduction="none")[~mask].mean() mae_loss = F.l1_loss(y_hat, y, reduction="none")[~mask].mean() - ssim_loss = (1-self.ssim_func(y_hat, y))[~mask].mean() # need to maximise SSIM + ssim_loss = (1 - self.ssim_func(y_hat, y))[~mask].mean() # need to maximise SSIM losses = { - "MSE": mse_loss, - "MAE": mae_loss, - "SSIM": ssim_loss, + "MSE": mse_loss, + "MAE": mae_loss, + "SSIM": ssim_loss, } if isinstance(self.target_loss, LossFunction): @@ -160,21 +158,18 @@ def _calculate_common_losses( return losses def _calculate_val_losses( - self, - y: torch.Tensor, - y_hat: torch.Tensor + self, y: torch.Tensor, y_hat: torch.Tensor ) -> dict[str, torch.Tensor]: """Calculate additional validation losses - + Args: y: The true future satellite sequence y_hat: The predicted future satellite sequence """ - losses = {} + return {} + - return losses - def _training_accumulate_log(self, losses): """Internal function to accumulate training batches and log results. @@ -196,9 +191,9 @@ def _training_accumulate_log(self, losses): def training_step(self, batch, batch_idx: int) -> None | torch.Tensor: """Run training step""" - + X, y = batch - + y_hat = self.model(X) del X @@ -211,32 +206,30 @@ def training_step(self, batch, batch_idx: int) -> None | torch.Tensor: train_loss = losses[f"{self.target_loss.name}/train"] else: train_loss = losses[f"{self.target_loss}/train"] - + # Occasionally y will be entirely NaN and we have no training targets. So the train loss # will also be NaN. if torch.isnan(train_loss).item(): print("\n\nTraining loss is nan\n\n") if self.multi_gpu: - # For multi-GPU we need to return some kind of loss - return F.l1_loss(y_hat*0, y_hat*0) - else: - # For single GPU we return None so lightning skips this train step - return None - else: - return train_loss - + # For multi-GPU we need to return some kind of loss + return F.l1_loss(y_hat * 0, y_hat * 0) + # For single GPU we return None so lightning skips this train step + return None + return train_loss + def validation_step(self, batch: dict, batch_idx: int): """Run validation step""" X, y = batch y_hat = self.model(X) del X - + losses = self._calculate_common_losses(y, y_hat) losses.update(self._calculate_val_losses(y, y_hat)) # Rename and convert metrics to float losses = {f"{k}/val": v.item() for k, v in losses.items()} - + # Occasionally y will be entirely NaN and we have no training targets. So the val loss # will also be NaN. We filter these out non_nan_losses = {k: v for k, v in losses.items() if not np.isnan(v)} @@ -246,26 +239,24 @@ def validation_step(self, batch: dict, batch_idx: int): on_step=False, on_epoch=True, ) - + def on_validation_epoch_start(self): - # Upload videos of the first three validation samples val_dataset = self.trainer.val_dataloaders.dataset - + if self.video_plot_t0_times is not None: dates = pd.to_datetime(list(self.video_plot_t0_times)) - X, y = default_collate([val_dataset[date]for date in dates]) + X, y = default_collate([val_dataset[date] for date in dates]) X = X.to(self.device) y = y.to(self.device) - + with torch.no_grad(): y_hat = self.model(X) assert val_dataset.nan_to_num, val_dataset.nan_to_num - - for i in range(len(dates)): + for i in range(len(dates)): for channel_num in [1, 8]: channel_name = val_dataset.ds.variable.values[channel_num] video_name = f"val_sample_videos/{dates[i]}_{channel_name}" @@ -273,7 +264,7 @@ def on_validation_epoch_start(self): if self.video_crop_plots is not None: dates = pd.to_datetime([x["date"] for x in self.video_crop_plots]) - X, y = default_collate([val_dataset[date]for date in dates]) + X, y = default_collate([val_dataset[date] for date in dates]) X = X.to(self.device) y = y.to(self.device) @@ -281,7 +272,6 @@ def on_validation_epoch_start(self): y_hat = self.model(X) for n in range(len(self.video_crop_plots)): - date = dates[n] channel_num = 8 channel_name = val_dataset.ds.variable.values[channel_num] @@ -292,21 +282,19 @@ def on_validation_epoch_start(self): channel_name = val_dataset.ds.variable.values[channel_num] video_name = f"val_close_up_sample_videos/{date}_{channel_name}_{i=}_{j=}_{s=}" - i_slice = slice(max(0, i-s//2), i+s//2) - j_slice = slice(max(0, j-s//2), j+s//2) + i_slice = slice(max(0, i - s // 2), i + s // 2) + j_slice = slice(max(0, j - s // 2), j + s // 2) upload_video( - y[n, ..., i_slice, j_slice], + y[n, ..., i_slice, j_slice], y_hat[n, ..., i_slice, j_slice], - video_name, - channel_nums=[channel_num] + video_name, + channel_nums=[channel_num], ) - - def on_validation_epoch_end(self): # Clear cache at the end of validation if torch.cuda.is_available(): torch.cuda.empty_cache() - + def configure_optimizers(self): - return self._optimizer(self) \ No newline at end of file + return self._optimizer(self) diff --git a/train.py b/train.py index 7a343ca..38a4114 100644 --- a/train.py +++ b/train.py @@ -2,10 +2,14 @@ if __name__ == "__main__": import torch.multiprocessing as mp + mp.set_start_method("spawn", force=True) import os + import hydra +import rich.syntax +import rich.tree import torch from lightning.pytorch import ( Callback, @@ -17,11 +21,8 @@ from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.loggers import Logger from lightning.pytorch.loggers.wandb import WandbLogger -from omegaconf import DictConfig, OmegaConf - -import rich.syntax -import rich.tree from lightning.pytorch.utilities import rank_zero_only +from omegaconf import DictConfig, OmegaConf from sat_pred.load_model_from_checkpoint import get_model_from_checkpoints from sat_pred.loss import LossFunction @@ -29,24 +30,26 @@ # TODO: is this line needed? torch.set_default_dtype(torch.float32) + def resolve_loss_name(loss): """Return the desired metric to monitor based on the loss being used. The adds the option to use something like: monitor: "${resolve_loss_name:${model.output_quantiles}}" """ - + if isinstance(loss, str): return loss - else: - loss = hydra.utils.instantiate(loss, _convert_='all') - if isinstance(loss, LossFunction): - return loss.name - else: - raise ValueError(f"Unknown loss type: {type(loss)}") + loss = hydra.utils.instantiate(loss, _convert_="all") + if isinstance(loss, LossFunction): + return loss.name + msg = f"Unknown loss type: {type(loss)}" + raise ValueError(msg) + OmegaConf.register_new_resolver("resolve_loss_name", resolve_loss_name) + @rank_zero_only def print_config( config: DictConfig, @@ -84,9 +87,8 @@ def print_config( rich.print(tree) -@rank_zero_only - +@rank_zero_only @hydra.main(config_path="../configs/", config_name="config.yaml", version_base="1.2") def train(config: DictConfig): """Train the model using parameters in the supplied config files. @@ -100,26 +102,22 @@ def train(config: DictConfig): # Set seed for random number generators in pytorch, numpy and python.random if "seed" in config: seed_everything(config.seed, workers=True) - if config.model.model.get("from_pretrained", False): - # Load the model from the checkpoint torch_model, model_config, data_config = get_model_from_checkpoints( - config.model.model.checkpoint_dir, - val_best=config.model.model.val_best + config.model.model.checkpoint_dir, val_best=config.model.model.val_best ) # Overwtie the model config with the loaded model config config.model.model = OmegaConf.create(model_config).model - # Create a new lightning wrapped model + # Create a new lightning wrapped model model: LightningModule = hydra.utils.instantiate(config.model) # Replace the untrained model with the loaded model model.model = torch_model - else: # Instantiate the model model: LightningModule = hydra.utils.instantiate(config.model) @@ -155,7 +153,7 @@ def train(config: DictConfig): # Need to call the .experiment property to initialise the logger wandb_logger.experiment - # skip for non-rank-0 processes: + # skip for non-rank-0 processes: # see https://github.com/Lightning-AI/pytorch-lightning/issues/13166#issuecomment-1139765549 if wandb_logger.version is None: break @@ -173,8 +171,8 @@ def train(config: DictConfig): break # Instantiate the datamodule - datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule, _convert_='all') - + datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule, _convert_="all") + datamodule.zarr_path = list(datamodule.zarr_path) # Instantiate the trainer @@ -187,7 +185,7 @@ def train(config: DictConfig): # Train the model trainer.fit(model=model, datamodule=datamodule) - - -if __name__ == "__main__": + + +if __name__ == "__main__": train()