Skip to content

Commit d208ec9

Browse files
author
Feras A. Saad
committed
Merge branch '20191202-rees-dump-models'
2 parents 7cf258d + d82d33f commit d208ec9

File tree

4 files changed

+123
-2
lines changed

4 files changed

+123
-2
lines changed

src/backend.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,3 +285,13 @@ def logpdf_joint(self, bdb, generator_id, modelnos, rowid, targets,
285285
`modelno` is a model number or `None`, meaning all models.
286286
"""
287287
raise NotImplementedError
288+
289+
def json_ready_models(self, bdb, population_id, generator_id):
290+
"""Return a data object capturing model information
291+
that is ready to be written in JSON syntax.
292+
293+
The intent is that this information, combined with the data
294+
table (not itself dumped here), should be sufficient to simulate
295+
any of the models.
296+
"""
297+
raise NotImplementedError

src/backends/cgpm_backend.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import itertools
1818
import json
1919
import math
20+
import operator
2021

2122
from collections import Counter
2223
from collections import defaultdict
@@ -922,6 +923,90 @@ def logpdf_joint(
922923
multiprocess=self._multiprocess,
923924
)
924925

926+
def json_ready_models(self, bdb, population_id, generator_id):
927+
stattypes = {
928+
name: stattype for (name, stattype) in
929+
bdb.sql_execute('''
930+
SELECT name, stattype FROM bayesdb_variable
931+
WHERE (generator_id IS NULL OR generator_id = ?)
932+
AND population_id = ?
933+
''', (generator_id, population_id))
934+
}
935+
936+
categories = self._json_ready_categories(bdb, population_id, generator_id)
937+
938+
# Dict mapping colno to variable name
939+
name_map = core.bayesdb_colno_to_variable_names(
940+
bdb, population_id, generator_id)
941+
states = self._engine(bdb, generator_id).states
942+
model_blobs = [self._json_ready_model(s, name_map) for s in states]
943+
return {
944+
"column-statistical-types": stattypes,
945+
"categories": categories,
946+
"models": model_blobs
947+
}
948+
949+
def _json_ready_model(self, state, name_map):
950+
# state.Zv() is a column partition given as {colnum: viewnum, ...}
951+
column_groups = itertools.groupby(
952+
sorted(state.Zv().items()),
953+
key=operator.itemgetter(1))
954+
column_partition = [
955+
[name_map[colno] for (colno, _) in block]
956+
for (_, block) in column_groups
957+
]
958+
959+
column_crp_hypers = [view.alpha() for view in state.views.values()]
960+
961+
# All row clusters for all views: e.g. [[[0,1], [2,3,4], [5]], [[0,2,3],[1,4,5]]]
962+
# This triple comprehension is tricky so here is some explanation.
963+
# view.Zr() is a row partition given as {rownum: clusternum, ...}.
964+
# Grouping its items() by cluster number gives one group per cluster,
965+
# as (clusternum, [(rownum, clusternum), ...]).
966+
# To get a cluster [rownum, rownum, ...] we just strip off the
967+
# cluster number.
968+
clusters_for_views = [
969+
[
970+
[rowno for (rowno, _) in group]
971+
for (_, group) in itertools.groupby(sorted(view.Zr().items()),
972+
key=operator.itemgetter(1))
973+
]
974+
for view in state.views.values()
975+
]
976+
977+
# Each column has a little dict of hyperparameters
978+
column_hypers = {
979+
name: state.views[state.Zv()[colno]].dims[colno].hypers
980+
for (colno, name) in name_map.iteritems()
981+
}
982+
983+
# Return dump-able blob
984+
return {
985+
"column-partition": column_partition,
986+
"cluster-crp-hyperparameters": column_crp_hypers,
987+
"clusters": clusters_for_views,
988+
"column-hypers": column_hypers
989+
}
990+
991+
def _json_ready_categories(self, bdb, population_id, generator_id):
992+
name_map = core.bayesdb_colno_to_variable_names(bdb, population_id, generator_id)
993+
assert len(name_map) > 0
994+
# All categories for all categorical variables
995+
raw_categories = sorted(bdb.sql_execute('''
996+
SELECT colno, code, value FROM bayesdb_cgpm_category
997+
WHERE generator_id = ?
998+
''', (generator_id,)))
999+
# Collate categories by variable
1000+
groups = {
1001+
(colno, group)
1002+
for (colno, group) in
1003+
itertools.groupby(raw_categories, key=operator.itemgetter(0))
1004+
}
1005+
return {
1006+
name_map[colno]: {code: value for(_, code, value) in group}
1007+
for (colno, group) in groups
1008+
}
1009+
9251010
def _unique_rowid(self, rowids):
9261011
if len(set(rowids)) != 1:
9271012
raise ValueError('Multiple-row query: %r' % (list(set(rowids)),))
@@ -1334,7 +1419,7 @@ def _get_modelnos(self, bdb, generator_id, modelnos):
13341419

13351420
def _convert_subproblems_to_kernel(self, bdb, subproblems, backend):
13361421
# Keys are bayeslite subproblems, entries are cgpm where first element
1337-
# is gpmcc kernel name, and secodn element is lovecat kernel name.
1422+
# is gpmcc kernel name, and second element is lovecat kernel name.
13381423
if subproblems is None:
13391424
return None
13401425
conversions = {

src/core.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def bayesdb_variable_numbers(bdb, population_id, generator_id):
343343
return [colno for (colno,) in cursor]
344344

345345
def bayesdb_variable_name(bdb, population_id, generator_id, colno):
346-
"""Return the name a population variable."""
346+
"""Return the name of a population variable."""
347347
cursor = bdb.sql_execute('''
348348
SELECT name FROM bayesdb_variable
349349
WHERE population_id = ?
@@ -352,6 +352,15 @@ def bayesdb_variable_name(bdb, population_id, generator_id, colno):
352352
''', (population_id, generator_id, colno))
353353
return cursor_value(cursor)
354354

355+
def bayesdb_colno_to_variable_names(bdb, population_id, generator_id):
356+
"""Return a dictionary that maps column number to variable name in population."""
357+
cursor = bdb.sql_execute('''
358+
SELECT colno, name FROM bayesdb_variable
359+
WHERE population_id = ?
360+
AND (generator_id IS NULL OR generator_id = ?)
361+
''', (population_id, generator_id))
362+
return {colno: name for (colno, name) in cursor}
363+
355364
def bayesdb_variable_stattype(bdb, population_id, generator_id, colno):
356365
"""Return the statistical type of a population variable."""
357366
sql = '''

tests/test_core.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,3 +682,20 @@ def test_bayesdb_implicit_population_generator_rename():
682682
assert not core.bayesdb_has_generator(bdb, population_id, 't')
683683
assert core.bayesdb_has_generator(bdb, population_id2, 't2')
684684
assert generator_id2 == generator_id
685+
686+
# Not sure where this test belongs. It needs a test bdb such as t1(),
687+
# and something to analyze it, and this test file provides these.
688+
689+
def test_json_ready_models():
690+
with analyzed_bayesdb_population(t1(), 1, 1) as (bdb, pop_id, gen_id):
691+
assert len(bdb.backends) > 0
692+
population_name = core.bayesdb_population_name(bdb, pop_id)
693+
assert core.bayesdb_has_population(bdb, population_name)
694+
j = bdb.backends['cgpm'].json_ready_models(bdb, pop_id, gen_id)
695+
for m in j["models"]:
696+
assert len(m["clusters"]) > 0
697+
# This is handy debugging code (lets you look at the model)
698+
# that can be enabled manually when needed.
699+
with tempfile.NamedTemporaryFile(
700+
prefix='bayeslite-models', delete=False) as f:
701+
json.dump(j, f, indent=2)

0 commit comments

Comments
 (0)