diff --git a/flsim/utils/example_utils.py b/flsim/utils/example_utils.py index 9f12d8f4..065d49dd 100644 --- a/flsim/utils/example_utils.py +++ b/flsim/utils/example_utils.py @@ -8,10 +8,16 @@ # utils for use in the examples and tutorials import random -from typing import Any, Dict, Generator, Iterable, Iterator, List, Optional, Tuple +from typing import Any, Dict, Generator, Iterable, Iterator, List, Optional import torch import torch.nn.functional as F +from tqdm import tqdm +from torchvision import transforms +from torch import nn +from torch.utils.data import Dataset +from torchvision.datasets.cifar import CIFAR10 +from torchvision.datasets.vision import VisionDataset from flsim.data.data_provider import IFLDataProvider, IFLUserData from flsim.data.data_sharder import FLDataSharder, SequentialSharder from flsim.interfaces.data_loader import IFLDataLoader @@ -20,20 +26,37 @@ from flsim.metrics_reporter.tensorboard_metrics_reporter import FLMetricsReporter from flsim.utils.data.data_utils import batchify from flsim.utils.simple_batch_metrics import FLBatchMetrics -from torch import nn -from torch.utils.data import Dataset -from torchvision import transforms -from torchvision.datasets.cifar import CIFAR10 -from torchvision.datasets.vision import VisionDataset -from tqdm import tqdm - -def collate_fn(batch: Tuple) -> Dict[str, Any]: - feature, label = batch +def collate_fn(batch: Any) -> Dict[str, Any]: + """ + Process the given batch and return a dictionary with features and labels. + + Args: + batch (Any): The batch to process, can be a tuple or a dictionary. + + Returns: + Dict[str, Any]: A dictionary containing the processed features and labels. + """ + if isinstance(batch, tuple): + feature, label = batch + elif isinstance(batch, dict): + feature = batch["image"] + label = batch["label"] + else: + raise TypeError("The batch must be a tuple or dict") return {"features": feature, "labels": label} class DataLoader(IFLDataLoader): + + """Data loader class for federated learning. + + This data loader handles the batching and sharding of datasets for + federated learning scenarios. + + Attributes: + SEED (int): Seed for random number generation. + """ SEED = 2137 random.seed(SEED) @@ -163,7 +186,13 @@ def get_num_examples(batch: List) -> int: def fl_training_batch( features: List[torch.Tensor], labels: List[float] ) -> Dict[str, torch.Tensor]: - return {"features": torch.stack(features), "labels": torch.Tensor(labels)} + # Check the type of the first element in labels list to determine if conversion is needed + if not isinstance(labels[0], torch.Tensor): + labels = torch.tensor(labels, dtype=torch.float32) + else: + labels = torch.stack(labels) + + return {"features": torch.stack(features), "labels": labels} class LEAFDataLoader(IFLDataLoader): @@ -230,10 +259,9 @@ def num_train_users(self) -> int: def get_train_user(self, user_index: int) -> IFLUserData: if user_index in self._train_users: return self._train_users[user_index] - else: - raise IndexError( - f"Index {user_index} is out of bound for list with len {self.num_train_users()}" - ) + raise IndexError( + f"Index {user_index} is out of bound for list with len {self.num_train_users()}" + ) def train_users(self) -> Iterable[IFLUserData]: for user_data in self._train_users.values(): @@ -261,7 +289,6 @@ def _create_fl_users( def build_data_provider( local_batch_size, examples_per_user, image_size ) -> DataProvider: - # 1. Create training, eval, and test datasets like in non-federated learning. transform = transforms.Compose( [