@@ -1506,8 +1506,10 @@ def concat(
15061506 blocks : typing .List [Block ] = [self , * other ]
15071507 if ignore_index :
15081508 blocks = [block .reset_index () for block in blocks ]
1509-
1510- result_labels = _align_indices (blocks )
1509+ level_names = None
1510+ else :
1511+ level_names , level_types = _align_indices (blocks )
1512+ blocks = [_cast_index (block , level_types ) for block in blocks ]
15111513
15121514 index_nlevels = blocks [0 ].index .nlevels
15131515
@@ -1522,7 +1524,7 @@ def concat(
15221524 result_expr ,
15231525 index_columns = list (result_expr .column_ids )[:index_nlevels ],
15241526 column_labels = aligned_blocks [0 ].column_labels ,
1525- index_labels = result_labels ,
1527+ index_labels = level_names ,
15261528 )
15271529 if ignore_index :
15281530 result_block = result_block .reset_index ()
@@ -1783,16 +1785,40 @@ def block_from_local(data) -> Block:
17831785 )
17841786
17851787
1788+ def _cast_index (block : Block , dtypes : typing .Sequence [bigframes .dtypes .Dtype ]):
1789+ original_block = block
1790+ result_ids = []
1791+ for idx_id , idx_dtype , target_dtype in zip (
1792+ block .index_columns , block .index_dtypes , dtypes
1793+ ):
1794+ if idx_dtype != target_dtype :
1795+ block , result_id = block .apply_unary_op (idx_id , ops .AsTypeOp (target_dtype ))
1796+ result_ids .append (result_id )
1797+ else :
1798+ result_ids .append (idx_id )
1799+
1800+ expr = block .expr .select_columns ((* result_ids , * original_block .value_columns ))
1801+ return Block (
1802+ expr ,
1803+ index_columns = result_ids ,
1804+ column_labels = original_block .column_labels ,
1805+ index_labels = original_block .index_labels ,
1806+ )
1807+
1808+
17861809def _align_block_to_schema (
17871810 block : Block , schema : dict [Label , bigframes .dtypes .Dtype ]
17881811) -> Block :
1789- """For a given schema, remap block to schema by reordering columns and inserting nulls."""
1812+ """For a given schema, remap block to schema by reordering columns, and inserting nulls."""
17901813 col_ids : typing .Tuple [str , ...] = ()
17911814 for label , dtype in schema .items ():
1792- # TODO: Support casting to lcd type - requires mixed type support
17931815 matching_ids : typing .Sequence [str ] = block .label_to_col_id .get (label , ())
17941816 if len (matching_ids ) > 0 :
17951817 col_id = matching_ids [- 1 ]
1818+ col_dtype = block .expr .get_column_type (col_id )
1819+ if dtype != col_dtype :
1820+ # If _align_schema worked properly, this should always be an upcast
1821+ block , col_id = block .apply_unary_op (col_id , ops .AsTypeOp (dtype ))
17961822 col_ids = (* col_ids , col_id )
17971823 else :
17981824 block , null_column = block .create_constant (None , dtype = dtype )
@@ -1810,38 +1836,44 @@ def _align_schema(
18101836 return functools .reduce (reduction , schemas )
18111837
18121838
1813- def _align_indices (blocks : typing .Sequence [Block ]) -> typing .Sequence [Label ]:
1814- """Validates that the blocks have compatible indices and returns the resulting label names."""
1839+ def _align_indices (
1840+ blocks : typing .Sequence [Block ],
1841+ ) -> typing .Tuple [typing .Sequence [Label ], typing .Sequence [bigframes .dtypes .Dtype ]]:
1842+ """Validates that the blocks have compatible indices and returns the resulting label names and dtypes."""
18151843 names = blocks [0 ].index .names
18161844 types = blocks [0 ].index .dtypes
1845+
18171846 for block in blocks [1 :]:
18181847 if len (names ) != block .index .nlevels :
18191848 raise NotImplementedError (
18201849 f"Cannot combine indices with different number of levels. Use 'ignore_index'=True. { constants .FEEDBACK_LINK } "
18211850 )
1822- if block .index .dtypes != types :
1823- raise NotImplementedError (
1824- f"Cannot combine different index dtypes. Use 'ignore_index'=True. { constants .FEEDBACK_LINK } "
1825- )
18261851 names = [
18271852 lname if lname == rname else None
18281853 for lname , rname in zip (names , block .index .names )
18291854 ]
1830- return names
1855+ types = [
1856+ bigframes .dtypes .lcd_type_or_throw (ltype , rtype )
1857+ for ltype , rtype in zip (types , block .index .dtypes )
1858+ ]
1859+ types = typing .cast (typing .Sequence [bigframes .dtypes .Dtype ], types )
1860+ return names , types
18311861
18321862
18331863def _combine_schema_inner (
18341864 left : typing .Dict [Label , bigframes .dtypes .Dtype ],
18351865 right : typing .Dict [Label , bigframes .dtypes .Dtype ],
18361866) -> typing .Dict [Label , bigframes .dtypes .Dtype ]:
18371867 result = dict ()
1838- for label , type in left .items ():
1868+ for label , left_type in left .items ():
18391869 if label in right :
1840- if type != right [label ]:
1870+ right_type = right [label ]
1871+ output_type = bigframes .dtypes .lcd_type (left_type , right_type )
1872+ if output_type is None :
18411873 raise ValueError (
18421874 f"Cannot concat rows with label { label } due to mismatched types. { constants .FEEDBACK_LINK } "
18431875 )
1844- result [label ] = type
1876+ result [label ] = output_type
18451877 return result
18461878
18471879
@@ -1850,15 +1882,20 @@ def _combine_schema_outer(
18501882 right : typing .Dict [Label , bigframes .dtypes .Dtype ],
18511883) -> typing .Dict [Label , bigframes .dtypes .Dtype ]:
18521884 result = dict ()
1853- for label , type in left .items ():
1854- if (label in right ) and (type != right [label ]):
1855- raise ValueError (
1856- f"Cannot concat rows with label { label } due to mismatched types. { constants .FEEDBACK_LINK } "
1857- )
1858- result [label ] = type
1859- for label , type in right .items ():
1885+ for label , left_type in left .items ():
1886+ if label not in right :
1887+ result [label ] = left_type
1888+ else :
1889+ right_type = right [label ]
1890+ output_type = bigframes .dtypes .lcd_type (left_type , right_type )
1891+ if output_type is None :
1892+ raise NotImplementedError (
1893+ f"Cannot concat rows with label { label } due to mismatched types. { constants .FEEDBACK_LINK } "
1894+ )
1895+ result [label ] = output_type
1896+ for label , right_type in right .items ():
18601897 if label not in left :
1861- result [label ] = type
1898+ result [label ] = right_type
18621899 return result
18631900
18641901
0 commit comments