I'm having difficulties working with tf.contrib.data.Dataset API and wondered if some of you could help. I wanted to transform the entire skip-gram pre-processing of word2vec into this paradigm to play with the API a little bit, it involves the following operations:
- Sequence of tokens are loaded dynamically (to avoid loading all dataset in memory at a time), say we then start with a
Stream(to be understood as Scala's way, all data is not in memory but loaded when access is needed) of sequence of tokens:seq_tokens. - From any of these
seq_tokenswe extract skip-grams with a python function that returns a list of tuples(token, context). - Select for features the column of
tokens and for label the column ofcontexts.
In pseudo-code to make it clearer it would look like above. We should be taking advantage of the framework parallelism system not to load by ourselves the data, so I would do something like first load in memory only the indices of sequences, then load sequences (inside a map, hence if not all lines are processed synchronously, data is loaded asynchronously and there's no OOM to fear), and apply a function on those sequences of tokens that would create a varying number of skip-grams that needs to be flattened. In this end, I would formally end up with data being of shape (#lines=number of skip-grams generated, #columns=2).
data = range(1:N) .map(i => load(i): Seq[String]) // load: Int -> Seq[String] loads dynamically a sequence of tokens (sequences have varying length) .flat_map(s => skip_gram(s)) // skip_gram: Seq[String] -> Seq[(String, String)] with output length features = data[0] // features lables = data[1] // labels I've tried naively to do so with Dataset's API but I'm stuck, I can do something like:
iterator = ( tf.contrib.data.Dataset.range(N) .map(lambda i: tf.py_func(load_data, [i], [tf.int32, tf.int32])) // (1) .flat_map(?) // (2) .make_one_shot_iterator() ) (1) TensorFlow's not happy here because sequences loaded have differents lengths...
(2) Haven't managed yet to do the skip-gram part... I actually just want to call a python function that computes a sequence (of variable size) of skip-grams and flatten it so that if the return type is a matrix, then each line should be understood as a new line of the output Dataset.
Thanks a lot if anyone has any idea, and don't hesitate if I forgot to mention useful precisions...