Skip to content

Commit ad5370f

Browse files
committed
Fixed resource leaks in file loading and 'continued' sampling
1 parent 579aa49 commit ad5370f

4 files changed

Lines changed: 87 additions & 91 deletions

File tree

pom.xml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@
101101
<!-- v8.3.0: Support saving sampler from main Java application -->
102102
<!-- v8.4.0: Added support for continued sampling in main, controlled by 'continue' CMD line option -->
103103
<!-- v8.5.0: Support for compressed input files (zip, gz). Improved iteration callback support -->
104-
<version>8.5.0</version>
104+
<!-- v8.5.1: Fixed resource leaks in file loading and 'continued' sampling -->
105+
<version>8.5.1</version>
105106

106107

107108
<name>Partially Collapsed Parallel LDA</name>

src/main/java/cc/mallet/topics/UncollapsedParallelLDA.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2046,11 +2046,16 @@ private void readObject (ObjectInputStream in) throws IOException, ClassNotFound
20462046
lu.checkAndCreateCurrentLogDir(logSuitePath);
20472047
config.setLoggingUtil(lu);
20482048
if(activeSubconfig==null) {
2049-
activeSubconfig = config.getSubConfigs()[0];
2050-
System.out.println("Active subconfig not set, activating first available (" + activeSubconfig + ") ...");
2049+
String [] subconfs = config.getSubConfigs();
2050+
if(subconfs!= null && subconfs.length > 0) {
2051+
System.out.println("Active subconfig not set, activating first available (" + activeSubconfig + ") ...");
2052+
activeSubconfig = subconfs[0];
2053+
config.activateSubconfig(activeSubconfig);
2054+
System.out.println("Activating subconfig: " + activeSubconfig);
2055+
}
2056+
} else {
2057+
config.activateSubconfig(activeSubconfig);
20512058
}
2052-
System.out.println("Activating subconfig: " + activeSubconfig);
2053-
config.activateSubconfig(activeSubconfig);
20542059

20552060
System.out.println("Done Reading config!");
20562061
} catch (ConfigurationException e) {

src/main/java/cc/mallet/topics/tui/ParallelLDA.java

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -208,15 +208,18 @@ void doIteration(LDACommandLineParser cp, LDAConfiguration config, LoggingUtils
208208
}
209209
System.out.println("Scheme: " + whichModel);
210210

211-
InstanceList instances = LDAUtils.loadDataset(config, dataset_fn);
212-
instances.getAlphabet().stopGrowth();
213211

214212
boolean continueSampling = isContinuation(cp);
215213
LDAGibbsSampler model = createModel(config, whichModel);
214+
InstanceList instances = null;
216215
if(continueSampling) {
217216
System.out.println("Continuing sampling from previously stored model...");
218-
initSamplerFromSaved(config, instances, model);
219-
}
217+
initSamplerFromSaved(config, model);
218+
instances = model.getDataset();
219+
} else {
220+
instances = LDAUtils.loadDataset(config, dataset_fn);
221+
instances.getAlphabet().stopGrowth();
222+
}
220223

221224
if(model==null) {
222225
System.out.println("No valid model selected ('" + whichModel + "' is not a recognized model), please select a valid model...");
@@ -414,9 +417,9 @@ void doIteration(LDACommandLineParser cp, LDAConfiguration config, LoggingUtils
414417
System.out.println(new Date() + ": I am done!");
415418
}
416419

417-
private void initSamplerFromSaved(LDAConfiguration config, InstanceList instances, LDAGibbsSampler model) {
420+
private void initSamplerFromSaved(LDAConfiguration config, LDAGibbsSampler model) {
418421
String storedDir = config.getSavedSamplerDirectory(LDAConfiguration.STORED_SAMPLER_DIR_DEFAULT);
419-
LDASamplerWithPhi newModel = LDAUtils.loadStoredSampler(instances, config, storedDir);
422+
LDASamplerWithPhi newModel = LDAUtils.loadStoredSampler(config, storedDir);
420423
// Since the user asked us to continue using this sampler, we assume it is "initiable"
421424
LDASamplerInitiable toInit = (LDASamplerInitiable) model;
422425
toInit.initFrom(newModel);

src/main/java/cc/mallet/util/LDAUtils.java

Lines changed: 67 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -292,19 +292,17 @@ public static InstanceList loadInstancesPrune(String inputFile, String stoplistF
292292
*/
293293
public static InstanceList loadInstancesPrune(String inputFile, String stoplistFile, int pruneCount, boolean keepNumbers,
294294
int maxBufSize, boolean keepConnectors, Alphabet dataAlphabet, LabelAlphabet targetAlphabet) throws FileNotFoundException {
295-
BufferedInputStream in;
296-
try {
297-
in = new BufferedInputStream(streamFromFile(inputFile));
295+
296+
try (BufferedInputStream in = new BufferedInputStream(streamFromFile(inputFile))) {
298297
in.mark(Integer.MAX_VALUE);
298+
return loadInstancesPrune(in, stoplistFile, pruneCount, keepNumbers,
299+
maxBufSize, keepConnectors, dataAlphabet, targetAlphabet);
299300
} catch (IOException e) {
300301
throw new IllegalArgumentException(e);
301302
}
302-
303-
return loadInstancesPrune(in, stoplistFile, pruneCount, keepNumbers,
304-
maxBufSize, keepConnectors, dataAlphabet, targetAlphabet);
305303
}
306-
307-
304+
305+
308306
/**
309307
* Loads instances and prunes away low occurring words
310308
*
@@ -324,7 +322,7 @@ public static InstanceList loadInstancesPrune(BufferedInputStream in, String sto
324322
int dataGroup = 3;
325323
int labelGroup = 2;
326324
int nameGroup = 1; // data, label, name fields
327-
325+
328326
in.mark(Integer.MAX_VALUE);
329327

330328
tokenizer = initTokenizer(stoplistFile, keepNumbers, maxBufSize, keepConnectors);
@@ -475,22 +473,18 @@ public static InputStream streamFromFile(String inputFile) throws FileNotFoundEx
475473
return sout;
476474
}
477475
}
478-
476+
479477
public static InstanceList loadInstancesRaw(String inputFile, String stoplistFile, int keepCount, int maxBufSize,
480-
Alphabet dataAlphabet, LabelAlphabet targetAlphabet) throws FileNotFoundException {
481-
482-
BufferedInputStream in;
483-
try {
484-
in = new BufferedInputStream(streamFromFile(inputFile));
478+
Alphabet dataAlphabet, LabelAlphabet targetAlphabet) throws FileNotFoundException {
479+
try (BufferedInputStream in = new BufferedInputStream(streamFromFile(inputFile))){
485480
in.mark(Integer.MAX_VALUE);
481+
return loadInstancesRaw(in, stoplistFile, keepCount, maxBufSize, dataAlphabet, targetAlphabet);
486482
} catch (IOException e) {
487483
throw new IllegalArgumentException(e);
488484
}
489-
490-
return loadInstancesRaw(in, stoplistFile, keepCount, maxBufSize, dataAlphabet, targetAlphabet);
491485
}
492-
493-
486+
487+
494488
/**
495489
* Loads instances and keeps the <code>keepCount</code> number of words with
496490
* the highest TF-IDF. Does no preprocessing of the input other than splitting
@@ -512,11 +506,11 @@ public static InstanceList loadInstancesRaw(BufferedInputStream in, String stopl
512506
int dataGroup = 3;
513507
int labelGroup = 2;
514508
int nameGroup = 1; // data, label, name fields
515-
509+
516510
in.mark(Integer.MAX_VALUE);
517511

518512
tokenizer = initRawTokenizer(stoplistFile, maxBufSize);
519-
513+
520514
if (keepCount > 0) {
521515
CsvIterator reader = new CsvIterator(
522516
new InputStreamReader(in),
@@ -609,22 +603,18 @@ public static InstanceList loadInstancesRaw(BufferedInputStream in, String stopl
609603

610604
return instances;
611605
}
612-
613-
606+
607+
614608
public static InstanceList loadInstancesKeep(String inputFile, String stoplistFile, int keepCount, boolean keepNumbers,
615-
int maxBufSize, boolean keepConnectors, Alphabet dataAlphabet, LabelAlphabet targetAlphabet) throws FileNotFoundException {
616-
617-
BufferedInputStream in;
618-
try {
619-
in = new BufferedInputStream(streamFromFile(inputFile));
609+
int maxBufSize, boolean keepConnectors, Alphabet dataAlphabet, LabelAlphabet targetAlphabet) throws FileNotFoundException {
610+
try (BufferedInputStream in = new BufferedInputStream(streamFromFile(inputFile))){
620611
in.mark(Integer.MAX_VALUE);
612+
return loadInstancesKeep(in, stoplistFile, keepCount, keepNumbers,
613+
maxBufSize, keepConnectors, dataAlphabet, targetAlphabet);
614+
621615
} catch (IOException e) {
622616
throw new IllegalArgumentException(e);
623617
}
624-
625-
return loadInstancesKeep(in, stoplistFile, keepCount, keepNumbers,
626-
maxBufSize, keepConnectors, dataAlphabet, targetAlphabet);
627-
628618
}
629619

630620
/**
@@ -750,7 +740,7 @@ public static InstanceList loadInstancesKeep(BufferedInputStream in, String stop
750740

751741
/**
752742
* Re-creates the pipe that is used if loading with TF-IDF
753-
* This is ugly as hell, but I wanted ti to be as similar as
743+
* This is ugly as hell, but I wanted it to be as similar as
754744
* possible as when using loadDataset
755745
*
756746
* @param inputFile Input file to load
@@ -770,63 +760,60 @@ public static TfIdfPipe getTfIdfPipe(String inputFile, String stoplistFile, int
770760
int labelGroup = 2;
771761
int nameGroup = 1; // data, label, name fields
772762

773-
tokenizer = initTokenizer(stoplistFile, keepNumbers, maxBufSize, keepConnectors);
774-
775-
BufferedInputStream in;
776-
try {
777-
in = new BufferedInputStream(streamFromFile(inputFile));
778-
} catch (IOException e) {
779-
throw new IllegalArgumentException(e);
780-
}
781-
782763
if (keepCount > 0) {
783-
CsvIterator reader = new CsvIterator(
784-
new InputStreamReader(in),
785-
lineRegex,
786-
dataGroup,
787-
labelGroup,
788-
nameGroup);
764+
tokenizer = initTokenizer(stoplistFile, keepNumbers, maxBufSize, keepConnectors);
765+
try (BufferedInputStream in = new BufferedInputStream(streamFromFile(inputFile))) {
766+
CsvIterator reader = new CsvIterator(
767+
new InputStreamReader(in),
768+
lineRegex,
769+
dataGroup,
770+
labelGroup,
771+
nameGroup);
772+
773+
ArrayList<Pipe> pipes = new ArrayList<Pipe>();
774+
Alphabet alphabet = null;
775+
if(dataAlphabet==null) {
776+
alphabet = new Alphabet();
777+
} else {
778+
alphabet = dataAlphabet;
779+
}
789780

790-
ArrayList<Pipe> pipes = new ArrayList<Pipe>();
791-
Alphabet alphabet = null;
792-
if(dataAlphabet==null) {
793-
alphabet = new Alphabet();
794-
} else {
795-
alphabet = dataAlphabet;
796-
}
781+
CharSequenceLowercase csl = new CharSequenceLowercase();
782+
SimpleTokenizer st = tokenizer.deepClone();
783+
StringList2FeatureSequence sl2fs = new StringList2FeatureSequence(alphabet);
784+
TfIdfPipe tfIdfPipe = new TfIdfPipe(alphabet, null);
797785

798-
CharSequenceLowercase csl = new CharSequenceLowercase();
799-
SimpleTokenizer st = tokenizer.deepClone();
800-
StringList2FeatureSequence sl2fs = new StringList2FeatureSequence(alphabet);
801-
TfIdfPipe tfIdfPipe = new TfIdfPipe(alphabet, null);
786+
pipes.add(csl);
787+
pipes.add(st);
788+
pipes.add(sl2fs);
789+
if (keepCount > 0) {
790+
pipes.add(tfIdfPipe);
791+
}
802792

803-
pipes.add(csl);
804-
pipes.add(st);
805-
pipes.add(sl2fs);
806-
if (keepCount > 0) {
807-
pipes.add(tfIdfPipe);
808-
}
793+
Pipe serialPipe = new SerialPipes(pipes);
809794

810-
Pipe serialPipe = new SerialPipes(pipes);
795+
Iterator<Instance> iterator = serialPipe.newIteratorFrom(reader);
811796

812-
Iterator<Instance> iterator = serialPipe.newIteratorFrom(reader);
797+
int count = 0;
813798

814-
int count = 0;
799+
// We aren't really interested in the instance itself,
800+
// just the total feature counts.
801+
while (iterator.hasNext()) {
802+
count++;
803+
if (count % 100000 == 0) {
804+
System.out.println(count);
805+
}
806+
iterator.next();
807+
}
815808

816-
// We aren't really interested in the instance itself,
817-
// just the total feature counts.
818-
while (iterator.hasNext()) {
819-
count++;
820-
if (count % 100000 == 0) {
821-
System.out.println(count);
809+
if (keepCount > 0) {
810+
tfIdfPipe.addPrunedWordsToStoplist(tokenizer, keepCount);
811+
return tfIdfPipe;
822812
}
823-
iterator.next();
813+
} catch (IOException e) {
814+
throw new IllegalArgumentException(e);
824815
}
825816

826-
if (keepCount > 0) {
827-
tfIdfPipe.addPrunedWordsToStoplist(tokenizer, keepCount);
828-
return tfIdfPipe;
829-
}
830817
} else {
831818
return null;
832819
}
@@ -2411,7 +2398,7 @@ public static InstanceList loadInstancesStrings(String [] doclines, String class
24112398
return instances;
24122399
}
24132400

2414-
public static LDASamplerWithPhi loadStoredSampler(InstanceList trainingset, LDAConfiguration config, String saveDir) {
2401+
public static LDASamplerWithPhi loadStoredSampler(LDAConfiguration config, String saveDir) {
24152402
String configHash = getConfigSetHash(config);
24162403
if(!saveDir.endsWith(File.separator)) saveDir = saveDir + File.separator;
24172404
String samplerFn = saveDir + buildSamplerSaveFilename(configHash);

0 commit comments

Comments
 (0)