Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,12 @@ cython_debug/
.ruff_cache/

# PyPI configuration file
.pypirc
.pypirc

# Models
**/*.keras
**/*.pt
**/*.csv
**/*.jpg

benchmarking/
38 changes: 37 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,38 @@
# tf_pt_benchmarking
# Tensorflow VS Pytorch

Benchmarking of training and inference performance between Tensorflow and Pytorch.


## Config

```yaml
framework: pt # [pt, tf]
seed: 25 # seed for reproducibility
n_classes: 100 # number of classes
bs: 32 # batch size
imgsz: # image size
- 32
- 32
- 3
lr: 0.001 # learning rate
epochs: 2 # number of epochs
num_workers: 0 # number of parallel workers
output: "output/pt" # output directory path
data: "cifar100" # dataset path
```

## Installation

- `pip install git+https://github.com/infocusp/tf_pt_benchmarking.git`

## How to run ?

### Download the data

- `downloadcifar100 --output path/to/data`

### Run benchmarking

- Update the `config.yaml` with required details
- `tfvspt --framework tf --config path/to/config.yaml`
- Benchmarking results will be saved in the `output` folder
17 changes: 15 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,33 @@ requires-python = ">=3.10"

# Required dependencies ------------------------------------------------------------------------------------------------
dependencies = [

"tqdm >= 4.67.1",
"torch >= 2.6.0",
"PyYAML >= 6.0.2",
"notebook >= 7.3.3",
"pydantic >= 2.10.6",
"torchinfo >= 1.8.0",
"stocaching >= 0.2.0",
"tensorflow >= 2.19.0",
"matplotlib >= 3.10.1",
"torchvision >= 0.21.0",
]

# Optional dependencies ------------------------------------------------------------------------------------------------
[project.optional-dependencies]
dev = [
"yapf",
"isort",
"pre-commit",
"yapf",
]

[project.urls]
"Source" = "https://github.com/infocusp/tf_pt_benchmarking"

[project.scripts]
tfvspt = "tfvspt.main:entrypoint"
downloadcifar100 = "tfvspt.download_data:entrypoint"

# Tools settings -------------------------------------------------------------------------------------------------------
[tool.setuptools] # configuration specific to the `setuptools` build backend.
packages = { find = { where = ["."], include = ["tfvspt", "tfvspt.*"] } }
Expand Down
13 changes: 13 additions & 0 deletions tfvspt/assets/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
framework: tf
seed: 25
n_classes: 100
bs: 32
imgsz:
- 32
- 32
- 3
lr: 0.001
epochs: 20
num_workers: 0
output: "./output"
data: ""
111 changes: 111 additions & 0 deletions tfvspt/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Benchmarking Base."""

import logging

import matplotlib.pyplot as plt
import numpy as np

from tfvspt.config.config import get_config
from tfvspt.logger import get_logger
from tfvspt.utils import save_yaml


class BenchmarkingBase:

def __init__(self, config_path: str):
self.config = get_config(config_path)
self.logger = get_logger(name=self.config.framework.value,
path=str(self.config.output / "logs.log"))
self.stats = self._init_stats()
self.raw_data = None
self.dataloaders = None

def _init_stats(self) -> dict:
keys = [
"data_loading_time", "model_building_time", "training_time",
"model_saving_time", "eval_time"
]
return {key: None for key in keys}

def set_seed(self, seed: int) -> None:
raise NotImplementedError

def _log_stats(self, key: str, value: float) -> None:
self.stats[key] = value
self.logger.info(f"{self.config.framework}, {key}, {value}")

def _get_raw_data(self) -> dict[str, list[str]]:
train = list(map(str, self.config.data.glob("train/*/*.jpg")))
test = list(map(str, self.config.data.glob("test/*/*.jpg")))
self.logger.info(
f"{self.config.framework}, train data, {len(train)}, test data, {len(test)}"
)
return {"train": train, "test": test}

def load_dataloaders(self) -> dict:
raise NotImplementedError

def train(self) -> None:
raise NotImplementedError

def eval(self) -> dict:
raise NotImplementedError

def plot_images(self, images: np.ndarray, name: str) -> None:
grid_size = int(len(images)**0.5)
_, axes = plt.subplots(grid_size, grid_size, figsize=(10, 10))
for i, ax in enumerate(axes.flatten()):
if i < len(images):
img = (images[i] * 255).astype(np.uint8)
ax.imshow(img)
ax.axis('off')
plt.savefig(str(self.config.output / f"{name}.jpg"))
plt.show()

def plot_history(self, history: dict) -> None:

epochs = range(len(history['loss']))

plt.figure(figsize=(12, 12))
plt.subplot(2, 1, 1)
plt.plot(epochs,
history['accuracy'],
'-o',
label='Train Accuracy',
color='#ff7f0e')
plt.ylabel('Accuracy', size=14)
plt.xlabel('Epoch', size=14)
plt.title("Training Accuracy")

plt.subplot(2, 1, 2)
plt.plot(epochs,
history['loss'],
'-o',
label='Train Loss',
color='#1f77b4')
plt.ylabel('Loss', size=14)
plt.xlabel('Epoch', size=14)
plt.title("Training Loss")

plt.tight_layout()
plt.savefig(str(self.config.output / "history.png"))
plt.show()

def start(self) -> None:
# set seed
self.set_seed(seed=self.config.seed)
# save the config
save_yaml(self.config.model_dump(),
str(self.config.output / "config.yaml"))
# get raw data
self.raw_data = self._get_raw_data()
# load dataloaders
self.dataloaders = self.load_dataloaders()
# train
self.train()
# eval
results = self.eval()
# save results
save_yaml(results, str(self.config.output / "results.yaml"))
# save stats
save_yaml(self.stats, str(self.config.output / "stats.yaml"))
41 changes: 41 additions & 0 deletions tfvspt/config/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Config."""

from enum import Enum
from pathlib import Path

from pydantic import BaseModel

from tfvspt.utils import read_yaml


class Framework(str, Enum):
tensorflow = "tf"
pytorch = "pt"


class Config(BaseModel):
framework: Framework
seed: int
n_classes: int
bs: int
imgsz: tuple
lr: float
epochs: int
num_workers: int
output: Path
data: Path

def model_post_init(self, __context):
self.output.mkdir(parents=True, exist_ok=True)

def model_dump(self, *args, **kwargs):
data = super().model_dump(*args, **kwargs)
data["framework"] = data["framework"].value
# Convert PosixPath to string in the dictionary
for path in ["output", "data"]:
data[path] = str(data[path])
return data


def get_config(path: str) -> Config:
return Config(**read_yaml(path))
51 changes: 51 additions & 0 deletions tfvspt/download_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Download CIFAR100 Data."""

import argparse
from pathlib import Path

import numpy as np
from PIL import Image
import tensorflow as tf
from tqdm.autonotebook import tqdm


# Function to save images as JPG
def save_images(x_data: np.ndarray, y_data: np.ndarray,
directory: Path) -> None:
for i, (img, label) in tqdm(enumerate(zip(x_data, y_data))):
# Convert the image to a PIL Image object
img_pil = Image.fromarray(img).resize((64, 64))

# Create a label-specific folder if it doesn't exist
label_dir = directory / str(label[0])
label_dir.mkdir(parents=True, exist_ok=True)

# Save the image as a .jpg file
img_filename = str(label_dir / f'{i}.jpg')
img_pil.save(img_filename)


def entrypoint() -> None:

# Create the argument parser
parser = argparse.ArgumentParser(description="Download CIFAR100 Dataset")

# Add the framework argument (string, e.g., 'pytorch', 'tensorflow')
parser.add_argument('--output',
type=Path,
required=True,
help='Specify the output path')

args = parser.parse_args()

# Load CIFAR-100 data from TensorFlow
data = tf.keras.datasets.cifar100.load_data()

output_path = args.output
# Create directories & Save Images
for idx, split in enumerate(["train", "test"]):
path = output_path / split
path.mkdir(parents=True, exist_ok=True)
save_images(*data[idx], path)

print("CIFAR-100 images have been saved as JPG files!")
33 changes: 33 additions & 0 deletions tfvspt/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Logger."""

import logging
import sys


def get_logger(name: str, path: str) -> logging.Logger:

# Create a logger
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)

# Create a file handler to log to the specified log file
file_handler = logging.FileHandler(path)
file_handler.setLevel(logging.INFO)

# Create a stream handler to log to stdout (console)
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.DEBUG)

# Define a formatter for log messages
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')

# Set the formatter for both handlers
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)

# Add handlers to the logger
logger.addHandler(file_handler)
logger.addHandler(console_handler)

return logger
45 changes: 45 additions & 0 deletions tfvspt/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Main Script."""

import argparse

from tfvspt.config.config import Framework
from tfvspt.pt.main import PytorchBenchmarking
from tfvspt.tf.main import TensorflowBenchmarking

MAPPING = {
Framework.tensorflow.value: TensorflowBenchmarking,
Framework.pytorch.value: PytorchBenchmarking,
}


def entrypoint():
# Create the argument parser
parser = argparse.ArgumentParser(
description="Select framework and provide configuration")

# Add the framework argument (string, e.g., 'pytorch', 'tensorflow')
parser.add_argument(
'--framework',
type=Framework,
required=True,
help=
'Specify the machine learning framework to use (e.g., pytorch, tensorflow)'
)

# Add the config argument (string, e.g., path to config file)
parser.add_argument('--config',
type=str,
required=True,
help='Path to the configuration file')

# Parse arguments
args = parser.parse_args()

# Print the arguments to verify
print(f"Selected framework: {args.framework}")
print(f"Configuration file: {args.config}")

# Run Benchmarking
Benchmarking = MAPPING[args.framework]
benchmarking = Benchmarking(config_path=args.config)
benchmarking.start()
Loading