Skip to content

Commit 47afe2f

Browse files
committed
Fix bug: column partitions were were incorrectly mapped to view partitions.
As with the previous bug fix -- I suspect that again, itertools.groupby was used incorrectly. To avoid any mistakes, I am now referring to view indices directly, via a named list of view indices.
1 parent e2042f7 commit 47afe2f

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

src/backends/cgpm_backend.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -947,17 +947,19 @@ def json_ready_models(self, bdb, population_id, generator_id):
947947
}
948948

949949
def _json_ready_model(self, state, name_map):
950+
# We work off the same view-indeces everywhere.
951+
view_indices = sorted(state.views.keys())
950952
# state.Zv() is a column partition given as {colnum: viewnum, ...}
951-
column_groups = itertools.groupby(
952-
sorted(state.Zv().items()),
953-
key=operator.itemgetter(1))
954953
column_partition = [
955-
[name_map[colno] for (colno, _) in block]
956-
for (_, block) in column_groups
954+
[name_map[colno] for colno, current_view_index in state.Zv().items()
955+
if current_view_index==view_index]
956+
for view_index in view_indices
957957
]
958958

959-
column_crp_hypers = [view.alpha() for view in state.views.values()]
960-
959+
column_crp_hypers = [
960+
state.views[view_index].alpha()
961+
for view_index in view_indices
962+
]
961963
# All row clusters for all views: e.g. [[[0,1], [2,3,4], [5]], [[0,2,3],[1,4,5]]]
962964
# This triple comprehension is tricky so here is some explanation.
963965
# view.Zr() is a row partition given as {rownum: clusternum, ...}.
@@ -969,12 +971,12 @@ def _json_ready_model(self, state, name_map):
969971
[
970972
sorted([
971973
rowid
972-
for (rowid, current_cluster_index) in view.Zr().items()
974+
for (rowid, current_cluster_index) in state.views[view_index].Zr().items()
973975
if current_cluster_index==cluster_index
974976
])
975-
for cluster_index in sorted(set(view.Zr().values()))
977+
for cluster_index in sorted(set(state.views[view_index].Zr().values()))
976978
]
977-
for view in state.views.values()
979+
for view_index in view_indices
978980
]
979981

980982
# Each column has a little dict of hyperparameters

0 commit comments

Comments
 (0)