Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/main/java/mascot/dynamics/GLM.java
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ public boolean intervalIsDirty(int i){
public double[] getCoalescentRate(int i){
int intervalNr;
if (i >= rateShiftsInput.get().getDimension()-firstlargerzero-1)
intervalNr = rateShiftsInput.get().getDimension()-2;
intervalNr = rateShiftsInput.get().getDimension()-firstlargerzero-1;
else
intervalNr = i + firstlargerzero;

Expand Down Expand Up @@ -212,7 +212,7 @@ public void close(PrintStream out) {
public double getNe(int state, int i){
int intervalNr;
if (i >= rateShiftsInput.get().getDimension()-firstlargerzero-1)
intervalNr = rateShiftsInput.get().getDimension()-2;
intervalNr = rateShiftsInput.get().getDimension()-firstlargerzero-1;
else
intervalNr = i + firstlargerzero;

Expand All @@ -224,7 +224,7 @@ public double getNe(int state, int i){
public double getMig(int source, int sink, int i){
int intervalNr;
if (i >= rateShiftsInput.get().getDimension()-firstlargerzero-1)
intervalNr = rateShiftsInput.get().getDimension()-2;
intervalNr = rateShiftsInput.get().getDimension()-firstlargerzero-1;
else
intervalNr = i + firstlargerzero;

Expand Down
4 changes: 4 additions & 0 deletions src/main/java/mascot/glmmodel/Covariate.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ public Covariate(Double[] values, String id) {
this.values = new Double[values.length];
System.arraycopy(values, 0, this.values, 0, values.length);
this.ID = id;
// Also populate valuesInput so initAndValidate() and XML serialization work correctly
for (Double v : values) {
valuesInput.get().add(v);
}
}

public Covariate(List<String> rawValues, String id) {
Expand Down
90 changes: 90 additions & 0 deletions src/test/java/mascot/dynamics/GLMTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package mascot.dynamics;

import beast.base.spec.domain.Real;
import beast.base.spec.inference.parameter.BoolVectorParam;
import beast.base.spec.inference.parameter.RealScalarParam;
import beast.base.spec.inference.parameter.RealVectorParam;
import mascot.glmmodel.Covariate;
import mascot.glmmodel.CovariateList;
import mascot.glmmodel.LogLinear;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class GLMTest {

/**
* Regression for the single-epoch GLM intervalNr bug.
* <p>
* In a single-epoch model (rateShifts = [Infinity]), firstlargerzero == 0
* and dimension == 1, so the boundary branch in getCoalescentRate / getNe /
* getMig is always taken for any i. The old hardcoded fallback
* intervalNr = dim - 2 evaluated to -1 for dim=1, causing
* ArrayIndexOutOfBoundsException when rates were looked up. The fix
* uses dim - firstlargerzero - 1, which is the correct last valid
* interval index for any rateShifts configuration.
*/
@Test
public void testSingleEpochGLMRatesDoNotThrow() {
int dim = 2;

// single-epoch rate shifts: dim = 1 with the only value > 0, so firstlargerzero = 0
RateShifts rateShifts = new RateShifts();
rateShifts.initByName("value", "1.0");

GLM glm = buildGLM(dim, rateShifts);

double[] coalRate = glm.getCoalescentRate(0);
assertEquals(dim, coalRate.length);
for (double v : coalRate)
assertTrue(Double.isFinite(v), "coalescent rate must be finite, got " + v);

for (int s = 0; s < dim; s++) {
double ne = glm.getNe(s, 0);
assertTrue(Double.isFinite(ne), "Ne must be finite, got " + ne);
}

double mig = glm.getMig(0, 1, 0);
assertTrue(Double.isFinite(mig), "migration rate must be finite, got " + mig);
}

private GLM buildGLM(int dim, RateShifts rateShifts) {
LogLinear migGLM = buildLogLinear(dim * (dim - 1));
LogLinear neGLM = buildLogLinear(dim);

GLM glm = new GLM();
glm.initByName(
"dimension", dim,
"rateShifts", rateShifts,
"migrationGLM", migGLM,
"NeGLM", neGLM,
"types", "a b");
return glm;
}

private LogLinear buildLogLinear(int covariateDim) {
Double[] vals = new Double[covariateDim];
for (int i = 0; i < covariateDim; i++)
vals[i] = 1.0;

Covariate cov = new Covariate(vals, "cov");
cov.initAndValidate();

CovariateList covList = new CovariateList();
covList.initByName("covariates", cov);

RealVectorParam<Real> scaler = new RealVectorParam<>(new double[]{0.0}, Real.INSTANCE);
BoolVectorParam indicator = new BoolVectorParam(new boolean[]{true});
RealScalarParam<Real> clock = new RealScalarParam<>();
clock.initByName("value", "1.0");

LogLinear glm = new LogLinear();
glm.initByName(
"covariateList", covList,
"scaler", scaler,
"indicator", indicator,
"clock", clock);
return glm;
}
}
44 changes: 44 additions & 0 deletions src/test/java/mascot/glmmodel/CovariateTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package mascot.glmmodel;

import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertEquals;

public class CovariateTest {

@Test
public void testDoubleArrayConstructorPopulatesValuesInput() {
Double[] vals = new Double[]{1.0, 2.0, 3.0};
Covariate c = new Covariate(vals, "test");
assertEquals(3, c.valuesInput.get().size());
assertEquals(1.0, c.valuesInput.get().get(0));
assertEquals(2.0, c.valuesInput.get().get(1));
assertEquals(3.0, c.valuesInput.get().get(2));
}

@Test
public void testInitAndValidateAfterDoubleArrayConstructorPreservesValues() {
// Regression for the bug where initAndValidate() rebuilt `values` from an
// empty valuesInput, clobbering the values set by the constructor.
Double[] vals = new Double[]{4.0, 5.0, 6.0};
Covariate c = new Covariate(vals, "test");
c.initAndValidate();
assertEquals(3, c.getDimension());
assertEquals(4.0, c.getArrayValue(0));
assertEquals(5.0, c.getArrayValue(1));
assertEquals(6.0, c.getArrayValue(2));
}

@Test
public void testInitAndValidateViaValuesInputStillWorks() {
// Existing code path: default constructor + valuesInput populated externally.
Covariate c = new Covariate();
c.valuesInput.get().add(7.0);
c.valuesInput.get().add(8.0);
c.setID("test");
c.initAndValidate();
assertEquals(2, c.getDimension());
assertEquals(7.0, c.getArrayValue(0));
assertEquals(8.0, c.getArrayValue(1));
}
}
Loading