@@ -2642,9 +2642,37 @@ def raise_if_sum_is_zero(x):
26422642 s = pd .Series ([- 1 ,0 ,1 ,2 ])
26432643 grouper = s .apply (lambda x : x % 2 )
26442644 grouped = s .groupby (grouper )
2645- self .assertRaises (ValueError ,
2645+ self .assertRaises (TypeError ,
26462646 lambda : grouped .filter (raise_if_sum_is_zero ))
26472647
2648+ def test_filter_bad_shapes (self ):
2649+ df = DataFrame ({'A' : np .arange (8 ), 'B' : list ('aabbbbcc' ), 'C' : np .arange (8 )})
2650+ s = df ['B' ]
2651+ g_df = df .groupby ('B' )
2652+ g_s = s .groupby (s )
2653+
2654+ f = lambda x : x
2655+ self .assertRaises (TypeError , lambda : g_df .filter (f ))
2656+ self .assertRaises (TypeError , lambda : g_s .filter (f ))
2657+
2658+ f = lambda x : x == 1
2659+ self .assertRaises (TypeError , lambda : g_df .filter (f ))
2660+ self .assertRaises (TypeError , lambda : g_s .filter (f ))
2661+
2662+ f = lambda x : np .outer (x , x )
2663+ self .assertRaises (TypeError , lambda : g_df .filter (f ))
2664+ self .assertRaises (TypeError , lambda : g_s .filter (f ))
2665+
2666+ def test_filter_nan_is_false (self ):
2667+ df = DataFrame ({'A' : np .arange (8 ), 'B' : list ('aabbbbcc' ), 'C' : np .arange (8 )})
2668+ s = df ['B' ]
2669+ g_df = df .groupby (df ['B' ])
2670+ g_s = s .groupby (s )
2671+
2672+ f = lambda x : np .nan
2673+ assert_frame_equal (g_df .filter (f ), df .loc [[]])
2674+ assert_series_equal (g_s .filter (f ), s [[]])
2675+
26482676 def test_filter_against_workaround (self ):
26492677 np .random .seed (0 )
26502678 # Series of ints
@@ -2697,6 +2725,29 @@ def test_filter_against_workaround(self):
26972725 new_way = grouped .filter (lambda x : x ['ints' ].mean () > N / 20 )
26982726 assert_frame_equal (new_way .sort_index (), old_way .sort_index ())
26992727
2728+ def test_filter_using_len (self ):
2729+ # BUG GH4447
2730+ df = DataFrame ({'A' : np .arange (8 ), 'B' : list ('aabbbbcc' ), 'C' : np .arange (8 )})
2731+ grouped = df .groupby ('B' )
2732+ actual = grouped .filter (lambda x : len (x ) > 2 )
2733+ expected = DataFrame ({'A' : np .arange (2 , 6 ), 'B' : list ('bbbb' ), 'C' : np .arange (2 , 6 )}, index = np .arange (2 , 6 ))
2734+ assert_frame_equal (actual , expected )
2735+
2736+ actual = grouped .filter (lambda x : len (x ) > 4 )
2737+ expected = df .ix [[]]
2738+ assert_frame_equal (actual , expected )
2739+
2740+ # Series have always worked properly, but we'll test anyway.
2741+ s = df ['B' ]
2742+ grouped = s .groupby (s )
2743+ actual = grouped .filter (lambda x : len (x ) > 2 )
2744+ expected = Series (4 * ['b' ], index = np .arange (2 , 6 ))
2745+ assert_series_equal (actual , expected )
2746+
2747+ actual = grouped .filter (lambda x : len (x ) > 4 )
2748+ expected = s [[]]
2749+ assert_series_equal (actual , expected )
2750+
27002751 def test_groupby_whitelist (self ):
27012752 from string import ascii_lowercase
27022753 letters = np .array (list (ascii_lowercase ))
0 commit comments