Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions experimental/shoshin/configs/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ def get_data_config():
config.num_classes = 2
config.batch_size = 64
# Number of slices into which train and val will be split.
config.num_splits = 5
config.num_splits = 100
# Ratio of splits that will be considered out-of-distribution from each
# combination, e.g. when num_splits == 5 and ood_ratio == 0.4, 2 out 5
# slices will be excluded for every combination of training data.
config.ood_ratio = 0.4
# Indices of data splits to include in training. All by default.
config.included_splits_idx = (0, 1, 2, 3, 4)
# Indices of data splits to include in training.
config.split_id = 0
# Subgroup IDs. Specify them in an experiment config. For example, for
# Waterbirds, the subgroup IDs might be ('0_1', '1_0') for landbirds on water
# and waterbirds on land, respectively.
Expand All @@ -64,6 +64,13 @@ def get_data_config():
# Waterbirds, the subgroup proportions might be (0.05, 0.05), meaning each
# subgroup will represent 5% of the dataset.
config.subgroup_proportions = ()
config.split_seed = 0
config.initial_sample_seed = 0
config.split_proportion = 1.0

# Leave one out training
config.loo_id = ''
config.loo_training = False

# Proportion of training set to sample initially. Rest is considered the pool
# for active sampling.
Expand Down Expand Up @@ -151,6 +158,9 @@ def get_config() -> ml_collections.ConfigDict:
# Round of acitve sampling being performed
config.round_idx = -1

# Keep predictions of individual models
config.keep_individual_predictions = True

# Whether to generate bias table (from stage one models) or prediction table
# (from stage two models)
config.generate_bias_table = True
Expand Down
162 changes: 139 additions & 23 deletions experimental/shoshin/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
import dataclasses
import json
import os
from typing import Any, Dict, Iterator, Optional, Tuple, List, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

import pandas as pd
import tensorflow as tf
import tensorflow_datasets as tfds

Expand Down Expand Up @@ -71,8 +72,8 @@ def get_dataset(name: str):

@dataclasses.dataclass
class Dataloader:
train_splits: tf.data.Dataset # Result of tfds.load with 'split' arg.
val_splits: tf.data.Dataset # Result of tfds.load with 'split' arg.
train_splits: List[tf.data.Dataset] # Result of tfds.load with 'split' arg.
val_splits: List[tf.data.Dataset] # Result of tfds.load with 'split' arg.
train_ds: tf.data.Dataset # Dataset with all the train splits combined.
train_sample_ds: Optional[tf.data.Dataset] = None # Subsample of train set.
eval_ds: Optional[Dict[
Expand Down Expand Up @@ -120,6 +121,15 @@ def get_train_ids(dataloader: Dataloader):
return ids_train


def filter_ids_fn(hash_table, value=1):
# Filter dataset based on whether ids take a certain value in hash table.

def filter_fn(feats, label, example_ids):
del feats, label
return hash_table.lookup(example_ids) == value
return filter_fn


class CardiotoxFingerprintDataset(tfds.core.GeneratorBasedBuilder):
"""DatasetBuilder for cardiotoxicity fingerprint dataset."""

Expand Down Expand Up @@ -455,7 +465,12 @@ def get_waterbirds_dataset(
initial_sample_proportion: float,
subgroup_ids: List[str],
subgroup_proportions: List[float],
split_proportion: Optional[float] = 0.7,
initial_sample_seed: Optional[int] = 0,
split_seed: Optional[int] = 0,
is_training: Optional[bool] = True,
loo_training: Optional[bool] = False,
loo_id: Optional[str] = None,
) -> Dataloader:
"""Returns datasets for training, validation, and possibly test sets.

Expand All @@ -466,41 +481,135 @@ def get_waterbirds_dataset(
subgroup_ids: List of strings of IDs indicating subgroups.
subgroup_proportions: List of floats indicating proportion that each
subgroup should take in initial training dataset.
split_proportion: Relative size of each split
initial_sample_seed: Seed used for supsampling from org dataset
split_seed: Seed used to sample the splits
is_training: Dataset used for evaluation (in this case as_supervised is
set to True in the val/test)
loo_training: Train splits with only a single sample removed
loo_id: the id of the removed sample

Returns:
A tuple containing the split training data, split validation data, the
combined training dataset, and a dictionary mapping evaluation dataset names
to their respective combined datasets.
"""
split_size_in_pct = int(100 * initial_sample_proportion / num_splits)
reduced_datset_sz = int(100 * initial_sample_proportion)

meta_data_df = pd.read_csv(_WATERBIRDS_METADATA_DIR)
reduced_datset_sz = int(
(meta_data_df['split'] == 0).sum() * initial_sample_proportion)

builder_kwargs = {
'subgroup_ids': subgroup_ids,
'subgroup_proportions': subgroup_proportions
}
val_splits = tfds.load(
'waterbirds_dataset',
split=[
f'validation[{k}%:{k+split_size_in_pct}%]'
for k in range(0, reduced_datset_sz, split_size_in_pct)
],
data_dir=DATA_DIR,
builder_kwargs=builder_kwargs,
try_gcs=False,
as_supervised=is_training)

train_splits = tfds.load(
train_ds = tfds.load(
'waterbirds_dataset',
split=[
f'train[{k}%:{k+split_size_in_pct}%]'
for k in range(0, reduced_datset_sz, split_size_in_pct)
],
split='train',
data_dir=DATA_DIR,
builder_kwargs=builder_kwargs,
try_gcs=False,
as_supervised=is_training)
as_supervised=True,
with_info=False)

def create_ids_table(meta_data_df,
initial_sample_size,
initial_sample_seed,
split_proportion,
num_splits,
split_seed,
train=True) -> List[tf.lookup.StaticHashTable]:
"""Creates static hash tables representing ids in each file in ids_dir for each split."""
ids_tables = []
meta_data_df = meta_data_df[meta_data_df['split'] == 0] # Get trainset
ids = meta_data_df['img_filename']
subset_ids = ids.sample(
n=initial_sample_size, random_state=initial_sample_seed)
# ids_dir is populated by the sample_and_split_ids function above
for split_num in range(num_splits):
ids_i = subset_ids.sample(
frac=split_proportion, random_state=split_num + split_seed)
if not train:
ids_i = subset_ids[~subset_ids.isin(ids_i)]
if split_proportion >= 1:
# If the split is the entire subsample we use the
# remaining data as validation, only needed for loo training.
ids_i = ids[~ids.isin(subset_ids)]
keys = tf.convert_to_tensor(ids_i, dtype=tf.string)
values = tf.ones(shape=keys.shape, dtype=tf.int64)
init = tf.lookup.KeyValueTensorInitializer(
keys=keys,
values=values,
key_dtype=tf.string,
value_dtype=tf.int64)
ids_tables.append(tf.lookup.StaticHashTable(init, default_value=0))
return ids_tables

def create_loo_ids_table(meta_data_df,
initial_sample_size,
initial_sample_seed,
loo_id,
num_splits,
train=True) -> List[tf.lookup.StaticHashTable]:
"""Creates static hash table representing ids in each file in ids_dir with all ids except loo_id."""
ids_tables = []
meta_data_df = meta_data_df[meta_data_df['split'] == 0] # Get trainset
ids = meta_data_df['img_filename']
subset_ids = ids.sample(
n=initial_sample_size, random_state=initial_sample_seed)
# ids_dir is populated by the sample_and_split_ids function above
ids_i = subset_ids[subset_ids != loo_id]
if not train:
ids_i = subset_ids[~subset_ids.isin(ids_i)]
keys = tf.convert_to_tensor(ids_i, dtype=tf.string)
values = tf.ones(shape=keys.shape, dtype=tf.int64)
init = tf.lookup.KeyValueTensorInitializer(
keys=keys, values=values, key_dtype=tf.string, value_dtype=tf.int64)
ids_tables = [
tf.lookup.StaticHashTable(init, default_value=0)
for split_num in range(num_splits)
]
return ids_tables

if loo_training:
train_ids = create_loo_ids_table(
meta_data_df,
initial_sample_size=reduced_datset_sz,
initial_sample_seed=0,
loo_id=loo_id,
num_splits=num_splits,
train=True)
val_ids = create_loo_ids_table(
meta_data_df,
initial_sample_size=reduced_datset_sz,
initial_sample_seed=0,
loo_id=loo_id,
num_splits=num_splits,
train=False)
else:
train_ids = create_ids_table(
meta_data_df,
initial_sample_size=reduced_datset_sz,
initial_sample_seed=initial_sample_seed,
split_proportion=split_proportion,
num_splits=num_splits,
split_seed=split_seed,
train=True)
val_ids = create_ids_table(
meta_data_df,
initial_sample_size=reduced_datset_sz,
initial_sample_seed=initial_sample_seed,
split_proportion=split_proportion,
num_splits=num_splits,
split_seed=split_seed,
train=False)

train_splits = [
train_ds.filter(filter_ids_fn(ids_tab)) for ids_tab in train_ids
]

val_splits = [train_ds.filter(filter_ids_fn(ids_tab)) for ids_tab in val_ids]

train_sample = tfds.load(
'waterbirds_dataset',
Expand All @@ -520,8 +629,15 @@ def get_waterbirds_dataset(
as_supervised=is_training,
with_info=False)

train_ds = gather_data_splits(list(range(num_splits)), train_splits)
val_ds = gather_data_splits(list(range(num_splits)), val_splits)
val_ds = tfds.load(
'waterbirds_dataset',
split='validation',
data_dir=DATA_DIR,
builder_kwargs=builder_kwargs,
try_gcs=False,
as_supervised=is_training,
with_info=False)

eval_datasets = {
'val': val_ds,
'test': test_ds,
Expand Down
29 changes: 18 additions & 11 deletions experimental/shoshin/train_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,17 @@ def main(_) -> None:
logging.info('Running Round %d of Training.', config.round_idx)
if config.round_idx == 0:
# If initial round of sampling, sample randomly initial_sample_proportion
dataloader = dataset_builder(config.data.num_splits,
config.data.initial_sample_proportion,
config.data.subgroup_ids,
config.data.subgroup_proportions)
dataloader = dataset_builder(
config.data.num_splits,
config.data.initial_sample_proportion,
config.data.subgroup_ids,
config.data.subgroup_proportions,
initial_sample_seed=config.data.initial_sample_seed,
split_seed=config.data.split_seed,
split_proportion=config.data.split_proportion,
loo_training=config.data.loo_training,
loo_id=config.data.loo_id,
)
else:
# If latter round, keep track of split generated in last round of active
# sampling
Expand All @@ -127,17 +134,14 @@ def main(_) -> None:
logging.info(
'Error: Bias table not found')
return
# Training a single model on a combination of data splits.
included_splits_idx = [int(i) for i in config.data.included_splits_idx]
train_ds = data.gather_data_splits(included_splits_idx,
dataloader.train_splits)
val_ds = data.gather_data_splits(included_splits_idx,
dataloader.val_splits)
# Training a single model on a specific
train_ds = dataloader.train_splits[int(config.data.split_id)]
val_ds = dataloader.val_splits[int(config.data.split_id)]
dataloader.train_ds = train_ds
dataloader.eval_ds['val'] = val_ds
experiment_name = 'stage_2' if config.train_bias else 'stage_1'

_ = train_tf_lib.train_and_evaluate(
model = train_tf_lib.train_and_evaluate(
train_as_ensemble=config.train_stage_2_as_ensemble,
dataloader=dataloader,
model_params=model_params,
Expand All @@ -150,6 +154,9 @@ def main(_) -> None:
ensemble_dir=FLAGS.ensemble_dir,
example_id_to_bias_table=example_id_to_bias_table)

if config.keep_individual_predictions:
train_tf_lib.create_predictions_table(model, dataloader, output_dir)


if __name__ == '__main__':
app.run(main)
38 changes: 38 additions & 0 deletions experimental/shoshin/train_tf_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from absl import logging
import numpy as np
import pandas as pd
import tensorflow as tf
import data # local file import from experimental.shoshin
import models # local file import from experimental.shoshin
Expand Down Expand Up @@ -535,3 +536,40 @@ def train_and_evaluate(
example_id_to_bias_table=example_id_to_bias_table)
evaluate_model(two_head_model, checkpoint_dir, dataloader.eval_ds)
return two_head_model


def create_predictions_table(model, dataloader, save_dir, save_table=True):
"""Generates a lookup table mapping example ID and training set membership to prediction.

Args:
model: Trained model
dataloader: Dataclass object containing training and validation data.
save_dir: Directory in which predictions table will be saved as CSV.
save_table: Boolean for whether or not to save table.
Returns:
A pandas dataframe mapping example ID topredictions and group membership.
"""

table_name = 'predictions_table'
dfs = []
for is_train, ds in enumerate([dataloader.eval_ds['val'],
dataloader.train_ds]):
labels = list(
ds.map(
lambda feats, label, example_id: label).as_numpy_iterator())
labels = np.concatenate(labels)
predictions = model.predict(ds)
example_ids = list(ds.map(
lambda feats, label, example_id: example_id).as_numpy_iterator())
example_ids = np.concatenate(example_ids)
in_train = is_train*np.ones_like(example_ids)
dict_values = {'example_id': example_ids}
dict_values['in_train'] = in_train
dict_values['predictions_label'] = predictions['main'][..., 1]
dfs.append(pd.DataFrame(dict_values))
df = pd.concat(dfs)
if save_table:
df.to_csv(os.path.join(save_dir, table_name + '.csv'), index=False)
return df