Skip to content
This repository was archived by the owner on Oct 1, 2025. It is now read-only.
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 43 additions & 16 deletions flsim/utils/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,16 @@
# utils for use in the examples and tutorials

import random
from typing import Any, Dict, Generator, Iterable, Iterator, List, Optional, Tuple
from typing import Any, Dict, Generator, Iterable, Iterator, List, Optional

import torch
import torch.nn.functional as F
from tqdm import tqdm
from torchvision import transforms
from torch import nn
from torch.utils.data import Dataset
from torchvision.datasets.cifar import CIFAR10
from torchvision.datasets.vision import VisionDataset
from flsim.data.data_provider import IFLDataProvider, IFLUserData
from flsim.data.data_sharder import FLDataSharder, SequentialSharder
from flsim.interfaces.data_loader import IFLDataLoader
Expand All @@ -20,20 +26,37 @@
from flsim.metrics_reporter.tensorboard_metrics_reporter import FLMetricsReporter
from flsim.utils.data.data_utils import batchify
from flsim.utils.simple_batch_metrics import FLBatchMetrics
from torch import nn
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.datasets.cifar import CIFAR10
from torchvision.datasets.vision import VisionDataset
from tqdm import tqdm


def collate_fn(batch: Tuple) -> Dict[str, Any]:
feature, label = batch
def collate_fn(batch: Any) -> Dict[str, Any]:
"""
Process the given batch and return a dictionary with features and labels.

Args:
batch (Any): The batch to process, can be a tuple or a dictionary.

Returns:
Dict[str, Any]: A dictionary containing the processed features and labels.
"""
if isinstance(batch, tuple):
feature, label = batch
elif isinstance(batch, dict):
feature = batch["image"]
label = batch["label"]
else:
raise TypeError("The batch must be a tuple or dict")
return {"features": feature, "labels": label}


class DataLoader(IFLDataLoader):

"""Data loader class for federated learning.

This data loader handles the batching and sharding of datasets for
federated learning scenarios.

Attributes:
SEED (int): Seed for random number generation.
"""
SEED = 2137
random.seed(SEED)

Expand Down Expand Up @@ -163,7 +186,13 @@ def get_num_examples(batch: List) -> int:
def fl_training_batch(
features: List[torch.Tensor], labels: List[float]
) -> Dict[str, torch.Tensor]:
return {"features": torch.stack(features), "labels": torch.Tensor(labels)}
# Check the type of the first element in labels list to determine if conversion is needed
if not isinstance(labels[0], torch.Tensor):
labels = torch.tensor(labels, dtype=torch.float32)
else:
labels = torch.stack(labels)

return {"features": torch.stack(features), "labels": labels}


class LEAFDataLoader(IFLDataLoader):
Expand Down Expand Up @@ -230,10 +259,9 @@ def num_train_users(self) -> int:
def get_train_user(self, user_index: int) -> IFLUserData:
if user_index in self._train_users:
return self._train_users[user_index]
else:
raise IndexError(
f"Index {user_index} is out of bound for list with len {self.num_train_users()}"
)
raise IndexError(
f"Index {user_index} is out of bound for list with len {self.num_train_users()}"
)

def train_users(self) -> Iterable[IFLUserData]:
for user_data in self._train_users.values():
Expand Down Expand Up @@ -261,7 +289,6 @@ def _create_fl_users(
def build_data_provider(
local_batch_size, examples_per_user, image_size
) -> DataProvider:

# 1. Create training, eval, and test datasets like in non-federated learning.
transform = transforms.Compose(
[
Expand Down