diff --git a/tests/unit_tests/data_validation/test_TabularDescriptionTables.py b/tests/unit_tests/data_validation/test_TabularDescriptionTables.py index 3f0615957..0a68e466f 100644 --- a/tests/unit_tests/data_validation/test_TabularDescriptionTables.py +++ b/tests/unit_tests/data_validation/test_TabularDescriptionTables.py @@ -82,3 +82,23 @@ def test_table_contents(self): datetime_table = next(df for df in result if "Datetime Variable" in df.columns) self.assertEqual(len(datetime_table), 2) # Should have 2 datetime columns self.assertTrue("Earliest Date" in datetime_table.columns) + + def test_bool_columns_included_in_categorical(self): + df = pd.DataFrame( + { + "flag": pd.array([True, False, True, False, True], dtype=bool), + "active": pd.array([False, False, True, True, False], dtype=bool), + "num": [1, 2, 3, 4, 5], + } + ) + vm_dataset = vm.init_dataset(input_id="bool_dataset", dataset=df, __log=False) + result = TabularDescriptionTables(vm_dataset) + + categorical_table = next( + (t for t in result if "Categorical Variable" in t.columns), None + ) + self.assertIsNotNone(categorical_table, "Categorical table should be present") + categorical_vars = categorical_table["Categorical Variable"].tolist() + self.assertIn("flag", categorical_vars) + self.assertIn("active", categorical_vars) + self.assertEqual(len(categorical_vars), 2) diff --git a/validmind/tests/data_validation/TabularDescriptionTables.py b/validmind/tests/data_validation/TabularDescriptionTables.py index be9ff0c59..c8fb20609 100644 --- a/validmind/tests/data_validation/TabularDescriptionTables.py +++ b/validmind/tests/data_validation/TabularDescriptionTables.py @@ -177,7 +177,7 @@ def get_summary_statistics_datetime(dataset, datetime_fields): def get_categorical_columns(dataset): categorical_columns = dataset.df.select_dtypes( - include=["object", "category"] + include=["object", "category", "bool"] ).columns.tolist() return categorical_columns