The following examples will walk you through the core concepts of YogaDL: storing, fetching, and streaming datasets.

Creating a yogadl.Storage

Most users will interact with yogadl.Storage object as a mechanism for storing and fetching datasets. The simplest Storage is the, or “local filesystem storage”. Let’s create one:

import os
import yogadl

# Create a yogadl.Storage object backed by the local filesystem.
storage_path = "/tmp/yogadl_cache"
os.makedirs(storage_path, exist_ok=True)
lfs_config =
storage =

YogaDL also comes with built-in support for GCS via and for S3 via

Storing a dataset

Let’s create a silly 10-record dataset and store it in the yogadl.Storage. This is done via storage.submit(). During storage.submit(), the entire dataset will be read and written to the storage backend (in this case, to a file).

import tensorflow as tf

# Create a dataset we can store.
records =

# Store this dataset as "range" version "1.0".
storage.submit(records, "range", "1.0")

Fetching a dataset

Later (possibly in a different process), you can fetch a yogadl.DataRef representing the dataset via storage.fetch().

A DataRef is just a reference to a dataset. In this case, the dataset will be stored in a file on your computer, but a DataRef could just as easily refer to a dataset on some remote machine; the interface would be the same.

To actually access the dataset, you need to first call, which will return a yogadl.Stream object. Then you can convert the Stream object to a framework-native data loader format (currently only is supported).

import yogadl.tensorflow

# Get the DataRef.
dataref = storage.fetch("range", "1.0")

# Tell the DataRef how to stream the dataset.
stream =, shuffle=True, shuffle_seed=777)

# Interpret the stream as a tensorflow dataset
records = yogadl.tensorflow.make_tf_dataset(stream)

# It's a real; you can use normal operations on it.
batches = records.repeat(3).batch(5)

# (this part requires TensorFlow >= 2.0)
for batch in batches:

This should print:

tf.Tensor([5 1 9 6 7], shape=(5,), dtype=int64)
tf.Tensor([1 7 3 9 8], shape=(5,), dtype=int64)
tf.Tensor([2 6 0 4 5], shape=(5,), dtype=int64)
tf.Tensor([9 5 3 0 8], shape=(5,), dtype=int64)
tf.Tensor([6 7 4 1 2], shape=(5,), dtype=int64)

Notice that:

  • The start_offset is only applied to the first epoch, so in this example .repeat(3) gave us 2.5 epochs of data since we skipped the first epoch.

  • The shuffle is a true shuffle. The shuffled stream samples from the whole dataset without any concept of a “buffer”, as with

  • The shuffle is reproducible because we chose a shuffle seed.

  • Each epoch is reshuffled.

Can I get the same features in fewer steps?

As a matter of fact, you can! In order to support the common use-case of running the same dataset through many different models during model development or hyperparameter search, you can use the storage.cacheable() decorator to decorate a function that returns a datastet.

When the decorated function is called the first time, it will run one time and save its output to storage. On subsequent calls, the original function will not run, but its cached output will be returned instead.

In this way, you can get the benefit of caching without a single script and only a single call against the storage object:

@storage.cacheable("range", "2.0")
def make_records():
    print("Cache not found, making range v2 dataset...")
    records = x: 2*x)
    return records

# Follow the same steps as before.
dataref = make_records()
stream =
records = yogadl.tensorflow.make_tf_dataset(stream)
batches = records.repeat(3).batch(5)

for batch in batches:

The storage.cacheble() decorator is multi-processing safe, so if two identical processes are configured to use the same storage, only one of them will create and save the dataset. The other one will wait for the dataset to be saved and will then read the dataset from the cache.

End-to-end training example:

Here is an example of how you might use YogaDL to train on the second half of an MNIST dataset. This illustrates the ability to continue training mid-dataset that is simply not natively possible with tf.keras. Without YogaDL, you could imitate this behavior using, but that is prohibitively expensive for large values of N.


MNIST is such a small dataset that YogaDL is not going to outperform any example that treats MNIST as an in-memory dataset.

import math
import os
import tensorflow as tf
import tensorflow_datasets as tfds
import yogadl
import yogadl.tensorflow


# Configure the yogadl storage.
storage_path = "/tmp/yogadl_cache"
os.makedirs(storage_path, exist_ok=True)
lfs_config =
storage =

@storage.cacheable("mnist", "1.0")
def make_data():
    mnist = tfds.image.MNIST()
    dataset = mnist.as_dataset(as_supervised=True)["train"]

    # Apply dataset transformations from the TensorFlow docs:
    # (

    def normalize_img(image, label):
        """Normalizes images: `uint8` -> `float32`."""
        return tf.cast(image, tf.float32) / 255., label


# Get the DataRef from the storage via the decorated function.
dataref = make_data()

# Stream the dataset starting halfway through it.
num_batches = math.ceil(len(dataref) / BATCH_SIZE)
batches_to_skip = num_batches // 2
records_to_skip = batches_to_skip * BATCH_SIZE
stream =
    start_offset=records_to_skip, shuffle=True, shuffle_seed=777

# Convert the stream to a object.
dataset = yogadl.tensorflow.make_tf_dataset(stream)

# Apply normal data augmentation and prefetch steps.
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.prefetch(

# Model is straight from the TensorFlow docs:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
  tf.keras.layers.Dense(10, activation='softmax')

Advanced Use Case: Distributed Training

Sharding a dataset for use with distributed training is easy. If you are using Horovod for distributed training, you only need to alter the arguments of your call to

import horovod.tensorflow as hvd


stream =
    shard_rank=hvd.rank(), num_shards=hvd.size()

Advanced Use Case: Custom DataRef Objects

If you have an advanced use case, like generating data on an external machine and streaming it to another machine for training or something, and you would like to integrate with a platform that allows you to submit your dataset as a yogadl.DataRef, you can implement a custom yogadl.DataRef. By implementing the yogadl.DataRef interface, you can fully customize the behavior of how the platform interacts with your dataset. Here is a toy example of what that might look like:

import os
import yogadl
import yogadl.tensorflow
import tensorflow as tf

class RandomDataRef(yogadl.DataRef):
    A DataRef to a a non-reproducible dataset that just produces random
    int32 values.

    def __len__(self):
        return 10

    def stream(
        start_offset = 0,
        shuffle = False,
        skip_shuffle_at_epoch_end = False,
        shuffle_seed = None,
        shard_rank = 0,
        num_shards = 1,
        drop_shard_remainder = False,
    ) -> yogadl.Stream:
        For custom DataRefs, .stream() will often be a pretty beefy
        function. This example simplifies it by assuming that the dataset
        is non-reproducible, meaning that shuffle and shuffle_seed
        arguments are meaningless, and the shard_rank is only used to
        determine how many records will be yielded during each epoch.

        first_epoch = True

        def iterator_fn():
            nonlocal first_epoch
            if first_epoch:
                first_epoch = False
                start = start_offset + shard_rank
                start = shard_rank

            if drop_shard_remainder:
                end = len(self) - (len(self) % num_shards)
                end = len(self)

            for _ in range(start, end, num_shards):
                # Make a uint32 out of 4 random bytes
                r = os.urandom(4)
                yield r[0] + (r[1] << 8) + (r[2] << 16) + (r[3] << 24)

        # Since we will later convert to,
        # we will supply output_types and shapes.
        return yogadl.Stream(

dataref = RandomDataRef()
stream =
records = yogadl.tensorflow.make_tf_dataset(stream)
batches = records.batch(5)
for batch in batches:

© Copyright 2020, Determined AI. Revision 7f4233dd.

Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.