|
17 | 17 | import itertools |
18 | 18 | import json |
19 | 19 | import math |
| 20 | +import operator |
20 | 21 |
|
21 | 22 | from collections import Counter |
22 | 23 | from collections import defaultdict |
@@ -922,6 +923,90 @@ def logpdf_joint( |
922 | 923 | multiprocess=self._multiprocess, |
923 | 924 | ) |
924 | 925 |
|
| 926 | + def json_ready_models(self, bdb, population_id, generator_id): |
| 927 | + stattypes = { |
| 928 | + name: stattype for (name, stattype) in |
| 929 | + bdb.sql_execute(''' |
| 930 | + SELECT name, stattype FROM bayesdb_variable |
| 931 | + WHERE (generator_id IS NULL OR generator_id = ?) |
| 932 | + AND population_id = ? |
| 933 | + ''', (generator_id, population_id)) |
| 934 | + } |
| 935 | + |
| 936 | + categories = self._json_ready_categories(bdb, population_id, generator_id) |
| 937 | + |
| 938 | + # Dict mapping colno to variable name |
| 939 | + name_map = core.bayesdb_colno_to_variable_names( |
| 940 | + bdb, population_id, generator_id) |
| 941 | + states = self._engine(bdb, generator_id).states |
| 942 | + model_blobs = [self._json_ready_model(s, name_map) for s in states] |
| 943 | + return { |
| 944 | + "column-statistical-types": stattypes, |
| 945 | + "categories": categories, |
| 946 | + "models": model_blobs |
| 947 | + } |
| 948 | + |
| 949 | + def _json_ready_model(self, state, name_map): |
| 950 | + # state.Zv() is a column partition given as {colnum: viewnum, ...} |
| 951 | + column_groups = itertools.groupby( |
| 952 | + sorted(state.Zv().items()), |
| 953 | + key=operator.itemgetter(1)) |
| 954 | + column_partition = [ |
| 955 | + [name_map[colno] for (colno, _) in block] |
| 956 | + for (_, block) in column_groups |
| 957 | + ] |
| 958 | + |
| 959 | + column_crp_hypers = [view.alpha() for view in state.views.values()] |
| 960 | + |
| 961 | + # All row clusters for all views: e.g. [[[0,1], [2,3,4], [5]], [[0,2,3],[1,4,5]]] |
| 962 | + # This triple comprehension is tricky so here is some explanation. |
| 963 | + # view.Zr() is a row partition given as {rownum: clusternum, ...}. |
| 964 | + # Grouping its items() by cluster number gives one group per cluster, |
| 965 | + # as (clusternum, [(rownum, clusternum), ...]). |
| 966 | + # To get a cluster [rownum, rownum, ...] we just strip off the |
| 967 | + # cluster number. |
| 968 | + clusters_for_views = [ |
| 969 | + [ |
| 970 | + [rowno for (rowno, _) in group] |
| 971 | + for (_, group) in itertools.groupby(sorted(view.Zr().items()), |
| 972 | + key=operator.itemgetter(1)) |
| 973 | + ] |
| 974 | + for view in state.views.values() |
| 975 | + ] |
| 976 | + |
| 977 | + # Each column has a little dict of hyperparameters |
| 978 | + column_hypers = { |
| 979 | + name: state.views[state.Zv()[colno]].dims[colno].hypers |
| 980 | + for (colno, name) in name_map.iteritems() |
| 981 | + } |
| 982 | + |
| 983 | + # Return dump-able blob |
| 984 | + return { |
| 985 | + "column-partition": column_partition, |
| 986 | + "cluster-crp-hyperparameters": column_crp_hypers, |
| 987 | + "clusters": clusters_for_views, |
| 988 | + "column-hypers": column_hypers |
| 989 | + } |
| 990 | + |
| 991 | + def _json_ready_categories(self, bdb, population_id, generator_id): |
| 992 | + name_map = core.bayesdb_colno_to_variable_names(bdb, population_id, generator_id) |
| 993 | + assert len(name_map) > 0 |
| 994 | + # All categories for all categorical variables |
| 995 | + raw_categories = sorted(bdb.sql_execute(''' |
| 996 | + SELECT colno, code, value FROM bayesdb_cgpm_category |
| 997 | + WHERE generator_id = ? |
| 998 | + ''', (generator_id,))) |
| 999 | + # Collate categories by variable |
| 1000 | + groups = { |
| 1001 | + (colno, group) |
| 1002 | + for (colno, group) in |
| 1003 | + itertools.groupby(raw_categories, key=operator.itemgetter(0)) |
| 1004 | + } |
| 1005 | + return { |
| 1006 | + name_map[colno]: {code: value for(_, code, value) in group} |
| 1007 | + for (colno, group) in groups |
| 1008 | + } |
| 1009 | + |
925 | 1010 | def _unique_rowid(self, rowids): |
926 | 1011 | if len(set(rowids)) != 1: |
927 | 1012 | raise ValueError('Multiple-row query: %r' % (list(set(rowids)),)) |
@@ -1334,7 +1419,7 @@ def _get_modelnos(self, bdb, generator_id, modelnos): |
1334 | 1419 |
|
1335 | 1420 | def _convert_subproblems_to_kernel(self, bdb, subproblems, backend): |
1336 | 1421 | # Keys are bayeslite subproblems, entries are cgpm where first element |
1337 | | - # is gpmcc kernel name, and secodn element is lovecat kernel name. |
| 1422 | + # is gpmcc kernel name, and second element is lovecat kernel name. |
1338 | 1423 | if subproblems is None: |
1339 | 1424 | return None |
1340 | 1425 | conversions = { |
|
0 commit comments