diff --git a/torchvision/datasets/celeba.py b/torchvision/datasets/celeba.py index e3ee8ee62ea..ae5d5d36867 100644 --- a/torchvision/datasets/celeba.py +++ b/torchvision/datasets/celeba.py @@ -89,6 +89,9 @@ def __init__( else: self.target_type = [target_type] + for t in self.target_type: + verify_str_arg(t, "target_type", ("attr", "identity", "bbox", "landmarks")) + if not self.target_type and self.target_transform is not None: raise RuntimeError("target_transform is specified but target_type is empty") @@ -185,9 +188,6 @@ def __getitem__(self, index: int) -> tuple[Any, Any]: target.append(self.bbox[index, :]) elif t == "landmarks": target.append(self.landmarks_align[index, :]) - else: - # TODO: refactor with utils.verify_str_arg - raise ValueError(f'Target type "{t}" is not recognized.') if self.transform is not None: X = self.transform(X)