@@ -322,6 +322,114 @@ def test_insert(scalars_dfs, loc, column, value, allow_duplicates):
322322 pd .testing .assert_frame_equal (bf_df .to_pandas (), pd_df , check_dtype = False )
323323
324324
325+ def test_where_series_cond (scalars_df_index , scalars_pandas_df_index ):
326+ # Condition is dataframe, other is None (as default).
327+ cond_bf = scalars_df_index ["int64_col" ] > 0
328+ cond_pd = scalars_pandas_df_index ["int64_col" ] > 0
329+ bf_result = scalars_df_index .where (cond_bf ).to_pandas ()
330+ pd_result = scalars_pandas_df_index .where (cond_pd )
331+ pandas .testing .assert_frame_equal (bf_result , pd_result )
332+
333+
334+ def test_where_series_multi_index (scalars_df_index , scalars_pandas_df_index ):
335+ # Test when a dataframe has multi-index or multi-columns.
336+ columns = ["int64_col" , "float64_col" ]
337+ dataframe_bf = scalars_df_index [columns ]
338+
339+ dataframe_bf .columns = pd .MultiIndex .from_tuples (
340+ [("str1" , 1 ), ("str2" , 2 )], names = ["STR" , "INT" ]
341+ )
342+ cond_bf = dataframe_bf ["str1" ] > 0
343+
344+ with pytest .raises (NotImplementedError ) as context :
345+ dataframe_bf .where (cond_bf ).to_pandas ()
346+ assert (
347+ str (context .value )
348+ == "The dataframe.where() method does not support multi-index and/or multi-column."
349+ )
350+
351+
352+ def test_where_series_cond_const_other (scalars_df_index , scalars_pandas_df_index ):
353+ # Condition is a series, other is a constant.
354+ columns = ["int64_col" , "float64_col" ]
355+ dataframe_bf = scalars_df_index [columns ]
356+ dataframe_pd = scalars_pandas_df_index [columns ]
357+ dataframe_bf .columns .name = "test_name"
358+ dataframe_pd .columns .name = "test_name"
359+
360+ cond_bf = dataframe_bf ["int64_col" ] > 0
361+ cond_pd = dataframe_pd ["int64_col" ] > 0
362+ other = 0
363+
364+ bf_result = dataframe_bf .where (cond_bf , other ).to_pandas ()
365+ pd_result = dataframe_pd .where (cond_pd , other )
366+ pandas .testing .assert_frame_equal (bf_result , pd_result )
367+
368+
369+ def test_where_series_cond_dataframe_other (scalars_df_index , scalars_pandas_df_index ):
370+ # Condition is a series, other is a dataframe.
371+ columns = ["int64_col" , "float64_col" ]
372+ dataframe_bf = scalars_df_index [columns ]
373+ dataframe_pd = scalars_pandas_df_index [columns ]
374+
375+ cond_bf = dataframe_bf ["int64_col" ] > 0
376+ cond_pd = dataframe_pd ["int64_col" ] > 0
377+ other_bf = - dataframe_bf
378+ other_pd = - dataframe_pd
379+
380+ bf_result = dataframe_bf .where (cond_bf , other_bf ).to_pandas ()
381+ pd_result = dataframe_pd .where (cond_pd , other_pd )
382+ pandas .testing .assert_frame_equal (bf_result , pd_result )
383+
384+
385+ def test_where_dataframe_cond (scalars_df_index , scalars_pandas_df_index ):
386+ # Condition is a dataframe, other is None.
387+ columns = ["int64_col" , "float64_col" ]
388+ dataframe_bf = scalars_df_index [columns ]
389+ dataframe_pd = scalars_pandas_df_index [columns ]
390+
391+ cond_bf = dataframe_bf > 0
392+ cond_pd = dataframe_pd > 0
393+
394+ bf_result = dataframe_bf .where (cond_bf , None ).to_pandas ()
395+ pd_result = dataframe_pd .where (cond_pd , None )
396+ pandas .testing .assert_frame_equal (bf_result , pd_result )
397+
398+
399+ def test_where_dataframe_cond_const_other (scalars_df_index , scalars_pandas_df_index ):
400+ # Condition is a dataframe, other is a constant.
401+ columns = ["int64_col" , "float64_col" ]
402+ dataframe_bf = scalars_df_index [columns ]
403+ dataframe_pd = scalars_pandas_df_index [columns ]
404+
405+ cond_bf = dataframe_bf > 0
406+ cond_pd = dataframe_pd > 0
407+ other_bf = 10
408+ other_pd = 10
409+
410+ bf_result = dataframe_bf .where (cond_bf , other_bf ).to_pandas ()
411+ pd_result = dataframe_pd .where (cond_pd , other_pd )
412+ pandas .testing .assert_frame_equal (bf_result , pd_result )
413+
414+
415+ def test_where_dataframe_cond_dataframe_other (
416+ scalars_df_index , scalars_pandas_df_index
417+ ):
418+ # Condition is a dataframe, other is a dataframe.
419+ columns = ["int64_col" , "float64_col" ]
420+ dataframe_bf = scalars_df_index [columns ]
421+ dataframe_pd = scalars_pandas_df_index [columns ]
422+
423+ cond_bf = dataframe_bf > 0
424+ cond_pd = dataframe_pd > 0
425+ other_bf = dataframe_bf * 2
426+ other_pd = dataframe_pd * 2
427+
428+ bf_result = dataframe_bf .where (cond_bf , other_bf ).to_pandas ()
429+ pd_result = dataframe_pd .where (cond_pd , other_pd )
430+ pandas .testing .assert_frame_equal (bf_result , pd_result )
431+
432+
325433def test_drop_column (scalars_dfs ):
326434 scalars_df , scalars_pandas_df = scalars_dfs
327435 col_name = "int64_col"
0 commit comments