diff --git a/.gitignore b/.gitignore index 1800114..b79f322 100644 --- a/.gitignore +++ b/.gitignore @@ -171,4 +171,12 @@ cython_debug/ .ruff_cache/ # PyPI configuration file -.pypirc \ No newline at end of file +.pypirc + +# Models +**/*.keras +**/*.pt +**/*.csv +**/*.jpg + +benchmarking/ \ No newline at end of file diff --git a/README.md b/README.md index 513c6a0..0dc97b8 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,38 @@ -# tf_pt_benchmarking +# Tensorflow VS Pytorch + Benchmarking of training and inference performance between Tensorflow and Pytorch. + + +## Config + +```yaml +framework: pt # [pt, tf] +seed: 25 # seed for reproducibility +n_classes: 100 # number of classes +bs: 32 # batch size +imgsz: # image size + - 32 + - 32 + - 3 +lr: 0.001 # learning rate +epochs: 2 # number of epochs +num_workers: 0 # number of parallel workers +output: "output/pt" # output directory path +data: "cifar100" # dataset path +``` + +## Installation + +- `pip install git+https://github.com/infocusp/tf_pt_benchmarking.git` + +## How to run ? + +### Download the data + +- `downloadcifar100 --output path/to/data` + +### Run benchmarking + +- Update the `config.yaml` with required details +- `tfvspt --framework tf --config path/to/config.yaml` +- Benchmarking results will be saved in the `output` folder \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 6305ef2..2844683 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,20 +14,33 @@ requires-python = ">=3.10" # Required dependencies ------------------------------------------------------------------------------------------------ dependencies = [ - + "tqdm >= 4.67.1", + "torch >= 2.6.0", + "PyYAML >= 6.0.2", + "notebook >= 7.3.3", + "pydantic >= 2.10.6", + "torchinfo >= 1.8.0", + "stocaching >= 0.2.0", + "tensorflow >= 2.19.0", + "matplotlib >= 3.10.1", + "torchvision >= 0.21.0", ] # Optional dependencies ------------------------------------------------------------------------------------------------ [project.optional-dependencies] dev = [ + "yapf", "isort", "pre-commit", - "yapf", ] [project.urls] "Source" = "https://github.com/infocusp/tf_pt_benchmarking" +[project.scripts] +tfvspt = "tfvspt.main:entrypoint" +downloadcifar100 = "tfvspt.download_data:entrypoint" + # Tools settings ------------------------------------------------------------------------------------------------------- [tool.setuptools] # configuration specific to the `setuptools` build backend. packages = { find = { where = ["."], include = ["tfvspt", "tfvspt.*"] } } diff --git a/tfvspt/assets/config.yaml b/tfvspt/assets/config.yaml new file mode 100644 index 0000000..44e5c55 --- /dev/null +++ b/tfvspt/assets/config.yaml @@ -0,0 +1,13 @@ +framework: tf +seed: 25 +n_classes: 100 +bs: 32 +imgsz: + - 32 + - 32 + - 3 +lr: 0.001 +epochs: 20 +num_workers: 0 +output: "./output" +data: "" \ No newline at end of file diff --git a/tfvspt/base.py b/tfvspt/base.py new file mode 100644 index 0000000..36a3c69 --- /dev/null +++ b/tfvspt/base.py @@ -0,0 +1,111 @@ +"""Benchmarking Base.""" + +import logging + +import matplotlib.pyplot as plt +import numpy as np + +from tfvspt.config.config import get_config +from tfvspt.logger import get_logger +from tfvspt.utils import save_yaml + + +class BenchmarkingBase: + + def __init__(self, config_path: str): + self.config = get_config(config_path) + self.logger = get_logger(name=self.config.framework.value, + path=str(self.config.output / "logs.log")) + self.stats = self._init_stats() + self.raw_data = None + self.dataloaders = None + + def _init_stats(self) -> dict: + keys = [ + "data_loading_time", "model_building_time", "training_time", + "model_saving_time", "eval_time" + ] + return {key: None for key in keys} + + def set_seed(self, seed: int) -> None: + raise NotImplementedError + + def _log_stats(self, key: str, value: float) -> None: + self.stats[key] = value + self.logger.info(f"{self.config.framework}, {key}, {value}") + + def _get_raw_data(self) -> dict[str, list[str]]: + train = list(map(str, self.config.data.glob("train/*/*.jpg"))) + test = list(map(str, self.config.data.glob("test/*/*.jpg"))) + self.logger.info( + f"{self.config.framework}, train data, {len(train)}, test data, {len(test)}" + ) + return {"train": train, "test": test} + + def load_dataloaders(self) -> dict: + raise NotImplementedError + + def train(self) -> None: + raise NotImplementedError + + def eval(self) -> dict: + raise NotImplementedError + + def plot_images(self, images: np.ndarray, name: str) -> None: + grid_size = int(len(images)**0.5) + _, axes = plt.subplots(grid_size, grid_size, figsize=(10, 10)) + for i, ax in enumerate(axes.flatten()): + if i < len(images): + img = (images[i] * 255).astype(np.uint8) + ax.imshow(img) + ax.axis('off') + plt.savefig(str(self.config.output / f"{name}.jpg")) + plt.show() + + def plot_history(self, history: dict) -> None: + + epochs = range(len(history['loss'])) + + plt.figure(figsize=(12, 12)) + plt.subplot(2, 1, 1) + plt.plot(epochs, + history['accuracy'], + '-o', + label='Train Accuracy', + color='#ff7f0e') + plt.ylabel('Accuracy', size=14) + plt.xlabel('Epoch', size=14) + plt.title("Training Accuracy") + + plt.subplot(2, 1, 2) + plt.plot(epochs, + history['loss'], + '-o', + label='Train Loss', + color='#1f77b4') + plt.ylabel('Loss', size=14) + plt.xlabel('Epoch', size=14) + plt.title("Training Loss") + + plt.tight_layout() + plt.savefig(str(self.config.output / "history.png")) + plt.show() + + def start(self) -> None: + # set seed + self.set_seed(seed=self.config.seed) + # save the config + save_yaml(self.config.model_dump(), + str(self.config.output / "config.yaml")) + # get raw data + self.raw_data = self._get_raw_data() + # load dataloaders + self.dataloaders = self.load_dataloaders() + # train + self.train() + # eval + results = self.eval() + # save results + save_yaml(results, str(self.config.output / "results.yaml")) + # save stats + save_yaml(self.stats, str(self.config.output / "stats.yaml")) diff --git a/tfvspt/config/config.py b/tfvspt/config/config.py new file mode 100644 index 0000000..b76587b --- /dev/null +++ b/tfvspt/config/config.py @@ -0,0 +1,41 @@ +"""Config.""" + +from enum import Enum +from pathlib import Path + +from pydantic import BaseModel + +from tfvspt.utils import read_yaml + + +class Framework(str, Enum): + tensorflow = "tf" + pytorch = "pt" + + +class Config(BaseModel): + framework: Framework + seed: int + n_classes: int + bs: int + imgsz: tuple + lr: float + epochs: int + num_workers: int + output: Path + data: Path + + def model_post_init(self, __context): + self.output.mkdir(parents=True, exist_ok=True) + + def model_dump(self, *args, **kwargs): + data = super().model_dump(*args, **kwargs) + data["framework"] = data["framework"].value + # Convert PosixPath to string in the dictionary + for path in ["output", "data"]: + data[path] = str(data[path]) + return data + + +def get_config(path: str) -> Config: + return Config(**read_yaml(path)) diff --git a/tfvspt/download_data.py b/tfvspt/download_data.py new file mode 100644 index 0000000..61fc713 --- /dev/null +++ b/tfvspt/download_data.py @@ -0,0 +1,51 @@ +"""Download CIFAR100 Data.""" + +import argparse +from pathlib import Path + +import numpy as np +from PIL import Image +import tensorflow as tf +from tqdm.autonotebook import tqdm + + +# Function to save images as JPG +def save_images(x_data: np.ndarray, y_data: np.ndarray, + directory: Path) -> None: + for i, (img, label) in tqdm(enumerate(zip(x_data, y_data))): + # Convert the image to a PIL Image object + img_pil = Image.fromarray(img).resize((64, 64)) + + # Create a label-specific folder if it doesn't exist + label_dir = directory / str(label[0]) + label_dir.mkdir(parents=True, exist_ok=True) + + # Save the image as a .jpg file + img_filename = str(label_dir / f'{i}.jpg') + img_pil.save(img_filename) + + +def entrypoint() -> None: + + # Create the argument parser + parser = argparse.ArgumentParser(description="Download CIFAR100 Dataset") + + # Add the framework argument (string, e.g., 'pytorch', 'tensorflow') + parser.add_argument('--output', + type=Path, + required=True, + help='Specify the output path') + + args = parser.parse_args() + + # Load CIFAR-100 data from TensorFlow + data = tf.keras.datasets.cifar100.load_data() + + output_path = args.output + # Create directories & Save Images + for idx, split in enumerate(["train", "test"]): + path = output_path / split + path.mkdir(parents=True, exist_ok=True) + save_images(*data[idx], path) + + print("CIFAR-100 images have been saved as JPG files!") diff --git a/tfvspt/logger.py b/tfvspt/logger.py new file mode 100644 index 0000000..0024887 --- /dev/null +++ b/tfvspt/logger.py @@ -0,0 +1,33 @@ +"""Logger.""" + +import logging +import sys + + +def get_logger(name: str, path: str) -> logging.Logger: + + # Create a logger + logger = logging.getLogger(name) + logger.setLevel(logging.INFO) + + # Create a file handler to log to the specified log file + file_handler = logging.FileHandler(path) + file_handler.setLevel(logging.INFO) + + # Create a stream handler to log to stdout (console) + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(logging.DEBUG) + + # Define a formatter for log messages + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s') + + # Set the formatter for both handlers + file_handler.setFormatter(formatter) + console_handler.setFormatter(formatter) + + # Add handlers to the logger + logger.addHandler(file_handler) + logger.addHandler(console_handler) + + return logger diff --git a/tfvspt/main.py b/tfvspt/main.py new file mode 100644 index 0000000..336ce11 --- /dev/null +++ b/tfvspt/main.py @@ -0,0 +1,45 @@ +"""Main Script.""" + +import argparse + +from tfvspt.config.config import Framework +from tfvspt.pt.main import PytorchBenchmarking +from tfvspt.tf.main import TensorflowBenchmarking + +MAPPING = { + Framework.tensorflow.value: TensorflowBenchmarking, + Framework.pytorch.value: PytorchBenchmarking, +} + + +def entrypoint(): + # Create the argument parser + parser = argparse.ArgumentParser( + description="Select framework and provide configuration") + + # Add the framework argument (string, e.g., 'pytorch', 'tensorflow') + parser.add_argument( + '--framework', + type=Framework, + required=True, + help= + 'Specify the machine learning framework to use (e.g., pytorch, tensorflow)' + ) + + # Add the config argument (string, e.g., path to config file) + parser.add_argument('--config', + type=str, + required=True, + help='Path to the configuration file') + + # Parse arguments + args = parser.parse_args() + + # Print the arguments to verify + print(f"Selected framework: {args.framework}") + print(f"Configuration file: {args.config}") + + # Run Benchmarking + Benchmarking = MAPPING[args.framework] + benchmarking = Benchmarking(config_path=args.config) + benchmarking.start() diff --git a/tfvspt/pt/data.py b/tfvspt/pt/data.py new file mode 100644 index 0000000..f547e4e --- /dev/null +++ b/tfvspt/pt/data.py @@ -0,0 +1,54 @@ +"""Pytorch Dataset.""" + +import os + +from PIL import Image +from stocaching import get_shm_size +from stocaching import SharedCache +import torch +from torch.utils.data import Dataset + +from tfvspt.config.config import Config + + +class ClassificationDataset(Dataset): + + def __init__( + self, + config: Config, + paths: list[str], + transforms=None, + augmentations=None, + cache: bool = False, + ) -> None: + self.config = config + self.paths = paths + self.transforms = transforms + self.augmentations = augmentations + self.cache = None + if cache: + self.cache = SharedCache( + size_limit_gib=get_shm_size(), + dataset_len=len(self.paths), + data_dims=(self.config.imgsz[-1], *self.config.imgsz[:-1]), + dtype=torch.float32, + ) + + def __len__(self) -> None: + return len(self.paths) + + def __getitem__(self, idx: int): + path = self.paths[idx] + label = int(path.split(os.path.sep)[-2]) + image = None + if self.cache: + image = self.cache.get_slot(idx) + if image is None: + image = Image.open(path).convert('RGB') + if self.transforms: + image = self.transforms(image) + if self.cache: + self.cache.set_slot(idx, image) + if self.augmentations: + image = self.augmentations(image) + return image, label diff --git a/tfvspt/pt/main.py b/tfvspt/pt/main.py new file mode 100644 index 0000000..d54285a --- /dev/null +++ b/tfvspt/pt/main.py @@ -0,0 +1,218 @@ +"""Train, Eval & Log Pytorch training.""" + +import random +import time + +import numpy as np +import torch +from torch.utils.data import DataLoader +from torchinfo import summary +from torchvision import transforms +from tqdm.autonotebook import tqdm + +from tfvspt.base import BenchmarkingBase +from tfvspt.pt.data import ClassificationDataset +from tfvspt.pt.model import Model + + +class PytorchBenchmarking(BenchmarkingBase): + + def __init__(self, config_path): + super().__init__(config_path) + self.device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu") + self.logger.info(f"{self.config.framework}, device, {self.device}") + + def set_seed(self, seed): + # Set the random seed for numpy, random, and torch + np.random.seed(seed) + random.seed(seed) + torch.manual_seed(seed) + # Set the seed for CUDA (if using GPUs) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + # Ensure deterministic results for CuDNN (can be slower) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + self.logger.info(f"{self.config.framework}, seed, {self.config.seed}") + + def get_transforms(self): + return transforms.Compose([ + transforms.Resize(self.config.imgsz[:2]), + transforms.ToTensor(), + ]) + + def get_augmentations(self): + return transforms.Compose([ + transforms.RandomHorizontalFlip(), + transforms.RandomVerticalFlip(), + ]) + + def _get_dataset(self, data: str) -> ClassificationDataset: + return ClassificationDataset( + config=self.config, + paths=self.raw_data[data], + transforms=self.get_transforms(), + augmentations=self.get_augmentations() if data == "train" else None, + cache=data == "train", + ) + + def load_dataloaders(self) -> dict: + st = time.time() + dataloaders = { + "train": + DataLoader( + dataset=self._get_dataset("train"), + batch_size=self.config.bs, + shuffle=True, + num_workers=self.config.num_workers, + ), + "test": + DataLoader( + dataset=self._get_dataset("test"), + batch_size=self.config.bs, + shuffle=False, + num_workers=self.config.num_workers, + ) + } + self._log_stats("data_loading_time", time.time() - st) + return dataloaders + + def plot_images(self, dataloader, name: str) -> None: + images, _ = next(iter(dataloader)) + images = images.numpy().transpose((0, 2, 3, 1)) + return super().plot_images(images, name) + + def _train(self, dataloader: DataLoader, model: torch.nn.Module, + criterion: torch.nn.Module, + optimizer: torch.optim.Optimizer) -> dict: + + model.to(self.device) + + history = {"loss": [], "accuracy": []} + for epoch in range(self.config.epochs): + + model.train() + running_loss = 0.0 + total = 0 + correct = 0 + + self.logger.info(f"Epoch {epoch + 1} / {self.config.epochs}") + for images, targets in tqdm(dataloader): + + images, targets = images.to(self.device), targets.to( + self.device) + + optimizer.zero_grad() + outputs = model(images) + loss = criterion(outputs, targets) + loss.backward() + optimizer.step() + + running_loss += loss.item() + + predicted = torch.argmax(outputs, 1) + total += targets.size(0) + correct += (predicted == targets).sum().item() + + # Epoch Stats + epoch_loss = running_loss / len(dataloader) + epoch_accuracy = round(correct / total, 4) + + # Collect logs + history["loss"].append(epoch_loss) + history["accuracy"].append(epoch_accuracy) + + self.logger.info( + f"Epoch {epoch + 1}/{self.config.epochs}, Loss: {epoch_loss}, Accuracy: {epoch_accuracy}" + ) + + return history + + def _eval(self, dataloader: DataLoader, model: torch.nn.Module, + criterion: torch.nn.Module) -> dict: + + model.to(self.device) + model.eval() + val_loss = 0.0 + correct = 0 + total = 0 + + with torch.no_grad(): + for images, targets in tqdm(dataloader): + + images, targets = images.to(self.device), targets.to( + self.device) + + outputs = model(images) + loss = criterion(outputs, targets) + + val_loss += loss.item() + + predicted = torch.argmax(outputs, 1) + total += targets.size(0) + correct += (predicted == targets).sum().item() + + val_loss /= len(dataloader) + val_accuracy = round(correct / total, 4) + + return {"loss": val_loss, "accuracy": val_accuracy} + + def train(self) -> None: + # plot train images + self.plot_images(dataloader=self.dataloaders["train"], + name="train_data") + + # load model + st = time.time() + model = Model(n_classes=self.config.n_classes) + self._log_stats("model_building_time", time.time() - st) + model_summary = summary(model, + input_size=(self.config.bs, + self.config.imgsz[-1], + *self.config.imgsz[:-1])) + self.logger.info(f"{self.config.framework}, {model_summary}") + + # optimizer + optimizer = torch.optim.Adam(model.parameters(), lr=self.config.lr) + # criterion + criterion = torch.nn.CrossEntropyLoss() + + # train model + st = time.time() + history = self._train( + dataloader=self.dataloaders["train"], + model=model, + criterion=criterion, + optimizer=optimizer, + ) + self._log_stats("training_time", time.time() - st) + + # plot the history + self.plot_history(history=history) + + # save model + st = time.time() + torch.jit.script(model).save(str(self.config.output / "cifar100.pt")) + self._log_stats("model_saving_time", time.time() - st) + + def eval(self) -> dict: + # plot test images + self.plot_images(dataloader=self.dataloaders["test"], name="test_data") + + # load model + st = time.time() + model = torch.jit.load(str(self.config.output / "cifar100.pt")) + self._log_stats("model_loading_time", time.time() - st) + + # evaluate + st = time.time() + results = self._eval( + dataloader=self.dataloaders["test"], + model=model, + criterion=torch.nn.CrossEntropyLoss(), + ) + self.logger.info(f"{self.config.framework}, test results, {results}") + self._log_stats("eval_time", time.time() - st) + + return results diff --git a/tfvspt/pt/model.py b/tfvspt/pt/model.py new file mode 100644 index 0000000..bce02db --- /dev/null +++ b/tfvspt/pt/model.py @@ -0,0 +1,39 @@ +"""Pytorch Model.""" + +from torch import nn + + +class Model(nn.Module): + + def __init__(self, n_classes: int, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + layers = [] + channels = [3, 32, 64, 128, 128] + for i in range(1, len(channels)): + layers += [ + nn.Conv2d(channels[i - 1], + channels[i], + kernel_size=3, + padding='same'), + nn.BatchNorm2d(channels[i]), + nn.ReLU(), + nn.MaxPool2d((2, 2)), + ] + + self.backbone = nn.Sequential(*layers) + + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + + self.fc = nn.Sequential( + nn.Linear(channels[-1], 128), + nn.ReLU(), + nn.Linear(128, n_classes), + ) + + def forward(self, x): + x = self.backbone(x) + x = self.gap(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x diff --git a/tfvspt/tf/data.py b/tfvspt/tf/data.py new file mode 100644 index 0000000..781044b --- /dev/null +++ b/tfvspt/tf/data.py @@ -0,0 +1,57 @@ +"""Tensorflow Data Loader.""" + +import os + +import tensorflow as tf +from tensorflow.data import AUTOTUNE + +from tfvspt.config.config import Config + + +class ClassificationDataset: + + def __init__(self, config: Config) -> None: + self.config = config + + def preprocess_data(self, path: str): + # read the image from disk, decode it, convert the data type to + # floating point, and resize it + image = tf.io.read_file(path) + image = tf.image.decode_jpeg(image, channels=3) + image = tf.image.resize(image, self.config.imgsz[:2]) + image = tf.cast(image / 255.0, tf.float32) + # parse the class label from the file path + label = tf.strings.split(path, os.path.sep)[-2] + label = tf.strings.to_number(label, tf.int32) + label = tf.one_hot(label, self.config.n_classes) + return image, label + + def augment_data(self, image, label): + image = tf.image.random_flip_left_right(image) + image = tf.image.random_flip_up_down(image) + return image, label + + def get_dataloader( + self, + paths: list, + augment: bool = False, + shuffle: bool = False, + repeat: bool = False, + ): + # Create tf dataloader + dataloader = tf.data.Dataset.from_tensor_slices(paths) + dataloader = (dataloader.map(self.preprocess_data, + num_parallel_calls=AUTOTUNE).cache()) + # Augment + if augment: + dataloader = dataloader.map(self.augment_data, + num_parallel_calls=AUTOTUNE) + # Shuffle + if shuffle: + dataloader = dataloader.shuffle(len(paths)) + # Batch + dataloader = (dataloader.batch(self.config.bs).prefetch(AUTOTUNE)) + # Repeat + if repeat: + dataloader = dataloader.repeat() + return dataloader diff --git a/tfvspt/tf/main.py b/tfvspt/tf/main.py new file mode 100644 index 0000000..9b9e7fb --- /dev/null +++ b/tfvspt/tf/main.py @@ -0,0 +1,119 @@ +"""Train, Eval & Log Tensorflow training.""" + +import math +import random +import time + +import numpy as np +import tensorflow as tf + +from tfvspt.base import BenchmarkingBase +from tfvspt.tf.data import ClassificationDataset +from tfvspt.tf.model import get_model + + +class TensorflowBenchmarking(BenchmarkingBase): + + def set_seed(self, seed: int) -> None: + # Set random seed for Python + random.seed(seed) + # Set random seed for NumPy + np.random.seed(seed) + # Set random seed for TensorFlow + tf.random.set_seed(seed) + self.logger.info(f"{self.config.framework}, seed, {self.config.seed}") + + def load_dataloaders(self) -> dict: + dataset = ClassificationDataset(config=self.config) + st = time.time() + dataloaders = { + "train": + dataset.get_dataloader(self.raw_data["train"], + augment=True, + shuffle=True, + repeat=True), + "test": + dataset.get_dataloader(self.raw_data["test"], + augment=False, + shuffle=False, + repeat=False), + } + self._log_stats("data_loading_time", time.time() - st) + return dataloaders + + def plot_images(self, dataloader, name: str) -> None: + images, _ = next(iter(dataloader)) + images = images.numpy() + return super().plot_images(images, name) + + def get_callbacks(self) -> list: + # train history logger + csv_logger = tf.keras.callbacks.CSVLogger(str(self.config.output / + 'cifar100_training.csv'), + separator=",", + append=False) + return [csv_logger] + + def get_steps_per_epoch(self) -> None: + return math.ceil(len(self.raw_data["train"]) / self.config.bs) + + def train(self) -> None: + + # plot train images + self.plot_images(dataloader=self.dataloaders["train"], + name="train_data") + + # load model + st = time.time() + model = get_model(input_shape=self.config.imgsz, + n_classes=self.config.n_classes) + self._log_stats("model_building_time", time.time() - st) + self.logger.info(f"{self.config.framework}, {model.summary()}") + + # optimizer + optimizer = tf.keras.optimizers.Adam(learning_rate=self.config.lr) + # loss + loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True) + # compile model + model.compile( + optimizer=optimizer, + loss=loss, + metrics=["accuracy"], + ) + + # train model + st = time.time() + history = model.fit( + self.dataloaders["train"], + epochs=self.config.epochs, + shuffle=False, + steps_per_epoch=self.get_steps_per_epoch(), + callbacks=self.get_callbacks(), + ) + self._log_stats("training_time", time.time() - st) + + # plot the history + self.plot_history(history=history.history) + + # save model + st = time.time() + model.save(str(self.config.output / "cifar100.keras")) + self._log_stats("model_saving_time", time.time() - st) + + def eval(self) -> dict: + # plot test images + self.plot_images(dataloader=self.dataloaders["test"], name="test_data") + + # load model + st = time.time() + model = tf.keras.models.load_model( + str(self.config.output / "cifar100.keras")) + self._log_stats("model_loading_time", time.time() - st) + + # evaluate + st = time.time() + results = model.evaluate(self.dataloaders["test"], return_dict=True) + self.logger.info(f"{self.config.framework}, test results, {results}") + self._log_stats("eval_time", time.time() - st) + + return results diff --git a/tfvspt/tf/model.py b/tfvspt/tf/model.py new file mode 100644 index 0000000..477ea7b --- /dev/null +++ b/tfvspt/tf/model.py @@ -0,0 +1,28 @@ +"""Tensorflow Model.""" + +from tensorflow.keras import layers +from tensorflow.keras import models + + +def get_model( + input_shape: tuple[int, int, int], + n_classes: int, +): + # backbone + channels = [32, 64, 128, 128] + layers_ = [layers.InputLayer(input_shape)] + for i in range(len(channels)): + layers_ += [ + layers.Conv2D(channels[i], (3, 3), activation=None, padding='same'), + layers.BatchNormalization(), + layers.ReLU(), + layers.MaxPooling2D((2, 2)), + ] + # global average pooling + layers_ += [layers.GlobalAveragePooling2D()] + # fc + layers_ += [ + layers.Dense(128, activation='relu'), + layers.Dense(n_classes, activation=None), + ] + return models.Sequential(layers_) diff --git a/tfvspt/utils.py b/tfvspt/utils.py new file mode 100644 index 0000000..a121cb3 --- /dev/null +++ b/tfvspt/utils.py @@ -0,0 +1,14 @@ +"""Common Utils.""" + +from typing import Any + +import yaml + + +def read_yaml(path: str) -> dict: + return yaml.safe_load(open(path)) + + +def save_yaml(data: Any, path: str) -> None: + with open(path, 'w') as file: + yaml.dump(data, file, default_flow_style=False)