Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
name: Lint

on:
push:
branches: [main]
pull_request:

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v5
- run: uv python install 3.10
- run: uv pip install ruff
- run: uv run ruff check .
- run: uv run ruff format --check .
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.2
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
3 changes: 2 additions & 1 deletion mattergen/common/data/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from mattergen.denoiser import GemNetTDenoiser
from mattergen.diffusion.lightning_module import DiffusionLightningModule

TensorOrStringType = TypeVar("TensorOrStringType", torch.Tensor, list[str])


Expand All @@ -37,7 +38,7 @@ def _compute_property_scalers(
property_names = [p.name for p in property_embeddings.values() if not isinstance(p.scaler, torch.nn.Identity)]
if len(property_names) == 0:
return
for batch in tqdm(datamodule.train_dataloader(), desc=f"Fitting property scalers"):
for batch in tqdm(datamodule.train_dataloader(), desc="Fitting property scalers"):
for property_name in property_names:
# concat all values in train dataset for this given property
property_values[property_name].append(batch[property_name])
Expand Down
1 change: 0 additions & 1 deletion mattergen/common/data/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import warnings
from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, overload


from torch import Tensor
from torch_geometric.data import Batch, Data
from typing_extensions import TypeGuard
Expand Down
1 change: 1 addition & 0 deletions mattergen/common/data/dataset_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
from numpy.typing import NDArray

from mattergen.common.data.dataset import BaseDataset

# Dataset transforms
Expand Down
3 changes: 0 additions & 3 deletions mattergen/diffusion/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,15 @@

from __future__ import annotations

from collections import deque
from typing import Any, Dict, Generic, Optional, Protocol, Sequence, TypeVar, Union

import numpy as np
import pytorch_lightning as pl
import torch
from hydra.errors import InstantiationException
from hydra.utils import instantiate
from omegaconf import DictConfig
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torch.optim import AdamW, Optimizer
from tqdm import tqdm

from mattergen.diffusion.config import Config
from mattergen.diffusion.data.batched_data import BatchedData
Expand Down
2 changes: 1 addition & 1 deletion mattergen/diffusion/sampling/classifier_free_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import torch

from mattergen.diffusion.sampling.pc_sampler import Diffusable, PredictorCorrector
from mattergen.common.data.collate import collate
from mattergen.diffusion.sampling.pc_sampler import Diffusable, PredictorCorrector

BatchTransform = Callable[[Diffusable], Diffusable]

Expand Down
2 changes: 0 additions & 2 deletions mattergen/diffusion/timestep_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

import torch

from mattergen.diffusion.corruption.sde_lib import SDE


class TimestepSampler(Protocol):
min_t: float
Expand Down
2 changes: 1 addition & 1 deletion mattergen/diffusion/wrapped/wrapped_normal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from mattergen.diffusion.corruption.sde_lib import SDE, maybe_expand
from mattergen.diffusion.corruption.sde_lib import SDE
from mattergen.diffusion.data.batched_data import BatchedData
from mattergen.diffusion.training.field_loss import aggregate_per_sample

Expand Down
2 changes: 0 additions & 2 deletions mattergen/evaluation/metrics/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@
import numpy.typing
from pandas import DataFrame
from pymatgen.analysis.phase_diagram import PhaseDiagram
from pymatgen.entries.compatibility import MaterialsProject2020Compatibility
from tqdm import tqdm

from mattergen.evaluation.metrics.core import BaseAggregateMetric, BaseMetric, BaseMetricsCapability
from mattergen.evaluation.metrics.structure import StructureMetricsCapability
from mattergen.evaluation.reference.correction_schemes import TRI110Compatibility2024
from mattergen.evaluation.reference.reference_dataset import ReferenceDataset
from mattergen.evaluation.utils.globals import DEFAULT_STABILITY_THRESHOLD
from mattergen.evaluation.utils.logging import logger
Expand Down
2 changes: 1 addition & 1 deletion mattergen/evaluation/metrics/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from functools import cached_property
from inspect import getmembers, isclass
from pathlib import Path
from typing import Literal, Sequence, Type, TypeVar
from typing import Literal, Type, TypeVar

import numpy.typing
import pandas as pd
Expand Down
1 change: 0 additions & 1 deletion mattergen/evaluation/metrics/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from pymatgen.core.composition import Element
from pymatgen.core.structure import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from scipy.stats import wasserstein_distance
from smact.screening import pauling_test
from tqdm import tqdm

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@

import gzip
import os
import pickle
import shutil
import weakref
from collections import defaultdict
from functools import cached_property
from pathlib import Path
from tempfile import mkdtemp
from typing import Any, DefaultDict, Iterator, Mapping
from typing import DefaultDict, Iterator, Mapping

import lmdb # type: ignore [import]
from monty.json import MontyDecoder
Expand Down
3 changes: 1 addition & 2 deletions mattergen/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import ase.io
import hydra
import torch
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from pymatgen.core.structure import Structure
Expand All @@ -21,6 +20,7 @@
from mattergen.common.data.condition_factory import ConditionLoader
from mattergen.common.data.num_atoms_distribution import NUM_ATOMS_DISTRIBUTIONS
from mattergen.common.data.types import TargetProperty
from mattergen.common.utils.data_classes import ProgressCallback
from mattergen.common.utils.data_utils import lattice_matrix_to_params_torch
from mattergen.common.utils.eval_utils import (
MatterGenCheckpointInfo,
Expand All @@ -32,7 +32,6 @@
from mattergen.common.utils.globals import DEFAULT_SAMPLING_CONFIG_PATH, get_device
from mattergen.diffusion.lightning_module import DiffusionLightningModule
from mattergen.diffusion.sampling.pc_sampler import PredictorCorrector
from mattergen.common.utils.data_classes import ProgressCallback


def draw_samples_from_sampler(
Expand Down
1 change: 0 additions & 1 deletion mattergen/scripts/run.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import json
import logging

import hydra
Expand Down
14 changes: 14 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,17 @@ explicit = true
[build-system]
requires = ["setuptools <81"]
build-backend = "setuptools.build_meta"

[tool.ruff]
line-length = 120
extend-exclude = ["*.ipynb"]

[tool.ruff.lint]
select = ["E", "F", "I"]
ignore = ["E501", "E402", "E731", "E722", "E741", "E721", "E701", "F841", "F403", "F405"]

[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401"]

[tool.uv]
# Install with: uv sync