Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@ __pycache__/
*.pt
*.pth
*.jsonl
eval_manifest.json
pipeline_manifest.json
mid_checkpoint.safetensors
mid_checkpoint.merkle.json
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ tqdm==4.67.3
# This default torch wheel is CPU-only (or CUDA, depending on your platform).
# For an Intel GPU (Iris Xe / Arc), install the XPU build instead — see README
# "Locally on an Intel GPU":
# pip install torch --index-url https://download.pytorch.org/whl/xpu
# pip install torch --index-url https://download.pytorch.org/whl/xpu
236 changes: 236 additions & 0 deletions src/artifacts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
import hashlib
import json
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

if TYPE_CHECKING:
import torch

MERKLE_CHUNK_SIZE_BYTES = 1024 * 1024

CHECKPOINT_STATE_PATH = "mid_checkpoint.pt"
CHECKPOINT_WEIGHTS_PATH = "mid_checkpoint.safetensors"
CHECKPOINT_MERKLE_PATH = "mid_checkpoint.merkle.json"


def hash_json(data: Any) -> str:
encoded = json.dumps(data, sort_keys=True).encode()
return hashlib.sha256(encoded).hexdigest()


def compute_sha256_bytes(
*,
data: Optional[Union[bytes, bytearray]] = None,
file_path: Optional[Union[str, Path]] = None,
) -> bytes:
if (data is None) == (file_path is None):
raise ValueError("Exactly one of data or file_path must be provided")

h = hashlib.sha256()
if data is not None:
h.update(data)
return h.digest()

with Path(file_path).open("rb") as f:
while chunk := f.read(1024 * 1024):
h.update(chunk)
return h.digest()


def compute_sha256(
*,
data: Optional[Union[bytes, bytearray]] = None,
file_path: Optional[Union[str, Path]] = None,
) -> str:
return compute_sha256_bytes(data=data, file_path=file_path).hex()


def model_parameters_sha256(model: "torch.nn.Module") -> str:
h = hashlib.sha256()
for param in model.parameters():
h.update(param.detach().cpu().numpy().tobytes())
return h.hexdigest()


def _merkle_parent(left: bytes, right: bytes) -> bytes:
return compute_sha256_bytes(data=left + right)


def merkle_root_from_leaf_hashes(leaf_hashes: List[str]) -> str:
if not leaf_hashes:
return compute_sha256(data=b"")

level = [bytes.fromhex(leaf) for leaf in leaf_hashes]
while len(level) > 1:
next_level = []
for i in range(0, len(level), 2):
left = level[i]
right = level[i + 1] if i + 1 < len(level) else left
next_level.append(_merkle_parent(left, right))
level = next_level
return level[0].hex()


def build_merkle_manifest(
file_path: Union[str, Path],
*,
chunk_size: int = MERKLE_CHUNK_SIZE_BYTES,
) -> Dict[str, Any]:
if chunk_size <= 0:
raise ValueError("chunk_size must be a positive integer")

path = Path(file_path)
chunks = []
offset = 0
file_hasher = hashlib.sha256()
with path.open("rb") as f:
while chunk := f.read(chunk_size):
file_hasher.update(chunk)
chunks.append(
{
"index": len(chunks),
"offset": offset,
"size": len(chunk),
"sha256": compute_sha256(data=chunk),
}
)
offset += len(chunk)

leaf_hashes = [chunk["sha256"] for chunk in chunks]
return {
"artifact": path.name,
"size_bytes": offset,
"sha256": file_hasher.hexdigest(),
"chunk_size_bytes": chunk_size,
"chunk_count": len(chunks),
"merkle_root": merkle_root_from_leaf_hashes(leaf_hashes),
"chunks": chunks,
}


def write_merkle_manifest(
file_path: Union[str, Path],
output_path: Union[str, Path],
*,
chunk_size: int = MERKLE_CHUNK_SIZE_BYTES,
) -> Dict[str, Any]:
manifest = build_merkle_manifest(file_path, chunk_size=chunk_size)
output = Path(output_path)
with output.open("w", encoding="utf-8") as f:
json.dump(manifest, f, indent=2)
return manifest


def generate_merkle_proof(
file_path: Union[str, Path],
chunk_index: int,
*,
chunk_size: int = MERKLE_CHUNK_SIZE_BYTES,
) -> List[Dict[str, Any]]:
manifest = build_merkle_manifest(file_path, chunk_size=chunk_size)
if manifest["chunk_count"] == 0:
raise ValueError("Cannot generate a Merkle proof for an empty file")
if chunk_index < 0 or chunk_index >= manifest["chunk_count"]:
raise IndexError("chunk_index out of range")

level = [bytes.fromhex(chunk["sha256"]) for chunk in manifest["chunks"]]
proof = []
index = chunk_index
while len(level) > 1:
if len(level) % 2 == 1:
level.append(level[-1])

sibling_index = index ^ 1
proof.append(
{
"sibling_sha256": level[sibling_index].hex(),
"sibling_position": "left" if sibling_index < index else "right",
}
)

next_level = []
for i in range(0, len(level), 2):
next_level.append(_merkle_parent(level[i], level[i + 1]))
index //= 2
level = next_level
return proof


def verify_merkle_proof(
chunk_bytes: bytes,
proof: List[Dict[str, Any]],
expected_root: str,
) -> bool:
try:
current = compute_sha256_bytes(data=chunk_bytes)
expected = bytes.fromhex(expected_root)
except (TypeError, ValueError):
return False

for step in proof:
try:
sibling = bytes.fromhex(step["sibling_sha256"])
position = step["sibling_position"]
except (KeyError, TypeError, ValueError):
return False

if len(sibling) != hashlib.sha256().digest_size:
return False
if position == "left":
current = _merkle_parent(sibling, current)
elif position == "right":
current = _merkle_parent(current, sibling)
else:
return False

return current == expected


def _stable_cpu_state_dict(model: "torch.nn.Module") -> Dict[str, "torch.Tensor"]:
state = model.state_dict()
return {
name: tensor.detach().cpu().contiguous()
for name, tensor in sorted(state.items(), key=lambda item: item[0])
}


def save_model_safetensors(
model: "torch.nn.Module",
output_path: Union[str, Path] = CHECKPOINT_WEIGHTS_PATH,
*,
metadata: Optional[Dict[str, str]] = None,
) -> Path:
try:
from safetensors.torch import save_file
except ImportError as exc:
raise RuntimeError(
"safetensors is required to write stable model artifacts. "
"Install dependencies with `pip install -r requirements.txt`."
) from exc

output = Path(output_path)
save_file(_stable_cpu_state_dict(model), str(output), metadata=metadata)
return output


def load_model_safetensors(
model: "torch.nn.Module",
input_path: Union[str, Path] = CHECKPOINT_WEIGHTS_PATH,
*,
device: Optional["torch.device"] = None,
) -> "torch.nn.Module":
input_file = Path(input_path)
if not input_file.exists():
raise FileNotFoundError(f"Safetensors artifact not found: {input_file}")

try:
from safetensors.torch import load_file
except ImportError as exc:
raise RuntimeError(
"safetensors is required to read stable model artifacts. "
"Install dependencies with `pip install -r requirements.txt`."
) from exc

state = load_file(str(input_file), device=str(device) if device is not None else "cpu")
model.load_state_dict(state)
return model
14 changes: 6 additions & 8 deletions src/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,15 @@
from main import set_seed
from config import TRAIN_CONFIG
from device import get_device
from artifacts import CHECKPOINT_WEIGHTS_PATH, hash_json, load_model_safetensors, model_parameters_sha256

DEVICE = get_device()

def hash_model(model):
h = hashlib.sha256()
for p in model.parameters():
h.update(p.data.cpu().numpy().tobytes())
return h.hexdigest()
return model_parameters_sha256(model)

def hash_dict(d):
encoded = json.dumps(d, sort_keys=True).encode()
return hashlib.sha256(encoded).hexdigest()
return hash_json(d)

if __name__ == "__main__":
set_seed(TRAIN_CONFIG["seed"])
Expand Down Expand Up @@ -70,9 +67,10 @@ def hash_dict(d):

eval_data_hash = hashlib.sha256(dataset.encoded.numpy().tobytes()).hexdigest()

# Build manifesthash is computed over content, not including itself
# Build manifest: hash is computed over content, not including itself.
manifest = {
"model_checkpoint_hash": model_hash,
"model_checkpoint_source": checkpoint_source,
"eval_dataset": eval_data_hash,
"eval_loss": loss.item(),
"perplexity": perplexity,
Expand All @@ -86,4 +84,4 @@ def hash_dict(d):
json.dump(manifest, f, indent=2)

print(f"\n ~> Manifest saved to {os.path.normpath(manifest_path)}")
print(json.dumps(manifest, indent=2))
print(json.dumps(manifest, indent=2))
20 changes: 16 additions & 4 deletions src/global_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,21 @@
import sys
import platform
import os
from pathlib import Path
from dataset import TinyDataset
from config import TRAIN_CONFIG, get_config_hash
from artifacts import (
CHECKPOINT_MERKLE_PATH,
CHECKPOINT_STATE_PATH,
CHECKPOINT_WEIGHTS_PATH,
build_merkle_manifest,
compute_sha256,
hash_json,
)

def hash_dict(d):
# Sort keys to ensure deterministic JSON stringification
encoded = json.dumps(d, sort_keys=True).encode()
return hashlib.sha256(encoded).hexdigest()
return hash_json(d)

def generate_global_manifest():
if not os.path.exists("eval_manifest.json"):
Expand Down Expand Up @@ -54,6 +62,10 @@ def generate_global_manifest():
"2_training_config_hash": config_hash,
"3_dataset_hash": dataset_hash,
"4_model_checkpoint_hash": model_hash,
"4_model_checkpoint_artifact": model_artifact,
"4_model_checkpoint_merkle_root": model_merkle["merkle_root"],
"4_model_checkpoint_chunk_size_bytes": model_merkle["chunk_size_bytes"],
"4_model_checkpoint_chunk_count": model_merkle["chunk_count"],
"5_eval_manifest_hash": eval_hash,
}

Expand All @@ -63,8 +75,8 @@ def generate_global_manifest():
with open("pipeline_manifest.json", "w") as f:
json.dump(global_manifest, f, indent=2)

print("\n ༼ つ ◕_◕ ༽つ Global Manifest Sealed:")
print("\n Global Manifest Sealed:")
print(json.dumps(global_manifest, indent=2))

if __name__ == "__main__":
generate_global_manifest()
generate_global_manifest()
6 changes: 3 additions & 3 deletions src/gpu_reproducibility_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Trains the deterministic NanoGPT twice from scratch, with no checkpoint reuse,
and asserts that the two runs produce identical loss curves and bitwise-identical
parameters. On CPU this reproduces the Phase 1 baseline; on a CUDA GPU it is the
Phase 3 claim — that with a pinned cuBLAS workspace and deterministic cuDNN, the
Phase 3 claim: with a pinned cuBLAS workspace and deterministic cuDNN, the
*same* GPU yields the *same bits* run to run.

Run from the ``src`` directory:
Expand Down Expand Up @@ -89,9 +89,9 @@ def main():

ok = losses_match and params_match and (hash1 == hash2)
if ok:
print("\n(❁ ´◡`❁) PASSED: same device is bitwise reproducible.")
print("\n[PASS] same device is bitwise reproducible.")
else:
print("\n(╯°□°)╯︵ ┻━┻ FAILED: entropy detected on this device.")
print("\n[FAIL] entropy detected on this device.")

_write_proof(losses1, losses2, hash1, hash2, ok)
return ok
Expand Down
Loading