Tensorflow1.x 與 Tensorflow2.0 的區別

Galois發表於2020-03-02

TF1.x的歷史背景

TensorFlow 1.x 主要用於處理「靜態計算圖」的框架。計算圖中的節點是 Tensors,當圖形執行時,它將保持n維陣列;圖中的邊表示在執行圖以實際執行有用計算時將在張量上執行的函式。
在 TensorFlow 2.0 之前,我們必須將圖表分為兩個階段:

  1. 構建一個描述你要執行的計算的計算圖。這個階段實際上不執行任何計算;它只是建立來計算的符號表示。該階段通常將定義一個或多個表示計算圖輸入的“佔位符”(placeholder)物件。
  2. 多次執行計算圖。每次執行圖形時(例如,對於一個梯度下降步驟),你將指定要計算的圖形的哪些部分,並傳遞一個“feed_dict”字典,該字典將給出具體值為圖中的任何“佔位符”。

TF2.0中的新範例

使用 Tensorflow 2.0,我們可以簡單地才用“更像python”的功能形式,與 PyTorch 和 Numpy 操作直接相似。而不是帶有計算圖的兩步範例,使其(除其他事項外)更容易除錯 TF 程式碼。詳細資訊

TF 1.x 和 2.0 方法的主要區別在於 2.0 方法不使用 tf.Sessiontf.runplaceholderfeed_dict
兩個版本之間的不同之處以及兩者之間進行轉換的詳細資訊

一個簡單的例子:flatten功能

TF1.x

def flatten(x):
    """
    Input:
    - TensorFlow Tensor of shape (N, D1, ..., DM)

    Output:
    - TensorFlow Tensor of shape (N, D1 * ... * DM)
    """
    N = tf.shape(x)[0]
    return tf.reshape(x, (N, -1))
def test_flatten():
    # Clear the current TensorFlow graph.
    tf.reset_default_graph()

    # Stage I: Define the TensorFlow graph describing our computation.
    # In this case the computation is trivial: we just want to flatten
    # a Tensor using the flatten function defined above.

    # Our computation will have a single input, x. We don't know its
    # value yet, so we define a placeholder which will hold the value
    # when the graph is run. We then pass this placeholder Tensor to
    # the flatten function; this gives us a new Tensor which will hold
    # a flattened view of x when the graph is run. The tf.device
    # context manager tells TensorFlow whether to place these Tensors
    # on CPU or GPU.
    with tf.device(device):
        x = tf.placeholder(tf.float32)
        x_flat = flatten(x)

    # At this point we have just built the graph describing our computation,
    # but we haven't actually computed anything yet. If we print x and x_flat
    # we see that they don't hold any data; they are just TensorFlow Tensors
    # representing values that will be computed when the graph is run.
    print('x: ', type(x), x)
    print('x_flat: ', type(x_flat), x_flat)
    print()

    # We need to use a TensorFlow Session object to actually run the graph.
    with tf.Session() as sess:
        # Construct concrete values of the input data x using numpy
        x_np = np.arange(24).reshape((2, 3, 4))
        print('x_np:\n', x_np, '\n')

        # Run our computational graph to compute a concrete output value.
        # The first argument to sess.run tells TensorFlow which Tensor
        # we want it to compute the value of; the feed_dict specifies
        # values to plug into all placeholder nodes in the graph. The
        # resulting value of x_flat is returned from sess.run as a
        # numpy array.
        x_flat_np = sess.run(x_flat, feed_dict={x: x_np})
        print('x_flat_np:\n', x_flat_np, '\n')

        # We can reuse the same graph to perform the same computation
        # with different input data
        x_np = np.arange(12).reshape((2, 3, 2))
        print('x_np:\n', x_np, '\n')
        x_flat_np = sess.run(x_flat, feed_dict={x: x_np})
        print('x_flat_np:\n', x_flat_np)
test_flatten()

TF2.0

def flatten(x):
    """
    Input:
    - TensorFlow Tensor of shape (N, D1, ..., DM)
    Output:
    - TensorFlow Tensor of shape (N, D1 * ... * DM)
    """
    N = tf.shape(x)[0]
    return tf.reshape(x, (N, -1))
def test_flatten():
    # Construct concrete values of the input data x using numpy
    x_np = np.arange(24).reshape((2, 3, 4))
    print('x_np:\n', x_np, '\n')
    # Compute a concrete output value.
    x_flat_np = flatten(x_np)
    print('x_flat_np:\n', x_flat_np, '\n')
test_flatten()

百度網盤TF1.x和TF2.0函式比對錶,提取碼5du8
Tf1.0 和 Tf2.0

  • 靜態圖與動態圖
    • tf1.0: Sess、 feed_dict、placeholder 被移除
    • tf1.0: make_on_shot(initializable)_iterator 被移除
    • tf2.0: eager mode, @tf.function 與 AutoGraph

eager mode Vs sess

# TensorFlow 1.X
outputs = sessiion.run(f(placeholder), feed_dict={placeholder: input})
# TensorFlow 2.0
outputs = f(input)
  • eager mode & sess
    • 效能好
    • 可以匯入到出為 SavedModel
  • Eg:
    • for/while -> tf.while_loop
    • if -> tf.cond
    • for _ in dataset -> dataset.reduce

API 變動

  • Tensorflow 現在有 2000 個 API, 500 在根空間下
  • 一些空間被建立了但是沒有包含所有相關 API
    • tf.round 沒有在 tf.math 下
  • 有些在根空間下,但是很少被使用 tf.zeta
  • 有些經常使用,不在根空間下 tf.manip
  • 有些空間層次太深
    • tf.saved_model.signature_constants.CLASSIFY_INPUTS
    • tf.saved_model.CLASSIFY_INPUTS
  • 重複 API
    • tf.layers -> tf.keras.layers
    • tf.losses -> tf.keras.losses
    • tf.metrics -> tf.keras.metrics
  • 有些 API 有字首,所以應該建立子空間
    • tf.string_strip -> tf.string.strip
  • 重新組織
    • tf.debugging、tf.dtypes、tf.io、tf.quantization 等
本作品採用《CC 協議》,轉載必須註明作者和本文連結
不要試圖用百米衝刺的方法完成馬拉松比賽。