你可以给什么样的数据类型作为TensorFlow的关键字?

为了举例,考虑计算张量流中的内积。 我试图在TensorFlow的图表中用不同的方式来引用图表中的事物,当用一个使用feed的会话对它进行评估时。 考虑下面的代码:

import numpy as np
import tensorflow as tf

M = 4
D = 2
D1 = 3
x = tf.placeholder(tf.float32, shape=[M, D], name='data_x') # M x D
W = tf.Variable( tf.truncated_normal([D,D1], mean=0.0, stddev=0.1) ) # (D x D1)
b = tf.Variable( tf.constant(0.1, shape=[D1]) ) # (D1 x 1)
inner_product = tf.matmul(x,W) + b # M x D1
with tf.Session() as sess:
    sess.run( tf.initialize_all_variables() )
    x_val = np.random.rand(M,D)
    #print type(x.name)
    #print x.name
    name = x.name
    ans = sess.run(inner_product, feed_dict={name: x_val})
    ans = sess.run(inner_product, feed_dict={x.name: x_val})
    ans = sess.run(inner_product, feed_dict={x: x_val})
    name_str = unicode('data_x', "utf-8")
    ans = sess.run(inner_product, feed_dict={"data_x": x_val}) #doesn't work
    ans = sess.run(inner_product, feed_dict={'data_x': x_val}) #doesn't work
    ans = sess.run(inner_product, feed_dict={name_str: x_val}) #doesn't work
    print ans

以下工作:

ans = sess.run(inner_product, feed_dict={name: x_val})
ans = sess.run(inner_product, feed_dict={x.name: x_val})
ans = sess.run(inner_product, feed_dict={x: x_val})

但最后三个:

name_str = unicode('data_x', "utf-8")
ans = sess.run(inner_product, feed_dict={"data_x": x_val}) #doesn't work
ans = sess.run(inner_product, feed_dict={'data_x': x_val}) #doesn't work
ans = sess.run(inner_product, feed_dict={name_str: x_val}) #doesn't work

别。 我检查了为什么类型x.name是,但即使当我将它转换为类型python解释器时,它仍然无法正常工作。 我的文档似乎认为键必须是张量。 但是,它接受x.name而不是张量(它是一个<type 'unicode'> x.name <type 'unicode'> ),是否有人知道发生了什么?


我可以粘贴文档说它需要是一个张量:

可选的feed_dict参数允许调用者覆盖图中张量的值。 feed_dict中的每个键都可以是以下类型之一:

如果键是张量,则该值可能是Python标量,字符串,列表或numpy ndarray,可以将其转换为与该张量相同的dtype。 此外,如果键是占位符,则将检查值的形状是否与占位符兼容。 如果密钥是SparseTensor,则该值应该是SparseTensorValue。 feed_dict中的每个值必须可转换为相应键的dtype的numpy数组。


TensorFlow主要期望tf.Tensor对象作为Feed词典中的键。 如果它等于会话图中某个tf.Tensor.name属性,它也会接受一个字符串(可能是bytesunicode )。

在你的例子中, x.name起作用,因为x是一个tf.Tensor ,你正在评估它的.name属性。 "data_val"不起作用,因为它是一个名称tf.Operation (即x.op ),而不是一个名称tf.Tensor ,这是一个输出tf.Operation 。 如果你打印x.name ,你会发现它的值为"data_val:0" ,意思是“ tf.Operation的第0个输出叫做"data_val"

链接地址: http://www.djcxy.com/p/32103.html

上一篇: What data types can you give as keys to feed in TensorFlow?

下一篇: How to add regularizations in TensorFlow?