今日学んだこと#0010
tensroflowのAPI
tensorflow実装のモデルをpytorchで実装するために,必要なtensorflowのAPIをまとめていく.
主に,下記サイトを参考にした. dev.classmethod.jp
型取得:get_shape()
>>> import tensorflow as tf >>> tensor = tf.constant([[1, 1], [1, 1]]) >>> tensor.get_shape() TensorShpae([Dimension(2), Dimension(2)]) >>> tensor.get_shape().as_list() [2, 2]
縮約操作:reduce_sum()
>>> sess = tf.Session() >>> t = tf.constant([[1, 1, 1], [2, 2, 2]]) >>> sess.run(tf.reduce_sum(t, axis=[0], keep_dims=True)) array([[3, 3, 3]], dtype=int32) >>> sess.run(tf.reduce_sum(t, axis=[1], keep_dims=True)) array([[3], [6]], dtype=int32)
各要素を二乗:tf.sqaure()
>>> t = tf.constant([[1], [2], [4]]) >>> sess.run(tf.square(t)) array([[ 1], [ 4], [16]], dtype=int32)
要素ごとに最大値を取得:tf.maximum()
>>> t = tf.constant([[100, 2], [3, 4]]) >>> t_ = tf.constant([[10, 20], [30, 40]]) >>> sess.run(tf.maximum(t, t_)) array([[100, 20], [ 30, 40]], dtype=int32)
要素ごとに平方根を取得:tf.sqrt()
>>> t = tf.constant([[1., 2.], [3., 4.]]) >>> sess.run(tf.sqrt(t)) array([[1. , 1.4142135], [1.7320508, 2. ]], dtype=float32)