簡介
從本章開始,將記錄對官方文件中基本分類篇的剖析和理解。
基本分類文件主要對計算機視覺的基礎進行演示和簡要講解,其實質就是第一章的新手程式碼。就是對 70000 張服裝影象資料集進行訓練和測試的過程
-
官方文件地址
匯入庫
# 引入未來版本新特性,作用是在使用 Python2 時能夠相容 Python3 的語法
from __future__ import absolute_import, division, print_function, unicode_literals
# 匯入 TensorFlow 和 tf.keras
import tensorflow as tf
from tensorflow import keras
# 匯入輔助庫
import numpy as np
import matplotlib.pyplot as plt
# 輸出 TensorFlow 的版本號
print(tf.__version__)
注意:剛複製這段程式碼到 IDE 中時,可能提示 matplotlib 包未安裝。所以需要安裝一下 matplotlib 包:
pip install matplotlib
關於 __future__
官方地址 --> 傳送門
train_images、train_labels、test_images、test_labels 簡介
-
train_images:用來訓練的 60000 張圖片
-
train_labels:用來訓練的 60000 個圖片分類(0-9),共 10 種
-
test_images:用來測試的 10000 張圖片
-
test_labels:用來測試的 10000 個圖片分類
訓練資料和測試資料的獲取
# 從 tf.keras 中獲取 fashion_mnist 物件
fashion_mnist = keras.datasets.fashion_mnist
# 呼叫 fashion_mnist 物件的 load_data() 方法,獲取訓練資料和測試資料
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
下面我們來詳細檢視 load_data()
方法的原始碼
# Python2 相容 python3 的語法
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# 匯入 gzip 和 os 標準庫
import gzip
import os
# 匯入 numpy
import numpy as np
# 匯入工具類
from tensorflow.python.keras.utils.data_utils import get_file
from tensorflow.python.util.tf_export import keras_export
@keras_export('keras.datasets.fashion_mnist.load_data')
def load_data():
# 根據系統生成合理相對路徑 datasets\\fashion-mnist
dirname = os.path.join('datasets', 'fashion-mnist')
# 資料包網路位置
base = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
# 資料包名稱
files = [
'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'
]
# 下載網路資料到本地,並記錄資料在本地的絕對地址
paths = []
for fname in files:
# 重點就是這個 get_file 方法,實現了網路路徑到本地路徑的轉換(俗稱下載)
paths.append(get_file(fname, origin=base + fname, cache_subdir=dirname))
with gzip.open(paths[0], 'rb') as lbpath:
y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(paths[1], 'rb') as imgpath:
x_train = np.frombuffer(
imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)
with gzip.open(paths[2], 'rb') as lbpath:
y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(paths[3], 'rb') as imgpath:
x_test = np.frombuffer(
imgpath.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)
return (x_train, y_train), (x_test, y_test)
通過上面,我們知道了從網路資料到本地資料的核心方法:get_file
## 傳進來的引數如下:
# fname: 檔名稱
# origin: 遠端檔案地址
# cache_subdir:本地相對地址
@keras_export('keras.utils.get_file')
def get_file(fname,
origin,
untar=False,
md5_hash=None,
file_hash=None,
cache_subdir='datasets',
hash_algorithm='auto',
extract=False,
archive_format='auto',
cache_dir=None):
# cache_dir 如何是 None 則,生成 家目錄 + .keras。如 window 就是 C:\Users\Administrator\.keras
if cache_dir is None:
cache_dir = os.path.join(os.path.expanduser('~'), '.keras')
# 如果傳入 md5_hash 則優先使用
if md5_hash is not None and file_hash is None:
file_hash = md5_hash
hash_algorithm = 'md5'
datadir_base = os.path.expanduser(cache_dir)
# 如果家目錄沒有寫許可權,那就到在根目錄下建立 tmp 目錄
if not os.access(datadir_base, os.W_OK):
datadir_base = os.path.join('/tmp', '.keras')
# 拼接 datadir_base 和 cache_subdir 生成完整本地絕對路徑
datadir = os.path.join(datadir_base, cache_subdir)
# 檢測本地絕對路徑目錄存在嗎,不存在就建立
if not os.path.exists(datadir):
os.makedirs(datadir)
# 檔名是否啟用擴充套件優化
if untar:
untar_fpath = os.path.join(datadir, fname)
fpath = untar_fpath + '.tar.gz'
else:
fpath = os.path.join(datadir, fname)
download = False
# 檢測本地現存資料的 hash 與遠端 hash 是否一致,不一致則重新下載
if os.path.exists(fpath):
# File found; verify integrity if a hash was provided.
if file_hash is not None:
if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
print('A local file was found, but it seems to be '
'incomplete or outdated because the ' + hash_algorithm +
' file hash does not match the original value of ' + file_hash +
' so we will re-download the data.')
download = True
else:
download = True
# 啟用下載,並宣告進度條
if download:
print('Downloading data from', origin)
class ProgressTracker(object):
# Maintain progbar for the lifetime of download.
# This design was chosen for Python 2.7 compatibility.
progbar = None
def dl_progress(count, block_size, total_size):
if ProgressTracker.progbar is None:
if total_size == -1:
total_size = None
ProgressTracker.progbar = Progbar(total_size)
else:
ProgressTracker.progbar.update(count * block_size)
error_msg = 'URL fetch failure on {}: {} -- {}'
try:
try:
# 進行檔案下載,主要運用 Python 的 urllib.request 庫,進行下載和本地寫入
urlretrieve(origin, fpath, dl_progress)
except HTTPError as e:
raise Exception(error_msg.format(origin, e.code, e.msg))
except URLError as e:
raise Exception(error_msg.format(origin, e.errno, e.reason))
except (Exception, KeyboardInterrupt) as e:
if os.path.exists(fpath):
os.remove(fpath)
raise
ProgressTracker.progbar = None
if untar:
if not os.path.exists(untar_fpath):
_extract_archive(fpath, datadir, archive_format='tar')
return untar_fpath
if extract:
_extract_archive(fpath, datadir, archive_format)
# 最後返回本地資料的絕對路徑
return fpath
關於 urlretrieve 方法,我們看一下
if sys.version_info[0] == 2:
def urlretrieve(url, filename, reporthook=None, data=None):
def chunk_read(response, chunk_size=8192, reporthook=None):
content_type = response.info().get('Content-Length')
total_size = -1
if content_type is not None:
total_size = int(content_type.strip())
count = 0
while True:
chunk = response.read(chunk_size)
count += 1
if reporthook is not None:
reporthook(count, chunk_size, total_size)
if chunk:
yield chunk
else:
break
# 重點在這裡,運用 urlopen 下載資料
response = urlopen(url, data)
with open(filename, 'wb') as fd:
for chunk in chunk_read(response, reporthook=reporthook):
# 將資料寫入本地
fd.write(chunk)
else:
from six.moves.urllib.request import urlretrieve
關於 urlopen 方法 --> 傳送門
結果
-
命令列執行
-
下載到本地檔案的位置