Conversation
|
|
||
| self.variation = variation | ||
|
|
||
| assert distribution in ['gaussian', 'gamma'], f'Unsupported distribution {distribution}' |
There was a problem hiding this comment.
| assert distribution in ['gaussian', 'gamma'], f'Unsupported distribution {distribution}' | |
| if distribution not in ['gaussian', 'gamma']: | |
| raise ValueError( f'Unsupported distribution {distribution}') |
|
|
||
| assert distribution in ['gaussian', 'gamma'], f'Unsupported distribution {distribution}' | ||
| if distribution == 'gamma': | ||
| ''' |
There was a problem hiding this comment.
Did you comment out this code chunk using '''? Can you remove it if we don't need it here?
There was a problem hiding this comment.
yes this is an artefact of this being a temporary commit
| torch.Tensor: the log prior of the variable with shape (num_seeds, num_classes). | ||
| """ | ||
| # p(sample) | ||
| # DEBUG_MARKER |
There was a problem hiding this comment.
Remove the debug marker?
| concentration = torch.exp(mu) | ||
| rate = self.prior_variance | ||
| out = Gamma(concentration=concentration, | ||
| rate=rate).log_prob(sample) |
There was a problem hiding this comment.
This might be a bug! The rate and prior_variance are different, can you double-check it?
There was a problem hiding this comment.
Yes, i meant for gamma prior_variance to represent rate. Earlier too I set prior_variance to rate. This should be correct.
There was a problem hiding this comment.
Thanks for flagging it though.
There was a problem hiding this comment.
Is prior_variance the rate of prior or the log(rate) of prior?
| rate = torch.exp(self.variational_logstd) | ||
| return Gamma(concentration=concentration, rate=rate) | ||
| else: | ||
| raise NotImplementedError("Unknown variational distribution type.") |
There was a problem hiding this comment.
| raise NotImplementedError("Unknown variational distribution type.") | |
| raise NotImplementedError(f"Unknown variational distribution type {self.distribution}.") |
| def is_observable(name: str) -> bool: | ||
| return any(name.startswith(prefix) for prefix in observable_prefix) | ||
|
|
||
| utility_string = utility_string.replace(' - ', ' + -') |
There was a problem hiding this comment.
We might want to find a way to parse utilities even when the user does not put spaces around + or -; let's do this later.
| coef = (coef_sample_0 * coef_sample_1).sum(dim=-1) | ||
|
|
||
| additive_term = (coef * obs).sum(dim=-1) | ||
| additive_term *= term['sign'] |
There was a problem hiding this comment.
Can we just add one single additive_term *= term['sign'] outside the if-else-if loop?
There was a problem hiding this comment.
yes, should be possible but i'll make sure
| loss = - elbo | ||
| return loss | ||
|
|
||
| # DEBUG_MARKER |
There was a problem hiding this comment.
Let's remove debug marker.
There was a problem hiding this comment.
I did not review the configurations and main in your super-market specific script.
There was a problem hiding this comment.
yes i'll take care of those.
Temporary working commit for Tianyu to review.