TensorFlow入门

    xiaoxiao2021-03-25  138


    1 导入tensorflow

    import tensorflow as tf

    2 创建节点

    node1 = tf.constant(3.0,tf.float32) # 常量 node2 = tf.constant(4.0) #默认32位浮点 print(node1,node2);

    输出

    [3.0, 4.0]

    3 创建网络

    节点的信息

    sess = tf.Session() #创建网络 print(sess.run([node1,node2])) #输出节点信息

    输出

    ('node3: ', <tf.Tensor 'Add:0' shape=() dtype=float32>)

    4 相加操作

    #节点相加 node3=tf.add(node1,node2) print("node3: ",node3) print("sess.run(node3):",sess.run(node3))

    输出

    ('sess.run(node3):', 7.0)

    5 输入操作

    #输入数据,节点相加 a=tf.placeholder(tf.float32) b=tf.placeholder(tf.float32) add_node=a+b # 相当于add_node=tf.add(a,b) print(sess.run(add_node,{a:3,b:4.5})) print(sess.run(add_node,{a:[1,2],b:[3,4]})) print(sess.run(add_node,{a:[[1,2],[3,4]],b:[[3,4],[5,6]]}))

    输出

    7.5 [ 4. 6.] [[ 4. 6.] [ 8. 10.]]

    6 相乘操作

    #节点乘3 add_triple = add_node * 3 print(sess.run(add_triple,{a:[1,2],b:[3,4]}))

    输出

    [ 12. 18.]

    7 线性回归实例

    7.1 建立方程

    W = tf.Variable([.3],tf.float32) b = tf.Variable([-0.3],tf.float32) x = tf.placeholder(tf.float32) linear_model= W * x + b #l=W*x+b #变量需要初始化 init = tf.global_variables_initializer() sess.run(init) print(sess.run(linear_model,{x:[1,2,3,4]}))

    输出

    [ 0. 0.30000001 0.60000002 0.90000004]

    7.2 计算损失值

    #计算损失值,在这里用误差平方和 y=tf.placeholder(tf.float32) squared_deltas=tf.square(linear_model-y) loss = tf.reduce_sum(squared_deltas) print(sess.run(loss,{x:[1,2,3,4],y:[0,-1,-2,-3]}))

    输出

    23.66

    7.3 对W和b重新赋值

    #对W和b重新赋值,使用tf.assign()方法 fixW = tf.assign(W, [-1.]) fixb = tf.assign(b, [1.]) sess.run([fixW,fixb]) #需要此语句生效 print(sess.run(loss,{x:[1,2,3,4],y:[0,-1,-2,-3]}))

    输出

    0.0

    7.4 使用梯度下降进行回归

    #梯度下降线性回归 optimizer=tf.train.GradientDescentOptimizer(0.001) #0.01为步长,也称学习率 train=optimizer.minimize(loss) #最小化损失函数 #训练10000次 sess.run(init) #变量初始化 for i in range(10000): sess.run(train, {x:[1,2,3,4], y:[0,-1,-2,-3]}) print(sess.run([W,b]))

    输出

    [array([-0.99998975], dtype=float32), array([ 0.99997061], dtype=float32)]

    From[TensorFlow入门]

    转载请注明原文地址: https://ju.6miu.com/read-2704.html

    最新回复(0)