Large production pipelines in TensorFlow are quite difficult to pull off. Training small models is easy, and we mostly do this at first, but as soon as we get to the rest of the pipeline, complexity rapidly mounts.

One reason is that the “Computation Graph” abstraction used by TensorFlow is a close, but not exact match for the ML model we expect to train and use. How so?

Typically a model will be used in at least three ways:

  • Training – finding the correct weights or parameters for the model given some training data. Often done periodically as new data arrives.
  • Evaluation – calculating various metrics during training on a different data set to evaluate training quality or for cross validation.
  • Serving – on-demand prediction for new data

There could be more modes. For example we could re-train an existing model or apply the model to a large amount of data in batch mode.

While the conceptual model is the same, these use cases might need different computational graphs.
For example, if we use TensorFlow Serving, we would not be able to load models with Python function operations.
Another example is the evaluation metrics and debug operations like `tf.Assert` – we might not want to run them when serving for performance reasons.

Turns out we need 3-5 different graphs in order to represent our one model. The are a couple of ways to do this, and picking the right one is not straightforward.

Can’t we just build a graph and update it as we go?

TensorFlow graphs in Python are append-only. TensorFlow operations implicitly create graph nodes, and there are no operations to remove nodes. Even if we try to overwrite

with tf.Session() as sess:
    my_sqrt = tf.sqrt(4.0, name='my_sqrt')
    # override
    my_sqrt = tf.sqrt(2.0, name='my_sqrt')
    #print all nodes
    print sess.graph._nodes_by_name.keys()

 

TensorFlow will just add a suffix to the operation name:

[u'my_sqrt_1/x', u'my_sqrt_1', u'my_sqrt/x', u'my_sqrt']

What’s done cannot be undone, so to speak.

So what can we do? We can try to create more than one graph.

Creating multiple graphs with the same code

This is the method that we usually find in the documentation. A dropout example might look like this:

if is_training:
    activations = tf.nn.dropout(activations, 0.7)

TensorFlow even ships with tools like `tf.variable_scope` that make creating different graphs easier.

This approach has a big drawback however – the serialized graph can no longer be used without the code that produced it. Even a small change (like changing a variable name) will break the model in production so to revert to an older model version, we also need to revert to the older code. This is not always practical with larger repositories and in any case requires some operations effort.

In addition, training and evaluation don’t use the same graph (even if they share weights) and require awkward coordination to mesh together.

So we still want one graph, but we want to use it for both training and evaluation. And maybe serving… But definitely training and evaluation.

Using one graph with conditional logic

TensorFlow does have a way to encode different behaviors into a single graph – the `tf.cond` operation.

is_train = tf.placeholder(tf.bool)
dropout = tf.nn.dropout(activations, 0.7)
activations = tf.cond(is_train,
                      lambda: dropout,
                      lambda: activations)

The big advantage is that now we have all of the logic in one graph, for instance, we can see it in TensorBoard. Now our serialized models work for training and evaluation.
There are two sources of complexity that make the picture less rosy though – laziness and queues.

Laziness and the conditional operator

Let’s say we have an expensive operation we would only like to run during evaluation. If we put it behind a conditional operator, we would expect it to only run at evaluation time. This is also true if we mutate something as result of a condition such as the case with batch normalization.

This is not always how it works in TF. The `tf.cond` operation is like a box of laziness, but it protects only what’s inside.
So this code works correctly:

# Good - dropout inside the conditional
is_train = tf.placeholder(tf.bool)
activations = tf.cond(is_train,
                      lambda: tf.nn.dropout(activations, 0.7),
                      lambda: activations)

But this will run the dropout even if `is_train == False`.

# Bad - droupout outside the conditional evaluated every time!
is_train = tf.placeholder(tf.bool)
do_activations = tf.nn.dropout(activations, 0.7)
activations = tf.cond(is_train,
                      lambda: do_activations,
                      lambda: activations)

Queues and the conditional operator

Queues are the preferred (and best performing) way to get data into TensorFlow. Typically, train and evaluation will be done simultaneously on different inputs, so we might want to try the approach above to get them into the same graph.

tf.cond(is_eval,
        lambda: tf.train.shuffle_batch(eval_tensors, 1024,100000,10000),
        lambda: tf.train.shuffle_batch(train_tensors,1024,100000,10000))

It doesn’t work however. TensorFlow would inform us that `operation has been marked as not fetchable` and crash.

The issue is that TensorFlow, does not allow us to enqueue conditionally. However, `shuffle_batch` operation creates the queue and dequeue operation together.
To avoid this we need to split the operation into a conditional part that creates the queue, and conditional part that pulls from the correct queue.

Here is an example (the full code is quite long so I only left the relevant parts)

def create_queue(tensors, capacity, ...):
    ...
    queue = data_flow_ops.RandomShuffleQueue(
            capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed,
            dtypes=types, shapes=shapes, shared_name=shared_name)
    return queue

def create_dequeue(queue, ...):
    ...
    dequeued = queue.dequeue_up_to(batch_size, name=name)
    ...
    return dequeued

def merge_queues(self, is_train_tensor, train_tensors, test_tensors, ...):
    train_queue = self.create_queue(tensors=train_tensors,
                                   capacity=...
                                  )
    test_queue = self.create_queue(tensors=test_tensors
                                   capacity=...
                                  )

    input_values = tf.cond(is_train_tensor,
                           lambda: self.create_dequeue(train_queue, ...),
                           lambda: self.create_dequeue(test_queue, ...)

So we can get what we want with the conditional operator, but the code is more complex and harder to understand. Operations should be easier though – we have simple serialized graphs and monitoring.

Could we avoid conditions entirely and somehow work around the append-only limitation?

Working with saved graphs

Most pipelines serialize graphs, if only for serving. One very important thing we can do with model is to serve serialized representations. TensorFlow Serving is outside the scope of this post, but the general idea is that to get a full featured server we just need to run:

bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server --model_name=my_model --model_base_path=/my/model/path

Running our training graph in TensorFlow Serving is not the best idea however. Performance is hurt by running unnecessary operations, and `tf.py_func` operations can’t even be loaded by the server.
Luckily, the serialized graph is not like the append only graph we had when we started. It is just a bunch of Protobuf objects so we can create new versions. As an example, below is a simplified and annotated version of the `convert_variables_to_constants` function in `graph_util_impl.py`  that (unsurprisingly) converts variables into constants. It’s useful because this can be faster when serving in some cases.

def convert_variables_to_constants(sess, input_graph_def, output_node_names,
                                   variable_names_whitelist=None,
                                   variable_names_blacklist=None):
  inference_graph = extract_sub_graph(input_graph_def, output_node_names)

  ...
  # Here we find the variable we want to convert
  for node in inference_graph.node:
    if node.op in ["Variable", "VariableV2"]:
        # Compute a list of variables found

  ...

  #Here we create a new graph with the variables replaced by constants
  output_graph_def = graph_pb2.GraphDef()
  ...
  for input_node in inference_graph.node:
    output_node = node_def_pb2.NodeDef()
    if input_node.name in found_variables:
      # Make output_node into a new constant with the variables weights
    else:
      output_node.CopyFrom(input_node)
    output_graph_def.node.extend([output_node])
  print("Converted %d variables to const ops." % how_many_converted)

  return output_graph_def

TensorFlow actually ships with a few ways to manipulate saved graphs. An example is the `quantize_graph` tool and the `freeze_graph` tool which uses the code in the example above. You can use them if they fit your needs, but make sure that they work with your serialization format.

Serialization formats

At least at the moment (TF is at 1.2), there are many different serialization formats used by TensorFlow and serialized graph manipulation is quite complicated. The formats currently used are:

  • `GraphDef`- only defines the graph structure without weights or serving considerations. Graph freezing code is written for this format.
  • `Saver` – adds operations to load weights from files to a Graph
  • `SessionBundle` – adds a signature for serving most graph manipulation code is written for this format
  • `SavedModel` – the future of graph storage (for ML models). Allows versioning, signatures, the works. Support for manipulation is not great for now.

As of Tensorflow 1.2,  most graph manipulation tools work with `GraphDef` objects and serving works with `SavedModel` objects. Older versions use `SessionBundle` more. Documentation usually points to the correct input format.

Our approach here at Taboola

In order to simplify operations, and make experiments easy to reason about, we try to avoid the code-model dependencies.

Our general strategy is to create a super-graph that can train, evaluate, already at at train time. For code reuse between train and evaluation we use conditional operations, and we prepare the graph for serving using serialized graph manipulation.

Online experiments are a big part of our workflow, and we have serving machines around the world, so lowering operational complexity is well worth the effort. This approach has worked well for us. So far we have shipped two big pipelines, training, evaluating and deploying dozens of models.

Originally Published:

Start Your Taboola Career Today!