從零開始的 TensorFlow:第 4 章、深刻理解第一個神經網路 (一)

yuanshang發表於2019-07-18

簡介

從本章開始,將記錄對官方文件中基本分類篇的剖析和理解。

基本分類文件主要對計算機視覺的基礎進行演示和簡要講解,其實質就是第一章的新手程式碼。就是對 70000 張服裝影象資料集進行訓練和測試的過程

匯入庫


# 引入未來版本新特性,作用是在使用 Python2 時能夠相容 Python3 的語法
from __future__ import absolute_import, division, print_function, unicode_literals

# 匯入 TensorFlow 和 tf.keras
import tensorflow as tf
from tensorflow import keras

# 匯入輔助庫
import numpy as np
import matplotlib.pyplot as plt

# 輸出 TensorFlow 的版本號
print(tf.__version__)

注意:剛複製這段程式碼到 IDE 中時,可能提示 matplotlib 包未安裝。所以需要安裝一下 matplotlib 包:

pip install matplotlib

關於 __future__ 官方地址 --> 傳送門

train_images、train_labels、test_images、test_labels 簡介

  • train_images:用來訓練的 60000 張圖片

    圖片

  • train_labels:用來訓練的 60000 個圖片分類(0-9),共 10 種

    圖片

  • test_images:用來測試的 10000 張圖片

  • test_labels:用來測試的 10000 個圖片分類

訓練資料和測試資料的獲取

# 從 tf.keras 中獲取 fashion_mnist 物件
fashion_mnist = keras.datasets.fashion_mnist

# 呼叫 fashion_mnist 物件的 load_data() 方法,獲取訓練資料和測試資料
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

下面我們來詳細檢視 load_data() 方法的原始碼

# Python2 相容 python3 的語法
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# 匯入 gzip 和 os 標準庫
import gzip
import os

# 匯入 numpy
import numpy as np

# 匯入工具類
from tensorflow.python.keras.utils.data_utils import get_file
from tensorflow.python.util.tf_export import keras_export

@keras_export('keras.datasets.fashion_mnist.load_data')
def load_data():
  # 根據系統生成合理相對路徑 datasets\\fashion-mnist
  dirname = os.path.join('datasets', 'fashion-mnist')

  # 資料包網路位置
  base = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
  # 資料包名稱
  files = [
      'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
      't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'
  ]

  # 下載網路資料到本地,並記錄資料在本地的絕對地址
  paths = []
  for fname in files:

    # 重點就是這個 get_file 方法,實現了網路路徑到本地路徑的轉換(俗稱下載)
    paths.append(get_file(fname, origin=base + fname, cache_subdir=dirname))

  with gzip.open(paths[0], 'rb') as lbpath:
    y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)

  with gzip.open(paths[1], 'rb') as imgpath:
    x_train = np.frombuffer(
        imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)

  with gzip.open(paths[2], 'rb') as lbpath:
    y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8)

  with gzip.open(paths[3], 'rb') as imgpath:
    x_test = np.frombuffer(
        imgpath.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)

  return (x_train, y_train), (x_test, y_test)

通過上面,我們知道了從網路資料到本地資料的核心方法:get_file

## 傳進來的引數如下:
# fname: 檔名稱
# origin: 遠端檔案地址
# cache_subdir:本地相對地址

@keras_export('keras.utils.get_file')
def get_file(fname,
             origin,
             untar=False,
             md5_hash=None,
             file_hash=None,
             cache_subdir='datasets',
             hash_algorithm='auto',
             extract=False,
             archive_format='auto',
             cache_dir=None):
  # cache_dir 如何是 None 則,生成 家目錄 + .keras。如 window 就是 C:\Users\Administrator\.keras
  if cache_dir is None:
    cache_dir = os.path.join(os.path.expanduser('~'), '.keras')

  # 如果傳入 md5_hash 則優先使用
  if md5_hash is not None and file_hash is None:
    file_hash = md5_hash
    hash_algorithm = 'md5'
  datadir_base = os.path.expanduser(cache_dir)

  # 如果家目錄沒有寫許可權,那就到在根目錄下建立 tmp 目錄
  if not os.access(datadir_base, os.W_OK):
    datadir_base = os.path.join('/tmp', '.keras')

  # 拼接 datadir_base 和 cache_subdir 生成完整本地絕對路徑
  datadir = os.path.join(datadir_base, cache_subdir)

  # 檢測本地絕對路徑目錄存在嗎,不存在就建立
  if not os.path.exists(datadir):
    os.makedirs(datadir)

  # 檔名是否啟用擴充套件優化
  if untar:
    untar_fpath = os.path.join(datadir, fname)
    fpath = untar_fpath + '.tar.gz'
  else:
    fpath = os.path.join(datadir, fname)

  download = False

  # 檢測本地現存資料的 hash 與遠端 hash 是否一致,不一致則重新下載
  if os.path.exists(fpath):
    # File found; verify integrity if a hash was provided.
    if file_hash is not None:
      if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
        print('A local file was found, but it seems to be '
              'incomplete or outdated because the ' + hash_algorithm +
              ' file hash does not match the original value of ' + file_hash +
              ' so we will re-download the data.')
        download = True
  else:
    download = True

  # 啟用下載,並宣告進度條
  if download:
    print('Downloading data from', origin)

    class ProgressTracker(object):
      # Maintain progbar for the lifetime of download.
      # This design was chosen for Python 2.7 compatibility.
      progbar = None

    def dl_progress(count, block_size, total_size):
      if ProgressTracker.progbar is None:
        if total_size == -1:
          total_size = None
        ProgressTracker.progbar = Progbar(total_size)
      else:
        ProgressTracker.progbar.update(count * block_size)

    error_msg = 'URL fetch failure on {}: {} -- {}'
    try:
      try:
        # 進行檔案下載,主要運用 Python 的 urllib.request 庫,進行下載和本地寫入
        urlretrieve(origin, fpath, dl_progress)
      except HTTPError as e:
        raise Exception(error_msg.format(origin, e.code, e.msg))
      except URLError as e:
        raise Exception(error_msg.format(origin, e.errno, e.reason))
    except (Exception, KeyboardInterrupt) as e:
      if os.path.exists(fpath):
        os.remove(fpath)
      raise
    ProgressTracker.progbar = None

  if untar:
    if not os.path.exists(untar_fpath):
      _extract_archive(fpath, datadir, archive_format='tar')
    return untar_fpath

  if extract:
    _extract_archive(fpath, datadir, archive_format)

  # 最後返回本地資料的絕對路徑
  return fpath

關於 urlretrieve 方法,我們看一下

if sys.version_info[0] == 2:

  def urlretrieve(url, filename, reporthook=None, data=None):
    def chunk_read(response, chunk_size=8192, reporthook=None):
      content_type = response.info().get('Content-Length')
      total_size = -1
      if content_type is not None:
        total_size = int(content_type.strip())
      count = 0
      while True:
        chunk = response.read(chunk_size)
        count += 1
        if reporthook is not None:
          reporthook(count, chunk_size, total_size)
        if chunk:
          yield chunk
        else:
          break

    # 重點在這裡,運用 urlopen 下載資料
    response = urlopen(url, data)
    with open(filename, 'wb') as fd:
      for chunk in chunk_read(response, reporthook=reporthook):
        # 將資料寫入本地
        fd.write(chunk)
else:
  from six.moves.urllib.request import urlretrieve

關於 urlopen 方法 --> 傳送門

結果

  • 命令列執行

    Python

  • 下載到本地檔案的位置

    Python

我們是一群被時空壓迫的孩子。 ---- 愛因斯坦

相關文章