Skip to content

Commit d9cc29f

Browse files
author
Nicolas Béreux
committed
fix parser for test_dataset and update version
1 parent ceec2f1 commit d9cc29f

3 files changed

Lines changed: 12 additions & 4 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "rbms"
7-
version = "0.4.0"
7+
version = "0.4.1"
88
authors = [
99
{name="Nicolas Béreux", email="nicolas.bereux@gmail.com"},
1010
{name="Aurélien Decelle"},

rbms/dataset/parser.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,17 @@ def add_args_dataset(parser: argparse.ArgumentParser) -> argparse.ArgumentParser
88
dataset_args = parser.add_argument_group("Dataset")
99
dataset_args.add_argument(
1010
"-d",
11-
"--data",
11+
"--dataset",
1212
type=str,
1313
required=True,
14-
help="Name of the dataset ('HGD', 'MNIST', 'BKACE', 'PF00072', 'PF13354'), or path to a data file (type should be .h5 or .fasta)",
14+
help="Path to a data file (type should be .h5 or .fasta)",
15+
)
16+
dataset_args.add_argument(
17+
"--test_dataset",
18+
type=str,
19+
required=False,
20+
default=None,
21+
help="Path to test dataset file (type should be .h5 or .fasta)",
1522
)
1623
dataset_args.add_argument(
1724
"--subset_labels",

rbms/scripts/train_rbm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def train_rbm(args: dict):
3636
num_updates=args["num_updates"], n_save=args["n_save"], spacing=args["spacing"]
3737
)
3838
train_dataset, test_dataset = load_dataset(
39-
dataset_name=args["data"],
39+
dataset_name=args["dataset"],
40+
test_dataset_name=args["test_dataset"],
4041
subset_labels=args["subset_labels"],
4142
use_weights=args["use_weights"],
4243
alphabet=args["alphabet"],

0 commit comments

Comments
 (0)