Skip to content

Commit 31a723e

Browse files
committed
updates
1 parent 4bfbdb1 commit 31a723e

2 files changed

Lines changed: 21 additions & 15 deletions

File tree

src/main/scala/BIDMach/models/GLM.scala

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,18 @@ class GLM(opts:GLM.Opts) extends RegressionModel(opts) {
4444
def mupdate(in:Mat) = {
4545
val targs = targets * in
4646
min(targs, 1f, targs)
47-
val alltargs = if (targmap.asInstanceOf[AnyRef] != null) targmap * targs else targs
4847
val dweights = if (iweight.asInstanceOf[AnyRef] != null) iweight * in else null
49-
mupdate3(in, alltargs, dweights)
48+
mupdate3(in, targs, dweights)
5049
}
5150

5251
def mupdate2(in:Mat, targ:Mat) = mupdate3(in, targ, null)
5352

54-
def mupdate3(in:Mat, targ:Mat, dweights:Mat) = {
55-
val ftarg = full(targ)
53+
def mupdate3(in:Mat, targ:Mat, dweights:Mat) = {
54+
val ftarg = full(targ);
55+
val targs = if (targmap.asInstanceOf[AnyRef] != null) targmap * ftarg else ftarg;
5656
val eta = modelmats(0) * in
5757
GLM.preds(eta, eta, mylinks, linkArray, totflops)
58-
GLM.derivs(eta, ftarg, eta, mylinks, linkArray, totflops)
58+
GLM.derivs(eta, targs, eta, mylinks, linkArray, totflops)
5959
if (dweights.asInstanceOf[AnyRef] != null) eta ~ eta dweights
6060
updatemats(0) ~ eta *^ in
6161
if (mask.asInstanceOf[AnyRef] != null) {
@@ -64,21 +64,21 @@ class GLM(opts:GLM.Opts) extends RegressionModel(opts) {
6464
}
6565

6666
def meval(in:Mat):FMat = {
67-
val targs = targets * in
68-
min(targs, 1f, targs)
69-
val alltargs = if (targmap.asInstanceOf[AnyRef] != null) targmap * targs else targs
70-
val dweights = if (iweight.asInstanceOf[AnyRef] != null) iweight * in else null
71-
meval3(in, alltargs, dweights)
67+
val targs = targets * in;
68+
min(targs, 1f, targs);
69+
val dweights = if (iweight.asInstanceOf[AnyRef] != null) iweight * in else null;
70+
meval3(in, targs, dweights);
7271
}
7372

7473
def meval2(in:Mat, targ:Mat):FMat = meval3(in, targ, null)
7574

7675
def meval3(in:Mat, targ:Mat, dweights:Mat):FMat = {
77-
val ftarg = full(targ)
76+
val ftarg = full(targ);
77+
val targs = if (!(putBack >= 0) && targmap.asInstanceOf[AnyRef] != null) targmap * ftarg else ftarg;
7878
val eta = modelmats(0) * in
7979
GLM.preds(eta, eta, mylinks, linkArray, totflops)
80-
val v = GLM.llfun(eta, ftarg, mylinks, linkArray, totflops)
81-
if (putBack >= 0) {ftarg <-- eta}
80+
val v = GLM.llfun(eta, targs, mylinks, linkArray, totflops)
81+
if (putBack >= 0) {targ <-- eta}
8282
if (dweights.asInstanceOf[AnyRef] != null) {
8383
FMat(sum(v dweights, 2) / sum(dweights))
8484
} else {

src/main/scala/BIDMach/models/Regression.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,13 @@ abstract class RegressionModel(override val opts:RegressionModel.Opts) extends M
2727
val data0 = mats(0)
2828
val m = size(data0, 1)
2929
val targetData = mats.length > 1
30-
val d = if (targetData) mats(1).nrows else if (opts.targmap.asInstanceOf[AnyRef] != null) opts.targmap.nrows else opts.targets.nrows
30+
val d = if (opts.targmap.asInstanceOf[AnyRef] != null) {
31+
opts.targmap.nrows
32+
} else if (opts.targets.asInstanceOf[AnyRef] != null) {
33+
opts.targets.nrows
34+
} else {
35+
mats(1).nrows
36+
}
3137
val sdat = (sum(data0,2).t + 0.5f).asInstanceOf[FMat]
3238
sp = sdat / sum(sdat)
3339
println("corpus perplexity=%f" format (math.exp(-(sp ddot ln(sp)))))
@@ -39,9 +45,9 @@ abstract class RegressionModel(override val opts:RegressionModel.Opts) extends M
3945
}
4046
updatemats = new Array[Mat](1)
4147
updatemats(0) = modelmats(0).zeros(modelmats(0).nrows, modelmats(0).ncols)
48+
targmap = if (useGPU && opts.targmap.asInstanceOf[AnyRef] != null) GMat(opts.targmap) else opts.targmap
4249
if (! targetData) {
4350
targets = if (useGPU && opts.targets.asInstanceOf[AnyRef] != null) GMat(opts.targets) else opts.targets
44-
targmap = if (useGPU && opts.targmap.asInstanceOf[AnyRef] != null) GMat(opts.targmap) else opts.targmap
4551
mask = if (useGPU && opts.rmask.asInstanceOf[AnyRef] != null) GMat(opts.rmask) else opts.rmask
4652
}
4753
}

0 commit comments

Comments
 (0)