Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ class Experiment:
quadrant_threshold: int = 0
sparsity: int = 1 # repr_loss scaled up by sparsity, applied every 1/sparsity
latent_noise_std: float = 0
latent_masking: bool = False
latent_masking_incentive: float = 0.1


baseline = Experiment(
Expand Down
34 changes: 11 additions & 23 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,12 +346,16 @@ def _hyperparameter_search(*experiments_iterable: Experiment) -> None:
df = []
for train_result in train_results:
last_step = max(step_result.step for step_result in train_result.step_results)
reconstruction_loss_p2 = np.mean([
step_result.reconstruction_loss_p2
for step_result in train_result.step_results
if step_result.step >= last_step * 0.9
])
df.append(dict(tag=train_result.tag, reconstruction_loss_p2=reconstruction_loss_p2))
reconstruction_loss_p2 = np.mean(
[
step_result.reconstruction_loss_p2
for step_result in train_result.step_results
if step_result.step >= last_step * 0.9
]
)
df.append(
dict(tag=train_result.tag, reconstruction_loss_p2=reconstruction_loss_p2)
)
df = pd.DataFrame(df)
fig, ax = plt.subplots()
sns.barplot(data=df, x="tag", y="reconstruction_loss_p2", ax=ax)
Expand Down Expand Up @@ -438,8 +442,6 @@ def _train(experiment: Experiment) -> TrainResult:
target_latent_fn: Callable
repr_loss_mask_fn: Callable

latent_use_mask_fn = torch.nn.Sigmoid()

if experiment.loss_quadrants == "all":
repr_loss_mask_fn = lambda x: torch.ones(x.shape[0])
repr_loss_scale = 1.0
Expand All @@ -455,7 +457,6 @@ def _train(experiment: Experiment) -> TrainResult:
raise ValueError(
f"Loss quadrant must be 'all', 'bin_sum' or 'bin_val', got {experiment.loss_quadrants}."
)
print(f"repr_loss_scale = {repr_loss_scale}")

if experiment.use_class:
reconstruction_loss_fn = torch.nn.CrossEntropyLoss()
Expand Down Expand Up @@ -492,9 +493,6 @@ def _train(experiment: Experiment) -> TrainResult:
step_results = []
encoder_to_decoder_idx = list(range(len(models)))
for step in range(experiment.num_batches):
# TODO: Delete this if?
if experiment.dropout_prob is not None and step == 9000:
pass
if experiment.shuffle_decoders:
random.shuffle(encoder_to_decoder_idx)

Expand Down Expand Up @@ -527,14 +525,6 @@ def _train(experiment: Experiment) -> TrainResult:
vector_input = vector

latent_repr = encoder(vector_input)
if experiment.latent_masking:
latent_repr = latent_repr[: experiment.preferred_rep_size]
repr_mask = latent_use_mask_fn(latent_repr)
# TODO: Unused variable?
repr_use_loss = torch.mean(repr_mask)
else:
repr_use_loss = torch.Tensor(0)

noise = torch.normal(
mean=0, std=experiment.latent_noise_std, size=latent_repr.shape
)
Expand All @@ -555,7 +545,6 @@ def _train(experiment: Experiment) -> TrainResult:
target_latent_fn(latent_repr),
)
# Scaling here to compensate for quadrant sparsity
# TODO: Should we roll this into `experiment.representation_loss`?
representation_loss *= repr_loss_scale
loss = reconstruction_loss
if experiment.representation_loss is not None:
Expand Down Expand Up @@ -690,8 +679,7 @@ def _make_diagonal_repr_fn(rep_size: int) -> Callable:
def diagonal_repr_target(_input: torch.Tensor) -> torch.Tensor:
assert rep_size + 1 <= _input.shape[1]
dir_1 = _input[:, :rep_size]
# TODO: Wrap this so that we don't use N+1 variables?
dir_2 = _input[:, 1 : rep_size + 1]
dir_2 = torch.concat([_input[:, 1:rep_size], _input[:, :1]], dim=1)
repr_target = (dir_1 + dir_2) / np.sqrt(2)
return repr_target

Expand Down