使用**迭代器**獲取Cifar等常用資料集
Cifar
、MNIST
等常用資料集的坑:
- 每次在一臺新的機器上使用它們去訓練模型都需要重新下載(國內網路往往都不給力,需要花費大量的時間,有時還下載不了);
- 即使下載到本地,然而不同的模型對它們的處理方式各不相同,我們又需要花費一些時間去了解如何讀取資料。
為了解決上述的坑,我在Bunch 轉換為 HDF5 檔案:高效儲存 Cifar 等資料集中將一些常用的資料集封裝為 HDF5
檔案。
下面的 X.h5c
可以參考Bunch 轉換為 HDF5 檔案:高效儲存 Cifar 等資料集自己製作,也可以直接下載使用(連結:https://pan.baidu.com/s/1hsbMhv3MDlOES3UDDmOQiw 密碼:qlb7)。
使用方法很簡單:
訪問資料集
# 載入所需要的包
import tables as tb
import numpy as np
xpath = `E:/xdata/X.h5` # 檔案所在路徑
h5 = tb.open_file(xpath)
下面我們來看看此檔案中有那些資料集:
h5.root
/ (RootGroup) "Xinet`s dataset"
children := [`cifar10` (Group), `cifar100` (Group), `fashion_mnist` (Group), `mnist` (Group)]
下面我們以 Cifar
為例,來詳細說明該檔案的使用:
cifar = h5.root.cifar100 # 獲取 cifar100
為了高效使用資料集,我們使用迭代器的方式來獲取它:
class Loader:
"""
方法
========
L 為該類的例項
len(L)::返回 batch 的批數
iter(L)::即為資料迭代器
Return
========
可迭代物件(numpy 物件)
"""
def __init__(self, X, Y, batch_size, shuffle):
```
X, Y 均為類 numpy
```
self.X = X
self.Y = Y
self.batch_size = batch_size
self.shuffle = shuffle
def __iter__(self):
n = len(self.X)
idx = np.arange(n)
if self.shuffle:
np.random.shuffle(idx)
for k in range(0, n, self.batch_size):
K = idx[k:min(k + self.batch_size, n)].tolist()
yield np.take(self.X, K, 0), np.take(self.Y, K, 0)
def __len__(self):
return round(len(self.X) / self.batch_size)
下面我們可以使用 Loader
來例項化我們的資料集:
batch_size = 512
train_cifar = Loader(cifar.trainX, cifar.train_fine_labels, batch_size, True)
test_cifar = Loader(cifar.testX, cifar.test_fine_labels, batch_size, False)
讀取一個 Batch 的資料:
for imgs, labels in iter(train_cifar):
break
names = np.asanyarray([cifar.fine_label_names[label] for label in labels], dtype=`U`)
names[:7]
array([`orchid`, `spider`, `rabbit`, `shark`, `shrew`, `clock`, `bed`],
dtype=`<U13`)
視覺化
需要注意,這裡的 Cifar
是 first channel
的,即:
imgs.shape
(512, 3, 32, 32)
names.shape
(512,)
from pylab import plt, mpl
mpl.rcParams[`font.sans-serif`] = [`SimHei`] # 指定預設字型
mpl.rcParams[`axes.unicode_minus`] = False # 解決儲存影像是負號 `-` 顯示為方塊的問題
def show_imgs(imgs, labels):
```
展示 多張圖片
```
imgs = np.transpose(imgs, (0, 2, 3, 1))
n = imgs.shape[0]
h, w = 5, int(n / 5)
fig, ax = plt.subplots(h, w, figsize=(7, 7))
K = np.arange(n).reshape((h, w))
names = np.asanyarray([cifar.fine_label_names[label] for label in labels], dtype=`U`)
names = names.reshape((h, w))
for i in range(h):
for j in range(w):
img = imgs[K[i, j]]
ax[i][j].imshow(img)
ax[i][j].axes.get_yaxis().set_visible(False)
ax[i][j].axes.set_xlabel(names[i][j])
ax[i][j].set_xticks([])
plt.show()
show_imgs(imgs[:25], labels[:25])
$2$ 個深度學習框架 & 資料集
因為,上面的資料集是 NumPy
的 array
形式,故而:
TensorFlow
import tensorflow as tf
for imgs, labels in iter(train_cifar):
imgs = tf.constant(imgs)
labels = tf.constant(labels)
break
imgs
<tf.Tensor `Const:0` shape=(512, 3, 32, 32) dtype=uint8>
labels
<tf.Tensor `Const_1:0` shape=(512,) dtype=int32>
MXNet
from mxnet import nd, cpu, gpu
for imgs, labels in iter(train_cifar):
imgs = nd.array(imgs, ctx = gpu(0))
labels = nd.array(labels, ctx = cpu(0))
break
imgs.context
gpu(0)
labels.context
cpu(0)
Matlab 讀取 HDF
參考:h5read
相關文章
- 資料集的使用-以CIFAR10為例
- mORMot2 獲取資料集1ORM
- 迭代器,迭代器塊和資料管道
- datatables使用ajax獲取資料
- 如獲取獲取關聯資料的文件跟模型的關聯資料集呢模型
- 使用迭代器接收資料並自動停止
- redis叢集資料儲存和獲取原理Redis
- TP5 獲取資料集記錄數
- 在Grails使用Sql獲取資料AISQL
- CIFAR10/CIFAR100資料集介紹---有Python版本的二進位制資料格式說明Python
- python使用cx_Oracle連線oracle資料庫獲取常用資訊PythonOracle資料庫
- 使用commons-beanutils迭代獲取javabean的屬性BeanJava
- Temu api介面 獲取商品詳情 資料採集API
- 教你如何使用API介面獲取資料!API
- AJAX 獲取伺服器響應資料伺服器
- python迭代器資料整理Python
- 使用 C# 獲取 Kubernetes 叢集資源資訊C#
- 使用RxJava從多個資料來源獲取資料RxJava
- ckeditor獲取資料
- UCI資料集整理(附論文常用資料集)
- 如何使用API介面獲取淘寶商品資料API
- 在 Laravel 中使用 GraphQL 一 [獲取資料]Laravel
- dom元素操作獲取等
- 資料獲取,解析,儲存等知識的學習總結
- modbustcp封裝使用獲取裝置資料示例TCP封裝
- 如何教會小白使用API介面獲取商品資料API
- 如何使用js獲取USB掃碼槍資料JS
- 使用商品詳情API介面獲取商品資料API
- 使用Paging Library獲取網路資料
- 轉:使用基本認證從WebServer獲取資料WebServer
- Scrapy爬蟲 - 獲取知乎使用者資料爬蟲
- 使用Python獲取HTTP請求頭資料PythonHTTP
- 使用 useLazyFetch 進行非同步資料獲取非同步
- 獲取Wireshark資料流
- 1.獲取資料
- Modbus ASCII 獲取資料ASCII
- 33個機器學習常用資料集機器學習
- 分散式機器學習常用資料集分散式機器學習