Skip to content
This repository was archived by the owner on Feb 27, 2026. It is now read-only.
This repository was archived by the owner on Feb 27, 2026. It is now read-only.

GMM Training with Mini-Batch #57

@hashim19

Description

@hashim19

Hi, first of all thank you for the amazing repository.

I am trying to do mini-batch training of GMM. After going over #51 #19 and #7, I realized that I need to create my own dataset loader. Here is an sample of my custom dataset loader (Each file of my dataset is stored in .pkl files, so I wrote a pkl_dataset class).

` Dataset Class

class PKL_dataset(Dataset):

def __init__(self, dataset_pth, data_label):
    self.data_dir = dataset_pth
    self.files_ls = os.listdir(dataset_pth)
    self.len = len(self.files_ls)
    self.label = data_label

def __len__(self):
    return self.len

def transform(self, data):

    if data.shape[0] < 2000:

        return np.pad(data, [(0, 2000 - data.shape[0]), (0,0)], 'mean')

    else:

        return data[:2000]

def __getitem__(self, idx):

    file_path = os.path.join(self.data_dir, self.files_ls[idx])
    
    pkl_data = open_pkl(file_path)
    transformed_pkl_data = self.transform(pkl_data)
    
    return transformed_pkl_data`

Now I am calling the Gaussian Mixture class like this,

gmm = GaussianMixture(num_components=ncomp, covariance_type='diag', batch_size=32, covariance_regularization=0.1, init_strategy='kmeans++', trainer_params=dict(accelerator='gpu', devices=1, max_epochs=100))

and passing the my dataset to the fit function like this,

history = gmm.fit(pkl_dataloader)

It gives me the following error,

Traceback (most recent call last): File "asvspoof2021_baseline.py", line 65, in <module> gmm_bona = train_gmm(data_label=data_labels[0], features=features, File "/home/hashim/PhD/Audio_Spoof_Detection/Baseline-CQCC-GMM/python/gmm.py", line 218, in train_gmm history = gmm.fit(pkl_dataloader) File "/home/hashim/PhD/AsvSpoof2021/asvspoof_venv/lib/python3.8/site-packages/pycave/bayes/gmm/estimator.py", line 128, in fit num_features = len(data[0]) TypeError: 'DataLoader' object is not subscriptable

It looks like the fit routine does not accept data as a dataloader.

However, if I do not use a dataloader, the training gets killed because of the memory issues.

Here is a snapshot of my system, My GPU is NVIDIA RTX A2000 12GB
system

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions