Save and load datasets stored in TFRecord file format

This example demonstrates how to load the data from a stored .tfrecord file and to build a data generator.

Save Dataset to TFRecord

At first, a small temporary dataset compounding 5 source cases is created by utilizing the synthetic dataset DatasetSynthetic. The conventional beamforming map is included by specifying sourcemap as the desired feature.

[9]:
import tensorflow as tf
from acoupipe.datasets.synthetic import DatasetSynthetic

d1 = DatasetSynthetic()
d1.save_tfrecord(features=["sourcemap","loc","f"], f=2000, split="training", size=5, name="/tmp/dataset.tfrecord")
100%|██████████| 5/5 [00:02<00:00,  2.08it/s]

Load Dataset from TFRecord File

To parse the data from TFRecord files it is usually necessary to write a custom parser function that parses the file sequentially (see: TFRecord_ documentation for details).

A potential parser function for the 'sourcemap' feature can be similar to:

[8]:

def tfrecord_parser(record): """ parser for tfrecord datasets with 'csmtriu' feature """ parsed = tf.io.parse_single_example( record, { 'sourcemap': tf.io.VarLenFeature(tf.float32), # and further features ... } ) sourcemap = tf.reshape(tf.sparse.to_dense(parsed['sourcemap']),shape=(64,64)) return {'sourcemap': sourcemap} # and further features ...

To ease the data retrieval process, the dataset itself provides a get_tfrecord_parser method to obtain a suitable parser function.

[10]:
parser = d1.get_tfrecord_parser(features=["sourcemap","loc","f"], f=2000, num=0)

With the corresponding parser, one can load the Dataset with the TensorFlow Dataset API

[11]:
import acoular as ac
import matplotlib.pyplot as plt

dataset = tf.data.TFRecordDataset(
    filenames=["/tmp/dataset.tfrecord"],).map(parser)

# create an iterator from the dataset
dataset_iter = iter(dataset)
data = next(dataset_iter)

sourcemap = ac.L_p(data['sourcemap']).T
loc = data['loc']
f = data['f'][0]
extent = d1.config.grid.extend()

plt.figure()
plt.title(f"Beamforming map (f={f} Hz)")
plt.imshow(sourcemap,extent=extent,origin='lower')
for l in loc[:2].numpy().T:
    plt.plot(l[0],l[1],'x')
plt.colorbar(label=r'$L_p$ / dB')
plt.xlabel('x (m)')
plt.ylabel('y (m)')
plt.show()
../../_images/contents_jupyter_load_tfrecord_8_0.png