diff --git a/lib/transforms.py b/lib/transforms.py index 7882d37..68d2bcf 100644 --- a/lib/transforms.py +++ b/lib/transforms.py @@ -140,7 +140,7 @@ def __call__(self, coords, feats, labels): ############################## class RandomDropout(object): - def __init__(self, dropout_ratio=0.2, dropout_application_ratio=0.5): + def __init__(self, dropout_ratio=0.2, dropout_application_ratio=0.2): """ upright_axis: axis index among x,y,z, i.e. 2 for z """ @@ -148,7 +148,7 @@ def __init__(self, dropout_ratio=0.2, dropout_application_ratio=0.5): self.dropout_application_ratio = dropout_application_ratio def __call__(self, coords, feats, labels): - if random.random() < self.dropout_ratio: + if random.random() < self.dropout_application_ratio: N = len(coords) inds = np.random.choice(N, int(N * (1 - self.dropout_ratio)), replace=False) return coords[inds], feats[inds], labels[inds]