3

I'm using TF 2.2 and I'm trying to use tf.data to create a pipeline.

The following works fine:

def load_image(filePath, label): print('Loading File: {}' + filePath) raw_bytes = tf.io.read_file(filePath) image = tf.io.decode_image(raw_bytes, expand_animations = False) return image, label # TrainDS Pipeline trainDS = getDataset() trainDS = trainDS.shuffle(size['train']) trainDS = trainDS.map(load_image, num_parallel_calls=AUTOTUNE) for d in trainDS: print('Image: {} - Label: {}'.format(d[0], d[1])) 

I would like to use the load_image() with the Dataset.interleave(). Then I tried:

# TrainDS Pipeline trainDS = getDataset() trainDS = trainDS.shuffle(size['train']) trainDS = trainDS.interleave(lambda x, y: load_image_with_label(x, y), cycle_length=4) for d in trainDS: print('Image: {} - Label: {}'.format(d[0], d[1])) 

But I'm getting the following error:

Exception has occurred: TypeError `map_func` must return a `Dataset` object. Got <class 'tuple'> File "/data/dev/train_daninhas.py", line 44, in <module> trainDS = trainDS.interleave(lambda x, y: load_image_with_label(x, y), cycle_length=4) 

How can I adapt my code to have the Dataset.interleave() working with the load_image() to read the images in parallel ?

2
  • Did the below solution work? Commented Jun 3, 2020 at 12:13
  • 1
    @ParthasarathySubburaj Yes. I was still working on that. Thanks. Commented Jun 3, 2020 at 17:12

1 Answer 1

5
+50

As the error suggests, you need to modify the load_image so that it return a Dataset object, I have shown an example with two images on how to go about doing it in tensorflow 2.2.0:

import tensorflow as tf filenames = ["./img1.jpg", "./img2.jpg"] labels = ["A", "B"] def load_image(filePath, label): print('Loading File: {}' + filePath) raw_bytes = tf.io.read_file(filePath) image = tf.io.decode_image(raw_bytes, expand_animations = False) return tf.data.Dataset.from_tensors((image, label)) dataset = tf.data.Dataset.from_tensor_slices((filenames, labels)) dataset = dataset.interleave(lambda x, y: load_image(x, y), cycle_length=4) for i in dataset.as_numpy_iterator(): image = i[0] label = i[1] print(image.shape) print(label.decode()) # (275, 183, 3) # A # (275, 183, 3) # B 
Sign up to request clarification or add additional context in comments.

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.