-
Notifications
You must be signed in to change notification settings - Fork 2
Description
I have quite a few stan files for time series lying around and could contribute a module to this project.
`from abc import ABCMeta, abstractstaticmethod
from typing import Dict
import pystan as ps
class BaseStanData(Dict):
def append(self, **kwargs):
data = self.copy()
data.update(kwargs)
return data
class BaseModel(metaclass=ABCMeta):
model_code: str
def __init__(self, **kwargs):
self.kwargs = kwargs
@abstractstaticmethod
def preprocess(dat: BaseStanData) -> BaseStanData:
pass
class BaseModelResult(metaclass=ABCMeta):
pass
class TimeSeriesStanData(BaseStanData):
def init(self, y: np.array):
super().init()
assert len(y.shape) == 1, 'Mismatch dimension. y must be 1 dimensional array'
self.update(
{
'y': y,
'T': y.shape[0]
}
)
class TimeSeriesModel(BaseModel):
def init(self, **kwargs):
super().init(**kwargs)
def fit(self, y: np.array):
return TimeSeriesModelResult(
self,
ps.stan(
model_code=self.model_code,
data=self.preprocess(
TimeSeriesStanData(y)
),
**self.kwargs
)
)
@staticmethod
def preprocess(dat: TimeSeriesStanData) -> TimeSeriesStanData:
return dat
class TimeSeriesModelResult(BaseModelResult):
def init(self, model: TimeSeriesModel, stanfit):
self.model = model
self.stanfit = stanfit
def predict(self, y: np.array) -> np.array:
return self.predict_dist(y)
def check_fit():
u = self.stanfit.extract(permuted=True)['u'].mean(axis=0)
v = self.stanfit.extract(permuted=True)['v'].mean(axis=0)
fit_check = u + v
return fit_check
class BayesianStructuralTimeSeries(TimeSeriesModel):
model_code = """data {
int <lower=0> T;
vector[T] y;
}
parameters {
vector[T] u_err; //Slope innovation
vector[T] v_err; //Level innovation
real beta;
real <lower=0> s_obs;
real <lower=0> s_slope;
real <lower=0> s_level;
}
transformed parameters {
vector[T] u; //Level
vector[T] v; //Slope
u[1] = u_err[1];
v[1] = v_err[1];
for (t in 2:T) {
u[t] = u[t-1] + v[t-1] + s_level * u_err[t];
v[t] = v[t-1] + s_slope * v_err[t];
}
}
model {
u_err ~ normal(0,1);
v_err ~ normal(0,1);
y ~ normal (u, s_obs);
}"""
@staticmethod
def preprocess(dat: TimeSeriesStanData) -> TimeSeriesStanData:
return dat.append(sigma_upper=dat['y'])`
As an example.
I've tried pushing to a new branch but I've been denied permission, can I contribute? I can add AR/MA/ARMA etc.