Skip to content

Commit f53e356

Browse files
committed
Fix bug in truncated_graph_inputs
It could return duplicated truncated inputs before the changes
1 parent 16d1cbe commit f53e356

File tree

2 files changed

+34
-14
lines changed

2 files changed

+34
-14
lines changed

pytensor/graph/basic.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,14 +1056,20 @@ def truncated_graph_inputs(
10561056
truncated_inputs.append(node)
10571057
# no more actions are needed
10581058
return truncated_inputs
1059+
10591060
blockers: Set[Variable] = set(ancestors_to_include)
10601061
# enforce O(1) check for node in ancestors to include
10611062
ancestors_to_include = blockers.copy()
10621063

10631064
while candidates:
10641065
# on any new candidate
10651066
node = candidates.pop()
1066-
# check if the node is independent, never go above blockers
1067+
1068+
# There was a repeated reference to this node, we have already investigated it
1069+
if node in truncated_inputs:
1070+
continue
1071+
1072+
# check if the node is independent, never go above blockers;
10671073
# blockers are independent nodes and ancestors to include
10681074
if node in ancestors_to_include:
10691075
# The case where node is in ancestors to include so we check if it depends on others
@@ -1073,30 +1079,25 @@ def truncated_graph_inputs(
10731079
# should be added to truncated_inputs
10741080
truncated_inputs.append(node)
10751081
if dependent:
1076-
# if the ancestors to include is still dependent we need to go above,
1077-
# the search is not yet finished
1078-
# the node _has_ to have owner to be dependent
1079-
# so we do not check it
1080-
# and populate search to go above
1082+
# if the ancestors to include is still dependent we need to go above, the search is not yet finished
10811083
# owner can never be None for a dependent node
10821084
candidates.extend(node.owner.inputs)
10831085
else:
10841086
# A regular node to check
10851087
dependent = variable_depends_on(node, blockers)
1086-
# all regular nodes fall to blockes
1088+
# all regular nodes fall to blockers
10871089
# 1. it is dependent - further search irrelevant
10881090
# 2. it is independent - the search node is inside the closure
10891091
blockers.add(node)
10901092
# if we've found an independent node and it is not in blockers so far
1091-
# it is a new indepenent node not present in ancestors to include
1092-
if not dependent:
1093-
# we've found an independent node
1094-
# do not search beyond
1095-
truncated_inputs.append(node)
1096-
else:
1097-
# populate search otherwise
1093+
# it is a new independent node not present in ancestors to include
1094+
if dependent:
1095+
# populate search if it's not an independent node
10981096
# owner can never be None for a dependent node
10991097
candidates.extend(node.owner.inputs)
1098+
else:
1099+
# otherwise, do not search beyond
1100+
truncated_inputs.append(node)
11001101
return truncated_inputs
11011102

11021103

tests/graph/test_basic.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -749,3 +749,22 @@ def test_truncated_graph_inputs():
749749
assert truncated_graph_inputs([o2], [y2, z]) == [y2]
750750
# Disconnected output is present
751751
assert truncated_graph_inputs([o2, z], [y2]) == [z, y2]
752+
753+
754+
def test_truncated_graph_inputs_repeated_input():
755+
"""Test that truncated_graph_inputs does not return repeated inputs."""
756+
x = MyVariable(1)
757+
x.name = "x"
758+
y = MyVariable(1)
759+
y.name = "y"
760+
761+
trunc_inp1 = MyOp(x, y)
762+
trunc_inp1.name = "trunc_inp1"
763+
764+
trunc_inp2 = MyOp(x, y)
765+
trunc_inp2.name = "trunc_inp2"
766+
767+
o = MyOp(trunc_inp1, trunc_inp1, trunc_inp2, trunc_inp2)
768+
o.name = "o"
769+
770+
assert truncated_graph_inputs([o], [trunc_inp1]) == [trunc_inp2, trunc_inp1]

0 commit comments

Comments
 (0)