@@ -29,15 +29,21 @@ def _add_cast_to_inputs(graph, node, supported_dtypes, target_dtype):
2929 graph .copy_shape (inp , inp_cast .output [0 ])
3030 graph .set_dtype (inp_cast .output [0 ], target_dtype )
3131
32-
33- def _add_cast_to_same_type_to_inputs (graph , node ):
32+ def _add_cast_to_same_type_to_inputs (graph , node , supported_dtypes , target_dtype ):
3433 common_dtype = graph .get_dtype (node .input [0 ])
34+ if common_dtype not in supported_dtypes :
35+ common_dtype = target_dtype
3536
36- for inp in node .input [ 1 :] :
37+ for inp in node .input :
3738 if graph .get_dtype (inp ) != common_dtype :
3839 inp_cast = graph .insert_new_node_on_input (node , "Cast" , inp , to = common_dtype )
3940 graph .copy_shape (inp , inp_cast .output [0 ])
4041 graph .set_dtype (inp_cast .output [0 ], common_dtype )
42+ if graph .is_const (inp ) and graph .get_tensor_value (inp ) == '' :
43+ # Convert '' string constant to -1 int
44+ # https://github.com/tensorflow/tensorflow/blob/4e7f0185c70faf35e12acbfe381a729d1e6cc38c/tensorflow/python/feature_column/feature_column.py#L2286
45+ const_node = graph .get_node_by_output (inp )
46+ const_node .set_tensor_value (utils .np .array (- 1 ))
4147
4248
4349@tf_op ("LogicalNot" , onnx_op = "Not" )
@@ -92,8 +98,24 @@ def version_7(cls, ctx, node, **kwargs):
9298
9399 @classmethod
94100 def version_11 (cls , ctx , node , ** kwargs ):
95- # starting with opset-11, equal supports all types (but both operands must be of the same type)
96- _add_cast_to_same_type_to_inputs (ctx , node )
101+ # starting with opset-11, equal supports all numerical types (but both operands must be of the same type)
102+ # string type is not supported
103+ supported_dtypes = [
104+ TensorProto .BOOL ,
105+ TensorProto .DOUBLE ,
106+ TensorProto .FLOAT ,
107+ TensorProto .FLOAT16 ,
108+ TensorProto .INT8 ,
109+ TensorProto .INT16 ,
110+ TensorProto .INT32 ,
111+ TensorProto .INT64 ,
112+ TensorProto .UINT8 ,
113+ TensorProto .UINT16 ,
114+ TensorProto .UINT32 ,
115+ TensorProto .UINT64
116+ ]
117+ target_dtype = TensorProto .INT32
118+ _add_cast_to_same_type_to_inputs (ctx , node , supported_dtypes , target_dtype )
97119 need_not = node .type == "NotEqual"
98120 if need_not :
99121 node .type = "Equal"
0 commit comments