Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 12 additions & 34 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,20 @@ on:
branches: [main]

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.10'
- run: pip install ruff
- name: Ruff check
run: ruff check src/ tests/

test:
runs-on: ubuntu-latest
needs: lint
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10']
Expand All @@ -29,7 +41,6 @@ jobs:

- name: Run tests
run: |
# Run entropy modeling tests and performance tests (legacy tests have pre-existing issues)
pytest \
tests/test_entropy_parameters.py \
tests/test_context_model.py \
Expand All @@ -48,36 +59,3 @@ jobs:
uses: codecov/codecov-action@v4
with:
file: ./coverage.xml

lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.10'
- run: pip install flake8
# Lint only the new entropy modeling and optimization files (legacy files have pre-existing issues)
- name: Lint new source files
run: |
flake8 \
src/entropy_parameters.py \
src/entropy_model.py \
src/context_model.py \
src/channel_context.py \
src/attention_context.py \
src/model_transforms.py \
src/constants.py \
src/precision_config.py \
src/benchmarks.py \
--max-line-length=120
- name: Lint new test files
run: |
flake8 \
tests/test_entropy_parameters.py \
tests/test_context_model.py \
tests/test_channel_context.py \
tests/test_attention_context.py \
tests/test_performance.py \
--max-line-length=120 \
--ignore=E402,W503 # E402: imports after sys.path, W503: PEP8 updated to prefer breaks before operators
388 changes: 388 additions & 0 deletions CLAUDE.md

Large diffs are not rendered by default.

27 changes: 27 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
[tool.ruff]
line-length = 120
target-version = "py38"

[tool.ruff.lint]
select = ["F", "I", "E", "W"]
ignore = [
"E402", # module-level import not at top — tests use sys.path before imports
"E741", # ambiguous variable name — common in math-heavy code
]

[tool.ruff.lint.per-file-ignores]
"tests/*.py" = ["E402"]

[tool.ruff.lint.isort]
known-first-party = [
"model_transforms", "entropy_model", "entropy_parameters",
"context_model", "channel_context", "attention_context",
"constants", "precision_config", "data_loader",
"training_pipeline", "evaluation_pipeline", "experiment",
"compress_octree", "decompress_octree", "octree_coding",
"ds_mesh_to_pc", "ds_pc_octree_blocks", "ds_select_largest",
"ev_compare", "ev_run_render", "mp_report", "mp_run",
"quick_benchmark", "benchmarks", "parallel_process",
"point_cloud_metrics", "map_color", "colorbar",
"cli_train", "test_utils",
]
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ pytest~=7.1.0
scipy~=1.8.1
numba~=0.56.0
tensorflow-probability~=0.19.0
ruff>=0.4.0
7 changes: 4 additions & 3 deletions src/attention_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
- Global tokens provide long-range context without full attention
"""

from typing import Any, Dict, Optional, Tuple

import tensorflow as tf
from typing import Tuple, Optional, Dict, Any

from constants import LOG_2_RECIPROCAL

Expand Down Expand Up @@ -668,8 +669,8 @@ def __init__(self,
self.num_attention_layers = num_attention_layers

# Import here to avoid circular dependency
from entropy_parameters import EntropyParameters
from entropy_model import ConditionalGaussian
from entropy_parameters import EntropyParameters

# Hyperprior-based parameter prediction
self.entropy_parameters = EntropyParameters(
Expand Down Expand Up @@ -803,9 +804,9 @@ def __init__(self,
self.num_channel_groups = num_channel_groups
self.num_attention_layers = num_attention_layers

from entropy_parameters import EntropyParameters
from channel_context import ChannelContext
from entropy_model import ConditionalGaussian
from entropy_parameters import EntropyParameters

# Hyperprior parameters
self.entropy_parameters = EntropyParameters(
Expand Down
7 changes: 4 additions & 3 deletions src/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
print(f"Peak memory: {mem.peak_mb:.1f} MB")
"""

import tensorflow as tf
import time
from typing import Callable, Dict, Any, Optional
from dataclasses import dataclass, field
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Optional

import tensorflow as tf


@dataclass
Expand Down
5 changes: 3 additions & 2 deletions src/channel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
maintaining autoregressive structure across groups.
"""

from typing import Any, Dict, List, Optional, Tuple

import tensorflow as tf
from typing import Tuple, Optional, Dict, Any, List

from constants import LOG_2_RECIPROCAL

Expand Down Expand Up @@ -230,8 +231,8 @@ def __init__(self,
self.channels_per_group = latent_channels // num_groups

# Import here to avoid circular dependency
from entropy_parameters import EntropyParameters
from entropy_model import ConditionalGaussian
from entropy_parameters import EntropyParameters

# Hyperprior-based parameter prediction
self.entropy_parameters = EntropyParameters(
Expand Down
24 changes: 13 additions & 11 deletions src/cli_train.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
import tensorflow as tf
import os
import argparse
import glob
import numpy as np
import os

import keras_tuner as kt
import tensorflow as tf

from ds_mesh_to_pc import read_off


def create_model(hp):
model = tf.keras.Sequential()
model.add(tf.keras.layers.InputLayer(input_shape=(2048, 3)))

for i in range(hp.Int('num_layers', 1, 5)):
model.add(tf.keras.layers.Dense(
hp.Int(f'layer_{i}_units', min_value=64, max_value=1024, step=64),
activation='relu'
))

model.add(tf.keras.layers.Dense(3, activation='sigmoid'))

model.compile(
optimizer=tf.keras.optimizers.Adam(
learning_rate=hp.Float('learning_rate', 1e-5, 1e-3, sampling='log')
Expand All @@ -28,7 +30,7 @@ def create_model(hp):

def load_and_preprocess_data(input_dir, batch_size):
file_paths = glob.glob(os.path.join(input_dir, "*.ply"))

def parse_ply_file(file_path):
mesh_data = read_off(file_path)
return mesh_data.vertices
Expand All @@ -47,7 +49,7 @@ def data_generator():
dataset = dataset.shuffle(buffer_size=len(file_paths))
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

return dataset

def tune_hyperparameters(input_dir, output_dir, num_epochs=10):
Expand All @@ -63,10 +65,10 @@ def tune_hyperparameters(input_dir, output_dir, num_epochs=10):

dataset = load_and_preprocess_data(input_dir, batch_size=32)
tuner.search(dataset, epochs=num_epochs, validation_data=dataset)

best_model = tuner.get_best_models(num_models=1)[0]
best_hps = tuner.get_best_hyperparameters(num_trials=1)[0]

print("Best Hyperparameters:", best_hps.values)
best_model.save(os.path.join(output_dir, 'best_model'))

Expand Down Expand Up @@ -95,4 +97,4 @@ def main():
model.save(os.path.join(args.output_dir, 'trained_model'))

if __name__ == "__main__":
main()
main()
23 changes: 14 additions & 9 deletions src/colorbar.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from typing import Tuple, Callable, Optional, List
from dataclasses import dataclass
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, List, Optional, Tuple

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np


@dataclass
class ColorbarConfig:
Expand Down Expand Up @@ -67,7 +69,10 @@ def get_colorbar(
else:
# Format numeric labels
formatter = mpl.ticker.FormatStrFormatter(label_format)
cbar.ax.xaxis.set_major_formatter(formatter) if orientation == 'horizontal' else cbar.ax.yaxis.set_major_formatter(formatter)
if orientation == 'horizontal':
cbar.ax.xaxis.set_major_formatter(formatter)
else:
cbar.ax.yaxis.set_major_formatter(formatter)

# Set font sizes
cbar.ax.tick_params(labelsize=font_size)
Expand Down Expand Up @@ -96,10 +101,10 @@ def save_color_mapping(filename: str,
# Create mapping
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
cmap = plt.get_cmap(cmap)

values = np.linspace(vmin, vmax, num_samples)
colors = [cmap(norm(v)) for v in values]

# Save to file
Path(filename).parent.mkdir(parents=True, exist_ok=True)
with open(filename, 'w') as f:
Expand All @@ -122,4 +127,4 @@ def save_color_mapping(filename: str,
tick_rotation=45,
extend='both'
)
plt.show()
plt.show()
Loading
Loading