diff --git a/src/main/java/mascot/dynamics/GLM.java b/src/main/java/mascot/dynamics/GLM.java index ffd1f11..a635997 100644 --- a/src/main/java/mascot/dynamics/GLM.java +++ b/src/main/java/mascot/dynamics/GLM.java @@ -117,7 +117,7 @@ public boolean intervalIsDirty(int i){ public double[] getCoalescentRate(int i){ int intervalNr; if (i >= rateShiftsInput.get().getDimension()-firstlargerzero-1) - intervalNr = rateShiftsInput.get().getDimension()-2; + intervalNr = rateShiftsInput.get().getDimension()-firstlargerzero-1; else intervalNr = i + firstlargerzero; @@ -212,7 +212,7 @@ public void close(PrintStream out) { public double getNe(int state, int i){ int intervalNr; if (i >= rateShiftsInput.get().getDimension()-firstlargerzero-1) - intervalNr = rateShiftsInput.get().getDimension()-2; + intervalNr = rateShiftsInput.get().getDimension()-firstlargerzero-1; else intervalNr = i + firstlargerzero; @@ -224,7 +224,7 @@ public double getNe(int state, int i){ public double getMig(int source, int sink, int i){ int intervalNr; if (i >= rateShiftsInput.get().getDimension()-firstlargerzero-1) - intervalNr = rateShiftsInput.get().getDimension()-2; + intervalNr = rateShiftsInput.get().getDimension()-firstlargerzero-1; else intervalNr = i + firstlargerzero; diff --git a/src/main/java/mascot/glmmodel/Covariate.java b/src/main/java/mascot/glmmodel/Covariate.java index 5773cab..4d95229 100644 --- a/src/main/java/mascot/glmmodel/Covariate.java +++ b/src/main/java/mascot/glmmodel/Covariate.java @@ -28,6 +28,10 @@ public Covariate(Double[] values, String id) { this.values = new Double[values.length]; System.arraycopy(values, 0, this.values, 0, values.length); this.ID = id; + // Also populate valuesInput so initAndValidate() and XML serialization work correctly + for (Double v : values) { + valuesInput.get().add(v); + } } public Covariate(List rawValues, String id) { diff --git a/src/test/java/mascot/dynamics/GLMTest.java b/src/test/java/mascot/dynamics/GLMTest.java new file mode 100644 index 0000000..61ee8b0 --- /dev/null +++ b/src/test/java/mascot/dynamics/GLMTest.java @@ -0,0 +1,90 @@ +package mascot.dynamics; + +import beast.base.spec.domain.Real; +import beast.base.spec.inference.parameter.BoolVectorParam; +import beast.base.spec.inference.parameter.RealScalarParam; +import beast.base.spec.inference.parameter.RealVectorParam; +import mascot.glmmodel.Covariate; +import mascot.glmmodel.CovariateList; +import mascot.glmmodel.LogLinear; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class GLMTest { + + /** + * Regression for the single-epoch GLM intervalNr bug. + *

+ * In a single-epoch model (rateShifts = [Infinity]), firstlargerzero == 0 + * and dimension == 1, so the boundary branch in getCoalescentRate / getNe / + * getMig is always taken for any i. The old hardcoded fallback + * intervalNr = dim - 2 evaluated to -1 for dim=1, causing + * ArrayIndexOutOfBoundsException when rates were looked up. The fix + * uses dim - firstlargerzero - 1, which is the correct last valid + * interval index for any rateShifts configuration. + */ + @Test + public void testSingleEpochGLMRatesDoNotThrow() { + int dim = 2; + + // single-epoch rate shifts: dim = 1 with the only value > 0, so firstlargerzero = 0 + RateShifts rateShifts = new RateShifts(); + rateShifts.initByName("value", "1.0"); + + GLM glm = buildGLM(dim, rateShifts); + + double[] coalRate = glm.getCoalescentRate(0); + assertEquals(dim, coalRate.length); + for (double v : coalRate) + assertTrue(Double.isFinite(v), "coalescent rate must be finite, got " + v); + + for (int s = 0; s < dim; s++) { + double ne = glm.getNe(s, 0); + assertTrue(Double.isFinite(ne), "Ne must be finite, got " + ne); + } + + double mig = glm.getMig(0, 1, 0); + assertTrue(Double.isFinite(mig), "migration rate must be finite, got " + mig); + } + + private GLM buildGLM(int dim, RateShifts rateShifts) { + LogLinear migGLM = buildLogLinear(dim * (dim - 1)); + LogLinear neGLM = buildLogLinear(dim); + + GLM glm = new GLM(); + glm.initByName( + "dimension", dim, + "rateShifts", rateShifts, + "migrationGLM", migGLM, + "NeGLM", neGLM, + "types", "a b"); + return glm; + } + + private LogLinear buildLogLinear(int covariateDim) { + Double[] vals = new Double[covariateDim]; + for (int i = 0; i < covariateDim; i++) + vals[i] = 1.0; + + Covariate cov = new Covariate(vals, "cov"); + cov.initAndValidate(); + + CovariateList covList = new CovariateList(); + covList.initByName("covariates", cov); + + RealVectorParam scaler = new RealVectorParam<>(new double[]{0.0}, Real.INSTANCE); + BoolVectorParam indicator = new BoolVectorParam(new boolean[]{true}); + RealScalarParam clock = new RealScalarParam<>(); + clock.initByName("value", "1.0"); + + LogLinear glm = new LogLinear(); + glm.initByName( + "covariateList", covList, + "scaler", scaler, + "indicator", indicator, + "clock", clock); + return glm; + } +} diff --git a/src/test/java/mascot/glmmodel/CovariateTest.java b/src/test/java/mascot/glmmodel/CovariateTest.java new file mode 100644 index 0000000..417cbdf --- /dev/null +++ b/src/test/java/mascot/glmmodel/CovariateTest.java @@ -0,0 +1,44 @@ +package mascot.glmmodel; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class CovariateTest { + + @Test + public void testDoubleArrayConstructorPopulatesValuesInput() { + Double[] vals = new Double[]{1.0, 2.0, 3.0}; + Covariate c = new Covariate(vals, "test"); + assertEquals(3, c.valuesInput.get().size()); + assertEquals(1.0, c.valuesInput.get().get(0)); + assertEquals(2.0, c.valuesInput.get().get(1)); + assertEquals(3.0, c.valuesInput.get().get(2)); + } + + @Test + public void testInitAndValidateAfterDoubleArrayConstructorPreservesValues() { + // Regression for the bug where initAndValidate() rebuilt `values` from an + // empty valuesInput, clobbering the values set by the constructor. + Double[] vals = new Double[]{4.0, 5.0, 6.0}; + Covariate c = new Covariate(vals, "test"); + c.initAndValidate(); + assertEquals(3, c.getDimension()); + assertEquals(4.0, c.getArrayValue(0)); + assertEquals(5.0, c.getArrayValue(1)); + assertEquals(6.0, c.getArrayValue(2)); + } + + @Test + public void testInitAndValidateViaValuesInputStillWorks() { + // Existing code path: default constructor + valuesInput populated externally. + Covariate c = new Covariate(); + c.valuesInput.get().add(7.0); + c.valuesInput.get().add(8.0); + c.setID("test"); + c.initAndValidate(); + assertEquals(2, c.getDimension()); + assertEquals(7.0, c.getArrayValue(0)); + assertEquals(8.0, c.getArrayValue(1)); + } +}