@@ -102,11 +102,11 @@ def __init__(
102102 ):
103103 """Construct a block object, will create default index if no index columns specified."""
104104 index_columns = list (index_columns )
105- if index_labels :
105+ if index_labels is not None :
106106 index_labels = list (index_labels )
107107 if len (index_labels ) != len (index_columns ):
108108 raise ValueError (
109- "'index_columns' and 'index_labels' must have equal length"
109+ f "'index_columns' (size { len ( index_columns ) } ) and 'index_labels' (size { len ( index_labels ) } ) must have equal length"
110110 )
111111 if len (index_columns ) == 0 :
112112 new_index_col_id = guid .generate_guid ()
@@ -1089,6 +1089,46 @@ def summarize(
10891089 labels = self ._get_labels_for_columns (column_ids )
10901090 return Block (expr , column_labels = labels , index_columns = [label_col_id ])
10911091
1092+ def corr (self ):
1093+ """Returns a block object to compute the self-correlation on this block."""
1094+ aggregations = [
1095+ (
1096+ ex .BinaryAggregation (
1097+ agg_ops .CorrOp (), ex .free_var (left_col ), ex .free_var (right_col )
1098+ ),
1099+ f"{ left_col } -{ right_col } " ,
1100+ )
1101+ for left_col in self .value_columns
1102+ for right_col in self .value_columns
1103+ ]
1104+ expr = self .expr .aggregate (aggregations )
1105+
1106+ index_col_ids = [
1107+ guid .generate_guid () for i in range (self .column_labels .nlevels )
1108+ ]
1109+ input_count = len (self .value_columns )
1110+ unpivot_columns = tuple (
1111+ (
1112+ guid .generate_guid (),
1113+ tuple (expr .column_ids [input_count * i : input_count * (i + 1 )]),
1114+ )
1115+ for i in range (input_count )
1116+ )
1117+ labels = self ._get_labels_for_columns (self .value_columns )
1118+
1119+ expr = expr .unpivot (
1120+ row_labels = labels ,
1121+ index_col_ids = index_col_ids ,
1122+ unpivot_columns = unpivot_columns ,
1123+ )
1124+
1125+ return Block (
1126+ expr ,
1127+ column_labels = self .column_labels ,
1128+ index_columns = index_col_ids ,
1129+ index_labels = self .column_labels .names ,
1130+ )
1131+
10921132 def _standard_stats (self , column_id ) -> typing .Sequence [agg_ops .UnaryAggregateOp ]:
10931133 """
10941134 Gets a standard set of stats to preemptively fetch for a column if
@@ -1889,7 +1929,7 @@ def to_pandas(self) -> pd.Index:
18891929 df = expr .session ._rows_to_dataframe (results , dtypes )
18901930 df = df .set_index (index_columns )
18911931 index = df .index
1892- index .names = list (self ._block ._index_labels )
1932+ index .names = list (self ._block ._index_labels ) # type:ignore
18931933 return index
18941934
18951935 def resolve_level (self , level : LevelsType ) -> typing .Sequence [str ]:
0 commit comments