diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..fe58f73b --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,17 @@ +name: Lint + +on: + push: + branches: [main] + pull_request: + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/setup-uv@v5 + - run: uv python install 3.10 + - run: uv pip install ruff + - run: uv run ruff check . + - run: uv run ruff format --check . diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..700cbc40 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,7 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.12.2 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/mattergen/common/data/callback.py b/mattergen/common/data/callback.py index 9191e8c3..80df7407 100644 --- a/mattergen/common/data/callback.py +++ b/mattergen/common/data/callback.py @@ -11,6 +11,7 @@ from mattergen.denoiser import GemNetTDenoiser from mattergen.diffusion.lightning_module import DiffusionLightningModule + TensorOrStringType = TypeVar("TensorOrStringType", torch.Tensor, list[str]) @@ -37,7 +38,7 @@ def _compute_property_scalers( property_names = [p.name for p in property_embeddings.values() if not isinstance(p.scaler, torch.nn.Identity)] if len(property_names) == 0: return - for batch in tqdm(datamodule.train_dataloader(), desc=f"Fitting property scalers"): + for batch in tqdm(datamodule.train_dataloader(), desc="Fitting property scalers"): for property_name in property_names: # concat all values in train dataset for this given property property_values[property_name].append(batch[property_name]) diff --git a/mattergen/common/data/collate.py b/mattergen/common/data/collate.py index e4b0e22f..b99db774 100644 --- a/mattergen/common/data/collate.py +++ b/mattergen/common/data/collate.py @@ -4,7 +4,6 @@ import warnings from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, overload - from torch import Tensor from torch_geometric.data import Batch, Data from typing_extensions import TypeGuard diff --git a/mattergen/common/data/dataset_transform.py b/mattergen/common/data/dataset_transform.py index 46620cb0..710eb7a3 100644 --- a/mattergen/common/data/dataset_transform.py +++ b/mattergen/common/data/dataset_transform.py @@ -3,6 +3,7 @@ import numpy as np from numpy.typing import NDArray + from mattergen.common.data.dataset import BaseDataset # Dataset transforms diff --git a/mattergen/diffusion/lightning_module.py b/mattergen/diffusion/lightning_module.py index e4066d0a..dcce7f00 100644 --- a/mattergen/diffusion/lightning_module.py +++ b/mattergen/diffusion/lightning_module.py @@ -3,10 +3,8 @@ from __future__ import annotations -from collections import deque from typing import Any, Dict, Generic, Optional, Protocol, Sequence, TypeVar, Union -import numpy as np import pytorch_lightning as pl import torch from hydra.errors import InstantiationException @@ -14,7 +12,6 @@ from omegaconf import DictConfig from pytorch_lightning.utilities.types import STEP_OUTPUT from torch.optim import AdamW, Optimizer -from tqdm import tqdm from mattergen.diffusion.config import Config from mattergen.diffusion.data.batched_data import BatchedData diff --git a/mattergen/diffusion/sampling/classifier_free_guidance.py b/mattergen/diffusion/sampling/classifier_free_guidance.py index ecd0f39c..9790221d 100644 --- a/mattergen/diffusion/sampling/classifier_free_guidance.py +++ b/mattergen/diffusion/sampling/classifier_free_guidance.py @@ -5,8 +5,8 @@ import torch -from mattergen.diffusion.sampling.pc_sampler import Diffusable, PredictorCorrector from mattergen.common.data.collate import collate +from mattergen.diffusion.sampling.pc_sampler import Diffusable, PredictorCorrector BatchTransform = Callable[[Diffusable], Diffusable] diff --git a/mattergen/diffusion/timestep_samplers.py b/mattergen/diffusion/timestep_samplers.py index 8c0f0ce2..b57ce238 100644 --- a/mattergen/diffusion/timestep_samplers.py +++ b/mattergen/diffusion/timestep_samplers.py @@ -5,8 +5,6 @@ import torch -from mattergen.diffusion.corruption.sde_lib import SDE - class TimestepSampler(Protocol): min_t: float diff --git a/mattergen/diffusion/wrapped/wrapped_normal_loss.py b/mattergen/diffusion/wrapped/wrapped_normal_loss.py index f1085dab..1d799896 100644 --- a/mattergen/diffusion/wrapped/wrapped_normal_loss.py +++ b/mattergen/diffusion/wrapped/wrapped_normal_loss.py @@ -5,7 +5,7 @@ import torch -from mattergen.diffusion.corruption.sde_lib import SDE, maybe_expand +from mattergen.diffusion.corruption.sde_lib import SDE from mattergen.diffusion.data.batched_data import BatchedData from mattergen.diffusion.training.field_loss import aggregate_per_sample diff --git a/mattergen/evaluation/metrics/energy.py b/mattergen/evaluation/metrics/energy.py index 2e21b1b0..cdeb6a71 100644 --- a/mattergen/evaluation/metrics/energy.py +++ b/mattergen/evaluation/metrics/energy.py @@ -9,12 +9,10 @@ import numpy.typing from pandas import DataFrame from pymatgen.analysis.phase_diagram import PhaseDiagram -from pymatgen.entries.compatibility import MaterialsProject2020Compatibility from tqdm import tqdm from mattergen.evaluation.metrics.core import BaseAggregateMetric, BaseMetric, BaseMetricsCapability from mattergen.evaluation.metrics.structure import StructureMetricsCapability -from mattergen.evaluation.reference.correction_schemes import TRI110Compatibility2024 from mattergen.evaluation.reference.reference_dataset import ReferenceDataset from mattergen.evaluation.utils.globals import DEFAULT_STABILITY_THRESHOLD from mattergen.evaluation.utils.logging import logger diff --git a/mattergen/evaluation/metrics/evaluator.py b/mattergen/evaluation/metrics/evaluator.py index 5c922662..64016a7b 100644 --- a/mattergen/evaluation/metrics/evaluator.py +++ b/mattergen/evaluation/metrics/evaluator.py @@ -7,7 +7,7 @@ from functools import cached_property from inspect import getmembers, isclass from pathlib import Path -from typing import Literal, Sequence, Type, TypeVar +from typing import Literal, Type, TypeVar import numpy.typing import pandas as pd diff --git a/mattergen/evaluation/metrics/structure.py b/mattergen/evaluation/metrics/structure.py index fa63be85..b3716c4c 100644 --- a/mattergen/evaluation/metrics/structure.py +++ b/mattergen/evaluation/metrics/structure.py @@ -16,7 +16,6 @@ from pymatgen.core.composition import Element from pymatgen.core.structure import Structure from pymatgen.symmetry.analyzer import SpacegroupAnalyzer -from scipy.stats import wasserstein_distance from smact.screening import pauling_test from tqdm import tqdm diff --git a/mattergen/evaluation/reference/reference_dataset_serializer.py b/mattergen/evaluation/reference/reference_dataset_serializer.py index 4ddaee79..2b998a48 100644 --- a/mattergen/evaluation/reference/reference_dataset_serializer.py +++ b/mattergen/evaluation/reference/reference_dataset_serializer.py @@ -3,14 +3,13 @@ import gzip import os -import pickle import shutil import weakref from collections import defaultdict from functools import cached_property from pathlib import Path from tempfile import mkdtemp -from typing import Any, DefaultDict, Iterator, Mapping +from typing import DefaultDict, Iterator, Mapping import lmdb # type: ignore [import] from monty.json import MontyDecoder diff --git a/mattergen/generator.py b/mattergen/generator.py index 465f45c1..8c4c5912 100644 --- a/mattergen/generator.py +++ b/mattergen/generator.py @@ -9,7 +9,6 @@ import ase.io import hydra -import torch from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf from pymatgen.core.structure import Structure @@ -21,6 +20,7 @@ from mattergen.common.data.condition_factory import ConditionLoader from mattergen.common.data.num_atoms_distribution import NUM_ATOMS_DISTRIBUTIONS from mattergen.common.data.types import TargetProperty +from mattergen.common.utils.data_classes import ProgressCallback from mattergen.common.utils.data_utils import lattice_matrix_to_params_torch from mattergen.common.utils.eval_utils import ( MatterGenCheckpointInfo, @@ -32,7 +32,6 @@ from mattergen.common.utils.globals import DEFAULT_SAMPLING_CONFIG_PATH, get_device from mattergen.diffusion.lightning_module import DiffusionLightningModule from mattergen.diffusion.sampling.pc_sampler import PredictorCorrector -from mattergen.common.utils.data_classes import ProgressCallback def draw_samples_from_sampler( diff --git a/mattergen/scripts/run.py b/mattergen/scripts/run.py index 882b9b35..200d4939 100644 --- a/mattergen/scripts/run.py +++ b/mattergen/scripts/run.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import json import logging import hydra diff --git a/pyproject.toml b/pyproject.toml index 2ce25044..f9301e54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -128,3 +128,17 @@ explicit = true [build-system] requires = ["setuptools <81"] build-backend = "setuptools.build_meta" + +[tool.ruff] +line-length = 120 +extend-exclude = ["*.ipynb"] + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = ["E501", "E402", "E731", "E722", "E741", "E721", "E701", "F841", "F403", "F405"] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401"] + +[tool.uv] +# Install with: uv sync