@@ -697,55 +697,95 @@ def test_variable_depends_on():
697697 assert variable_depends_on (y , [y ])
698698
699699
700- def test_truncated_graph_inputs ():
701- """
702- * No conditions
703- n - n - (o)
704-
705- * One condition
706- n - (c) - o
707-
708- * Two conditions where on depends on another, both returned
709- (c) - (c) - o
710-
711- * Additional nodes are present
712- (c) - n - o
713- n - (n) -'
700+ class TestTruncatedGraphInputs :
701+ def test_basic (self ):
702+ """
703+ * No conditions
704+ n - n - (o)
705+
706+ * One condition
707+ n - (c) - o
708+
709+ * Two conditions where on depends on another, both returned
710+ (c) - (c) - o
711+
712+ * Additional nodes are present
713+ (c) - n - o
714+ n - (n) -'
715+
716+ * Disconnected condition not returned
717+ (c) - n - o
718+ c
719+
720+ * Disconnected output is present and returned
721+ (c) - (c) - o
722+ (o)
723+
724+ * Condition on itself adds itself
725+ n - (c) - (o/c)
726+ """
727+ x = MyVariable (1 )
728+ x .name = "x"
729+ y = MyVariable (1 )
730+ y .name = "y"
731+ z = MyVariable (1 )
732+ z .name = "z"
733+ x2 = MyOp (x )
734+ x2 .name = "x2"
735+ y2 = MyOp (y , x2 )
736+ y2 .name = "y2"
737+ o = MyOp (y2 )
738+ o2 = MyOp (o )
739+ # No conditions
740+ assert truncated_graph_inputs ([o ]) == [o ]
741+ # One condition
742+ assert truncated_graph_inputs ([o2 ], [y2 ]) == [y2 ]
743+ # Condition on itself adds itself
744+ assert truncated_graph_inputs ([o ], [y2 , o ]) == [o , y2 ]
745+ # Two conditions where on depends on another, both returned
746+ assert truncated_graph_inputs ([o2 ], [y2 , o ]) == [o , y2 ]
747+ # Additional nodes are present
748+ assert truncated_graph_inputs ([o ], [y ]) == [x2 , y ]
749+ # Disconnected condition
750+ assert truncated_graph_inputs ([o2 ], [y2 , z ]) == [y2 ]
751+ # Disconnected output is present
752+ assert truncated_graph_inputs ([o2 , z ], [y2 ]) == [z , y2 ]
753+
754+ def test_repeated_input (self ):
755+ """Test that truncated_graph_inputs does not return repeated inputs."""
756+ x = MyVariable (1 )
757+ x .name = "x"
758+ y = MyVariable (1 )
759+ y .name = "y"
760+
761+ trunc_inp1 = MyOp (x , y )
762+ trunc_inp1 .name = "trunc_inp1"
763+
764+ trunc_inp2 = MyOp (x , y )
765+ trunc_inp2 .name = "trunc_inp2"
766+
767+ o = MyOp (trunc_inp1 , trunc_inp1 , trunc_inp2 , trunc_inp2 )
768+ o .name = "o"
769+
770+ assert truncated_graph_inputs ([o ], [trunc_inp1 ]) == [trunc_inp2 , trunc_inp1 ]
771+
772+ def test_repeated_nested_input (self ):
773+ """Test that truncated_graph_inputs does not return repeated inputs."""
774+ x = MyVariable (1 )
775+ x .name = "x"
776+ y = MyVariable (1 )
777+ y .name = "y"
778+
779+ trunc_inp = MyOp (x , y )
780+ trunc_inp .name = "trunc_inp"
781+
782+ o1 = MyOp (trunc_inp , trunc_inp , x , x )
783+ o1 .name = "o1"
714784
715- * Disconnected condition not returned
716- (c) - n - o
717- c
785+ assert truncated_graph_inputs ([o1 ], [trunc_inp ]) == [x , trunc_inp ]
718786
719- * Disconnected output is present and returned
720- (c) - (c) - o
721- (o)
787+ # Reverse order of inputs
788+ o2 = MyOp ( x , x , trunc_inp , trunc_inp )
789+ o2 . name = "o2"
722790
723- * Condition on itself adds itself
724- n - (c) - (o/c)
725- """
726- x = MyVariable (1 )
727- x .name = "x"
728- y = MyVariable (1 )
729- y .name = "y"
730- z = MyVariable (1 )
731- z .name = "z"
732- x2 = MyOp (x )
733- x2 .name = "x2"
734- y2 = MyOp (y , x2 )
735- y2 .name = "y2"
736- o = MyOp (y2 )
737- o2 = MyOp (o )
738- # No conditions
739- assert truncated_graph_inputs ([o ]) == [o ]
740- # One condition
741- assert truncated_graph_inputs ([o2 ], [y2 ]) == [y2 ]
742- # Condition on itself adds itself
743- assert truncated_graph_inputs ([o ], [y2 , o ]) == [o , y2 ]
744- # Two conditions where on depends on another, both returned
745- assert truncated_graph_inputs ([o2 ], [y2 , o ]) == [o , y2 ]
746- # Additional nodes are present
747- assert truncated_graph_inputs ([o ], [y ]) == [x2 , y ]
748- # Disconnected condition
749- assert truncated_graph_inputs ([o2 ], [y2 , z ]) == [y2 ]
750- # Disconnected output is present
751- assert truncated_graph_inputs ([o2 , z ], [y2 ]) == [z , y2 ]
791+ assert truncated_graph_inputs ([o2 ], [trunc_inp ]) == [trunc_inp , x ]
0 commit comments