Pylearn2的Stacked Autoencoders示例

std1984發表於2014-07-31
環境:Ubuntu 12.4

1. 首先下載訓練資料

cd /u01/lisa/data/mnist
wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
gunzip train-images-idx3-ubyte.gz
wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
gunzip train-labels-idx1-ubyte.gz wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz gunzip t10k-images-idx3-ubyte.gz
	
	
wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
gunzip t10k-labels-idx1-ubyte.gz

2. 修改檔案dae_l1.yaml

進入python命令列模式
 
layer1_yaml = open('dae_l1.yaml', 'r').read() 
hyper_params_l1 = {'train_stop' : 50000, 'batch_size' : 100, 'monitoring_batches' : 5, 'nhid' : 500, 'max_epochs' : 10, 'save_path' : '.'} 
layer1_yaml = layer1_yaml % (hyper_params_l1) 
print layer1_yaml


將輸出的內容全部覆蓋掉dae_l1.yaml檔案的內容

3.  進入示例指令碼目錄
cd ~/pylearn2/pylearn2/scripts/tutorials/stacked_autoencoders
執行指令碼 
python ~/pylearn2/pylearn2/scripts/train.py  dae_l1.yaml
輸入日誌如下:
/home/jerry/pylearn2/pylearn2/utils/call_check.py:98: UserWarning: the `one_hot` parameter is deprecated. To get one-hot e                                              ncoded targets, request that they live in `VectorSpace` through the `data_specs` parameter of MNIST's iterator method. `on                                              e_hot` will be removed on or after September 20, 2014.
  return to_call(**kwargs)
/home/jerry/.local/lib/python2.7/site-packages/theano/sandbox/rng_mrg.py:1183: UserWarning: MRG_RandomStreams Can't determ                                              ine #streams from size (Shape.0), guessing 60*256
  nstreams = self.n_streams(size)
Parameter and initial learning rate summary:
        vb: 0.001
        hb: 0.001
        W: 0.001
        Wprime: 0.001
/home/jerry/pylearn2/pylearn2/models/model.py:71: UserWarning: The Model subclass seems not to call the Model constructor. This behavior may be considered an error on or after 2014-11-0                                              1.
  warnings.warn("The " + str(type(self)) + " Model subclass "
Compiling sgd_update...
Compiling sgd_update done. Time elapsed: 7.379370 seconds
compiling begin_record_entry...
compiling begin_record_entry done. Time elapsed: 0.103046 seconds
Monitored channels:
        learning_rate
        objective
        total_seconds_last_epoch
        training_seconds_this_epoch
Compiling accum...
graph size: 19
Compiling accum done. Time elapsed: 0.876798 seconds
Monitoring step:
        Epochs seen: 0
        Batches seen: 0
        Examples seen: 0
        learning_rate: 0.001
        objective: 89.1907964264
        total_seconds_last_epoch: 0.0
        training_seconds_this_epoch: 0.0
Time this epoch: 19.928861 seconds
......
Monitoring step:
        Epochs seen: 10
        Batches seen: 5000
        Examples seen: 500000
        learning_rate: 0.001
        objective: 11.9511445315
        total_seconds_last_epoch: 35.828732
        training_seconds_this_epoch: 22.296131
Saving to ./dae_l1.pkl...
Saving to ./dae_l1.pkl done. Time elapsed: 0.936124 seconds
Saving to ./dae_l1.pkl...
Saving to ./dae_l1.pkl done. Time elapsed: 0.886536 seconds

4. 檢視引數
>>> from pylearn2.utils import serial
>>> serial.load('dae_l1.pkl')

>>>
>>> model = serial.load('dae_l1.pkl')
>>>
>>> dir(model)
['__call__', '__class__', '__delattr__', '__dict__', '__doc__', '__format__', '__getattribute__', '__getstate__', '__hash__', '__init__', '__metaclass__', '__module__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_disallow_censor_updates', '_ensure_extensions', '_hidden_activation', '_hidden_input', '_initialize_hidbias', '_initialize_visbias', '_initialize_w_prime', '_initialize_weights', '_modify_updates', '_overrides_censor_updates', '_params', '_test_batch_size', 'act_dec', 'act_enc', 'censor_updates', 'continue_learning', 'corruptor', 'cpu_only', 'dataset_yaml_src', 'decode', 'encode', 'enforce_constraints', 'extensions', 'fn', 'free_energy', 'function', 'get_default_cost', 'get_input_dim', 'get_input_source', 'get_input_space', 'get_lr_scalers', 'get_monitoring_channels', 'get_monitoring_data_specs', 'get_output_dim', 'get_output_space', 'get_param_values', 'get_param_vector', 'get_params', 'get_target_source', 'get_target_space', 'get_test_batch_size', 'get_weights', 'get_weights_format', 'get_weights_topo', 'get_weights_view_shape', 'hidbias', 'input_space', 'inverse', 'irange', 'libv', 'modify_updates', 'monitor', 'nhid', 'output_space', 'perform', 'print_versions', 'reconstruct', 'redo_theano', 'register_names_to_del', 'rng', 's_rng', 'score', 'set_batch_size', 'set_input_space', 'set_param_values', 'set_param_vector', 'set_visible_size', 'tag', 'tied_weights', 'train_all', 'train_batch', 'upward_pass', 'visbias', 'w_prime', 'weights', 'yaml_src']
>>>

5. 類似步驟2,修改dae_l2.yaml檔案

 
layer2_yaml = open('dae_l2.yaml', 'r').read() 
hyper_params_l2 = {'train_stop' : 50000, 'batch_size' : 100, 'monitoring_batches' : 5, 'nvis' : 500, 'nhid' : 500, 'max_epochs' : 10, 'save_path' : '.'} 
layer2_yaml = layer2_yaml % (hyper_params_l2) 
print layer2_yaml

6. 執行dae_l2.yaml ,第二層模型訓練
python ~/pylearn2/pylearn2/scripts/train.py dae_l2.yaml
/home/jerry/pylearn2/pylearn2/utils/call_check.py:98: UserWarning: the `one_hot` parameter is deprecated. To get one-hot encoded targets, request that they live in `VectorSpace` through the `data_specs` parameter of MNIST's iterator method. `one_hot` will be removed on or after September 20, 2014.
  return to_call(**kwargs)
/home/jerry/.local/lib/python2.7/site-packages/theano/sandbox/rng_mrg.py:1183: UserWarning: MRG_RandomStreams Can't determine #streams from size (Shape.0), guessing 60*256
  nstreams = self.n_streams(size)
Parameter and initial learning rate summary:
        vb: 0.001
        hb: 0.001
        W: 0.001
        Wprime: 0.001
/home/jerry/pylearn2/pylearn2/models/model.py:71: UserWarning: The Model subclass seems not to call the Model constructor. This behavior may be considered an error on or after 2014-11-01.
  warnings.warn("The " + str(type(self)) + " Model subclass "
Compiling sgd_update...
Compiling sgd_update done. Time elapsed: 0.339660 seconds
compiling begin_record_entry...
compiling begin_record_entry done. Time elapsed: 0.023657 seconds
Monitored channels:
        learning_rate
        objective
        total_seconds_last_epoch
        training_seconds_this_epoch
Compiling accum...
graph size: 19
Compiling accum done. Time elapsed: 0.189965 seconds
Monitoring step:
        Epochs seen: 0
        Batches seen: 0
        Examples seen: 0
        learning_rate: 0.001
        objective: 52.2956323286
        total_seconds_last_epoch: 0.0
        training_seconds_this_epoch: 0.0
Time this epoch: 17.452593 seconds
......
Monitoring step:
        Epochs seen: 10
        Batches seen: 5000
        Examples seen: 500000
        learning_rate: 0.001
        objective: 4.33433924602
        total_seconds_last_epoch: 30.433518
        training_seconds_this_epoch: 19.303109
Saving to ./dae_l2.pkl...
Saving to ./dae_l2.pkl done. Time elapsed: 0.607150 seconds
Saving to ./dae_l2.pkl...
Saving to ./dae_l2.pkl done. Time elapsed: 0.588375 seconds


7.  類似步驟2修改dae_mlp.yaml檔案

mlp_yaml = open('dae_mlp.yaml', 'r').read() 
hyper_params_mlp = {'train_stop' : 50000, 'valid_stop' : 60000, 'batch_size' : 100, 'max_epochs' : 50, 'save_path' : '.'} 
mlp_yaml = mlp_yaml % (hyper_params_mlp) 
print mlp_yaml
(注:在原dae_mlp.yaml檔案內沒有save_path, save_freq這兩項,造成引數資料沒有儲存,因而需要加入這兩項,如下:
 save_path : './dae_mlp.pkl',
save_freq : 1
)

8. 執行監督最佳化--Supervised fine-tuning
python ~/pylearn2/pylearn2/scripts/train.py dae_mlp.yaml
/home/jerry/pylearn2/pylearn2/utils/call_check.py:98: UserWarning: the `one_hot` parameter is deprecated. To get one-hot encoded targets, request that they live in `VectorSpace` through the `data_specs` parameter of MNIST's iterator method. `one_hot` will be removed on or after September 20, 2014.
  return to_call(**kwargs)
Parameter and initial learning rate summary:
        vb: 0.05
        hb: 0.05
        W: 0.05
        Wprime: 0.05
        vb: 0.05
        hb: 0.05
        W: 0.05
        Wprime: 0.05
        softmax_b: 0.05
        softmax_W: 0.05
Compiling sgd_update...
Compiling sgd_update done. Time elapsed: 17.156073 seconds
compiling begin_record_entry...
compiling begin_record_entry done. Time elapsed: 0.056943 seconds
Monitored channels:
        learning_rate
        momentum
        total_seconds_last_epoch
        training_seconds_this_epoch
        valid_objective
        valid_y_col_norms_max
        valid_y_col_norms_mean
        valid_y_col_norms_min
        valid_y_max_max_class
        valid_y_mean_max_class
        valid_y_min_max_class
        valid_y_misclass
        valid_y_nll
        valid_y_row_norms_max
        valid_y_row_norms_mean
        valid_y_row_norms_min
Compiling accum...
graph size: 63
Compiling accum done. Time elapsed: 8.821601 seconds
Monitoring step:
        Epochs seen: 0
        Batches seen: 0
        Examples seen: 0
        learning_rate: 0.05
        momentum: 0.5
        total_seconds_last_epoch: 0.0
        training_seconds_this_epoch: 0.0
        valid_objective: 2.30245763578
        valid_y_col_norms_max: 0.0650026130651
        valid_y_col_norms_mean: 0.0641744853852
        valid_y_col_norms_min: 0.0624679393698
        valid_y_max_max_class: 0.105532125739
        valid_y_mean_max_class: 0.102753872501
        valid_y_min_max_class: 0.101059172742
        valid_y_misclass: 0.9031
        valid_y_nll: 2.30245763578
        valid_y_row_norms_max: 0.0125483545665
        valid_y_row_norms_mean: 0.00897718040255
        valid_y_row_norms_min: 0.00411555936503
Time this epoch: 18.159817 seconds
......
Monitoring step:
        Epochs seen: 50
        Batches seen: 25000
        Examples seen: 2500000
        learning_rate: 0.0183943399319
        momentum: 0.539357429719
        total_seconds_last_epoch: 21.789649
        training_seconds_this_epoch: 19.881821
        valid_objective: 0.0667691463031
        valid_y_col_norms_max: 1.93649990002
        valid_y_col_norms_mean: 1.93614117524
        valid_y_col_norms_min: 1.93520053981
        valid_y_max_max_class: 0.999997756761
        valid_y_mean_max_class: 0.980073621031
        valid_y_min_max_class: 0.548149309784
        valid_y_misclass: 0.02
        valid_y_nll: 0.0667691463031
        valid_y_row_norms_max: 0.546499525611
        valid_y_row_norms_mean: 0.264354016013
        valid_y_row_norms_min: 0.101427414171

9. 至此整個訓練過程結束
  想調引數可以在yaml檔案內調整, 另外引數資料在三個檔案內 dae_l1.pkl,  dae_l2.pkl,  dae_mlp.pkl

來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/16582684/viewspace-1243187/,如需轉載,請註明出處,否則將追究法律責任。

相關文章