How is tensorflow while_loop implemented?

Consider the following example:

i = tf.get_variable("i", dtype=tf.int32, shape=[], initializer=tf.ones_initializer())
n = tf.constant(10)

def cond(a, n):
    return  a< n
def body(a, n):
    a = a + 2
    return a, n

ii, nn = tf.while_loop(cond, body, [i, n])
v1=ii+3
v2=nn+4

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    res = sess.run([v1, v2])
    print(res)

It will generate the following graph:

 

Open the while scope:

Quite messy, isn’t it? To understand  what is under the hood, you must have some knowledge about tensorflow control flow, dynamic control flow, the dead tensor, and of course, the while_loop official document.

Till now, we’ve known there are two kinds of dependencies between nodes(operations) that control their execution order: the data dependency, i.e., the output of A is an input node of B, thus A must be executed before B; the control dependency, i.e., there is no data flow between A and B, but A is still required to be executed before B. To evaluate a node, tensorflow finds all the nodes that have direct or indirect dependency relations with the node, and evaluate them in their dependency order. But if a node has different set of dependent nodes when some tensors in the graph have different values at runtime, how can we design the computation graph? It is not hard to construct different graphs based on the value of the tensor that is evaluated in the third graph. But it is better to put the graphs in one graph as sub-graphs, connect all the sub-graphs to current node,  and invent a mechanism to execute only one sub-graph at a time according to the value of the tensor on the fly. Since all the nodes in the sub-graphs have dependency relations with current node, we still need to traverse them but we need a special tensor called dead tensor to control if they are actually executed or not. If one of the inputs of an node is a dead tensor, the rule requires the actual operation of the node is not executed but outputs a dead tensor to all its outputs. This way, although all nodes are traversed as usual, only part of them are actually executed.

Now back to the while_loop. It takes 3 parameters. The third parameter is called loop_vars which is a list of tensors. The first and the second parameters are functions which take the exact number/shape of parameters as loop_vars(in this case, both take 2 scalar parameters). The first parameter is called cond, which must output a boolean tensor. The second parameter is called body which represents the operation(s) done when cond returns true. cond and body are both called once when calling while_loop, and they construct some nodes in the graph. Note that although in this case, the body function does not use the “n” parameter, the “n” parameter should be included in the parameter list and returned along with another loop var a.

Every loop var enters the while scope via an “Enter” node and exits the while scope via an “Exit” node.

The “Enter” node for i:

The “Enter” node for n:

The “Enter” node outputs the same tensor as the input. After going through the “Enter” node, these loop vars go through a “Merge” node.

The “Merge” node for i:

The “Merge” node for n:

The “Merge” nodes output the input tensor that is not a dead tensor. In this case, it will output the input from the output of “Enter” in the first iteration, and output the input from the output of “NextIteration” in the later iterations. When the loop vars go out of the “Merge” nodes, they are fed into the sub-graph constructed by the cond function. In this case, they are connected to the “Less” Operation to get a boolean tensor:

The result of the cond determines the future of the loop vars: whether they go through another iteration or go to the exit. To do this, the loop vars(after going through “Merge”) are connected to a “Switch” node which is controlled by the return tensor of cond.

The “Switch” node for i:

The “Switch” node for n:

If the returned tensor of cond is true, the loop vars go through the subgraph constructed by the body function; the other output of “Switch” would be a dead tensor disabling the “Exit” node.

The sub-graph constructed by body for i:

There is no subgraph for n because it is not used in body. Regardless the loop vars are involved in the operations defined by body, they will go through the “NextIteration” nodes(possibly after going through some operations defined by body).

The “NextIteration” node for i:

The “NextIteration” node for n:

 

The nodes in the while scope are unlike ordinary nodes which are only executed once. The nodes in the while scope except the “Enter” and “Exit” node may be executed multiple times when evaluating some node(e.g., v1 or v2) related to the loop. The returned tensors of while_loop are the outputs of the “Exit” nodes in the while scope.

The “Exit” node for i:

The “Exit” node for n:

The returned tensor(ii) of while_loop for i is connected to the “add” node to compute v1. The returned tensor(nn) of while_loop for n is connected to the “add_1″ node to compute v2.

The “add” node:

The loop is started when the loop vars i and n are fed into the “Enter” nodes. The computation for v1 and v2 is done when the “Exit” nodes output non-dead tensors because the inputs for the “Add” nodes are ready.

Hope you have understood what is inside while_loop. For more details and internals about while_loop, you may refer to this and this.

Leave a Reply