-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclockModel.py
More file actions
86 lines (68 loc) · 3.2 KB
/
clockModel.py
File metadata and controls
86 lines (68 loc) · 3.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import math
import torch
import torch.nn as nn
class FixedRateModel(nn.Module):
def __init__(self, init_clock_rate=1.0, **kwargs):
super().__init__()
self.log_clock_rate = torch.tensor(math.log(init_clock_rate))
def sample(self, **kwargs):
return self.log_clock_rate, 0.0
def forward(self, *args):
return 0.0
class StrictModel(nn.Module):
def __init__(self, init_clock_rate=1, mu=0.5, sigma=1.0,
**kwargs): # Note: what is a good value for clock rate offset
super().__init__()
self.mu, self.sigma = mu, sigma
self.log_clock_rate_offset = math.log(init_clock_rate)
self.clock_rate_param = nn.Parameter(torch.zeros(2, 1))
nn.init.xavier_uniform_(self.clock_rate_param)
def sample(self, n_particles=1, log_tree_height=0.0):
mean, std = self.clock_rate_param[0] + self.log_clock_rate_offset, self.clock_rate_param[1]
sample_epsilon = torch.randn(n_particles, mean.size(-1))
log_clock_rate = mean + std.exp() * sample_epsilon - log_tree_height # Note: should we use log_tree_height
logq_clock_rate = torch.sum(-0.5 * math.log(2 * math.pi) - std - 0.5 * sample_epsilon ** 2, dim=-1)
return log_clock_rate, logq_clock_rate
def forward(self, log_clock_rate):
return torch.sum(
-0.5 * math.log(2 * math.pi) - math.log(self.sigma) - 0.5 * ((log_clock_rate - self.mu) / self.sigma) ** 2,
dim=-1)
# class StrictModelPL(nn.Module):
# def __init__(self, init_clock_rate=1, mu=0.5, sigma=1.0,
# **kwargs): # Note: what is a good value for clock rate offset
# super().__init__()
#
# self.mu_param = nn.Parameter(torch.tensor(math.log(mu)))
# self.sigma_param = nn.Parameter(torch.tensor(sigma))
#
# self.log_clock_rate_offset = math.log(init_clock_rate)
#
# self.clock_rate_param = nn.Parameter(torch.zeros(2, 1))
# nn.init.xavier_uniform_(self.clock_rate_param)
#
# def sample(self, n_particles=1, log_tree_height=0.0):
# mean, std = self.clock_rate_param[0] + self.log_clock_rate_offset, self.clock_rate_param[1]
# sample_epsilon = torch.randn(n_particles, mean.size(-1))
# log_clock_rate = mean + std.exp() * sample_epsilon - log_tree_height # Note: should we use log_tree_height
# logq_clock_rate = torch.sum(-0.5 * math.log(2 * math.pi) - std - 0.5 * sample_epsilon ** 2, dim=-1)
#
# return log_clock_rate, logq_clock_rate
#
# def forward(self, log_clock_rate):
# mu = self.mu_param.exp()
# sigma = self.sigma_param.exp()
# return torch.sum(
# -0.5 * math.log(2 * math.pi) - math.log(sigma) - 0.5 * ((log_clock_rate - mu) / sigma) ** 2,
# dim=-1)
class RelaxedModel(nn.Module):
def __init__(self, init_clock_rate=1, mu=0.8, sigma=2.0, **kwargs):
super().__init__()
self.mu, self.sigma = math.log(mu), sigma
def forward(self, log_clock_rate):
...
class AutocorrelatedModel(nn.Module):
def __init__(self, init_clock_rate=1, mu=0.8, sigma=2.0, **kwargs):
super().__init__()
self.mu, self.sigma = math.log(mu), sigma
def forward(self, log_clock_rate):
...