What data types can you give as keys to feed in TensorFlow?
Consider computing an inner product in tensor flow for the sake of an example. I was trying to experiment on the different ways to refer to things in graphs in TensorFlow when one evaluates it with a session using feed. Consider the following code:
import numpy as np
import tensorflow as tf
M = 4
D = 2
D1 = 3
x = tf.placeholder(tf.float32, shape=[M, D], name='data_x') # M x D
W = tf.Variable( tf.truncated_normal([D,D1], mean=0.0, stddev=0.1) ) # (D x D1)
b = tf.Variable( tf.constant(0.1, shape=[D1]) ) # (D1 x 1)
inner_product = tf.matmul(x,W) + b # M x D1
with tf.Session() as sess:
sess.run( tf.initialize_all_variables() )
x_val = np.random.rand(M,D)
#print type(x.name)
#print x.name
name = x.name
ans = sess.run(inner_product, feed_dict={name: x_val})
ans = sess.run(inner_product, feed_dict={x.name: x_val})
ans = sess.run(inner_product, feed_dict={x: x_val})
name_str = unicode('data_x', "utf-8")
ans = sess.run(inner_product, feed_dict={"data_x": x_val}) #doesn't work
ans = sess.run(inner_product, feed_dict={'data_x': x_val}) #doesn't work
ans = sess.run(inner_product, feed_dict={name_str: x_val}) #doesn't work
print ans
The following work:
ans = sess.run(inner_product, feed_dict={name: x_val})
ans = sess.run(inner_product, feed_dict={x.name: x_val})
ans = sess.run(inner_product, feed_dict={x: x_val})
but the last three:
name_str = unicode('data_x', "utf-8")
ans = sess.run(inner_product, feed_dict={"data_x": x_val}) #doesn't work
ans = sess.run(inner_product, feed_dict={'data_x': x_val}) #doesn't work
ans = sess.run(inner_product, feed_dict={name_str: x_val}) #doesn't work
don't. I checked why type x.name
was but it still didn't work even when I converted it to the type python interpreter said it was. I documentation seems to say that the keys have to be tensors. However, it accepted x.name
while its not a tensor (its a <type 'unicode'>
), does someone know whats going on?
I can paste the documentation says it need to be a tensor:
The optional feed_dict argument allows the caller to override the value of tensors in the graph. Each key in feed_dict can be one of the following types:
If the key is a Tensor, the value may be a Python scalar, string, list, or numpy ndarray that can be converted to the same dtype as that tensor. Additionally, if the key is a placeholder, the shape of the value will be checked for compatibility with the placeholder. If the key is a SparseTensor, the value should be a SparseTensorValue. Each value in feed_dict must be convertible to a numpy array of the dtype of the corresponding key.
TensorFlow primarily expects tf.Tensor
objects as the keys in the feed dictionary. It will also accept a string (which may be bytes
or unicode
) if it is equal to the .name
property of some tf.Tensor
in the session's graph.
In your example, x.name
works, because x
is a tf.Tensor
and you're evaluating its .name
property. "data_val"
does not work because it is the name of a tf.Operation
(viz. x.op
) and not the name of a tf.Tensor
, which is the output of a tf.Operation
. If you print x.name
, you'll see that it has the value "data_val:0"
, which means "the 0th output of a tf.Operation
called "data_val"
.
上一篇: 困惑于张量流如何进食