版权声明:本文为博主原创文章,未经博主允许不得转载。
久闻Pylearn大名,最近由于某些原因,想拿Pylearn2来做做实验。Python这门神奇的语言几年前学过点皮毛,对于一个习惯使用C++/Java的人来说,使用Python简直就是一种折磨,特别是阅读别人的代码。这种弱类型语言最让人纠结的就是不好跟踪代码(应该是我不会跟踪)。不喜欢Python还有一个原因,Python虽有多线程,但是此多线程非彼多线程,不管你开多少个线程,它只使用CPU的一个核,多核对它来说完全没用,而且线程越多速度反而越慢。好了,不废话了,回归正题。
Pylearn2的安装及环境搭建这里就不赘述了,网上资料很多的。对Pylearn2的使用说明中文基本没有(反正我没搜到),英文的还是有一写的。为了节省新手(我也是)的时间,这里简单介绍一下使用Pylearn构建一个简单的分类器的方法,会了一个应该其他的就好办了,入门是关键。
本文基于pylearn2-practice进行介绍(之所以不根据Quick-start example来讲是因为它结构比较复杂,而且只有训练,没有预测部分,看了后还是不知道如何使用自己的数据,可能是本人智商不够)。
pylearn2-practice文件组织如下
├── adult │ ├── test.csv │ ├── test_v.csv │ └── train_v.csv ├── adult_dataset.py ├── adult.yaml ├── predict.py └── README.md
如何使用见README.md
我们先看看训练的配置文件adult.yaml
!obj:pylearn2.train.Train { #使用类pylearn2.train.Train的实例来训练 dataset: &train !obj:adult_dataset.AdultDataset { #使用adult_dataset.AdultDataset来提供训练数据 path: 'adult/train_v.csv', #训练数据文件路径 one_hot: 1 }, model: !obj:pylearn2.models.softmax_regression.SoftmaxRegression { #选择模型, 可以选择的模型包括RBM, softmax_regression, SVM等等
#具体参看http://deeplearning.net/software/pylearn2/features.html#features中的Models n_classes: 2, irange: 0., nvis: 123, }, algorithm: !obj:pylearn2.training_algorithms.bgd.BGD {#选择训练算法,具体参看
#http://deeplearning.net/software/pylearn2/features.html#features中的Training algorithms batch_size: 10000, line_search_mode: 'exhaustive', conjugate: 1, monitoring_dataset: { 'train' : *train, 'valid' : !obj:adult_dataset.AdultDataset { path: 'adult/test_v.csv', one_hot: 1 }, 'test' : !obj:adult_dataset.AdultDataset { path: 'adult/test.csv', one_hot: 1 } }, termination_criterion: !obj:pylearn2.termination_criteria.MonitorBased { channel_name: "valid_y_misclass" } }, extensions: [ !obj:pylearn2.train_extensions.best_params.MonitorBasedSaveBest { channel_name: 'valid_y_misclass', save_path: "softmax_regression_best.pkl" }, ], save_path: "softmax_regression.pkl", #训练得到的模型保存到softmax_regression.pkl这个文件 save_freq: 1 }
细心的你可能会发现上面的配置文件其实是给类的构造函数指定参数, 事实上Pylearn2就是使用PyYAML根据上面配置信息实例化一个Train对象.
我们再看看adult_dataset.py
[python] view plain copy import csv import numpy as np import os from pylearn2.datasets.dense_design_matrix import DenseDesignMatrix from pylearn2.utils import serial from pylearn2.utils.string_utils import preprocess class AdultDataset( DenseDesignMatrix ): def __init__(self, path = 'train.csv', one_hot = False, with_labels = True, start = None, stop = None, preprocessor = None, fit_preprocessor = False, fit_test_preprocessor = False): self.no_classes = 2 # won't work TODO self.test_args = locals() self.test_args['which_set'] = 'test' self.test_args['fit_preprocessor'] = fit_test_preprocessor del self.test_args['start'] del self.test_args['stop'] del self.test_args['self'] path = preprocess(path) X, y = self._load_data( path, with_labels ) if start is not None: assert which_set != 'test' assert isinstance(start, int) assert isinstance(stop, int) assert start >= 0 assert start < stop assert stop <= X.shape[0] X = X[start:stop, :] if y is not None: y = y[start:stop, :] super(AdultDataset, self).__init__(X=X, y=y) if preprocessor: preprocessor.apply(self, can_fit=fit_preprocessor) def _load_data(self, path, expect_labels): assert path.endswith('.csv') data = np.loadtxt( path, delimiter = ',', dtype = 'int' ) if expect_labels: y = data[:,0] X = data[:,1:] # TODO: if one_hot # 10 is number of possible y values one_hot = np.zeros((y.shape[0], self.no_classes ),dtype='float32') for i in xrange( y.shape[0] ): label = y[i] if label == 1: one_hot[i,1] = 1. else: one_hot[i,0] = 1. y = one_hot else: X = data y = None return X, y Adultataset继承于DenseDesignMatrix, 通过函数_load_data将训练/测试数据加载到X, y中, x基本没有变化, 但是y进行了相应的变换, 类别数决定了y的维数, 每一维代表一种类别,假如某条数据属于第N类,则y的第N维为1, 其他维为0(这里第N维从1开始计数,而不是从0开始计数). 然后将X, y赋给父类的X,y, 如果设置了preprocessor还需要执行preprocessor.
再看看predict.py.
[python] view plain copy import sys import os from pylearn2.utils import serial from pylearn2.config import yaml_parse from adult_dataset import AdultDataset try: model_path = sys.argv[1] test_path = sys.argv[2] out_path = sys.argv[3] except IndexError: print "Usage: predict.py <model file> <test file> <output file>" quit() try: model = serial.load( model_path ) except Exception, e: print model_path + "doesn't seem to be a valid model path, I got this error when trying to load it: " print e #dataset = yaml_parse.load( model.dataset_yaml_src ) #dataset = dataset.get_test_set() # or maybe specify test in yaml dataset = AdultDataset( path = test_path, one_hot = True ) # use smallish batches to avoid running out of memory batch_size = 100 model.set_batch_size(batch_size) # dataset must be multiple of batch size of some batches will have # different sizes. theano convolution requires a hard-coded batch size m = dataset.X.shape[0] extra = batch_size - m % batch_size assert (m + extra) % batch_size == 0 import numpy as np if extra > 0: dataset.X = np.concatenate((dataset.X, np.zeros((extra, dataset.X.shape[1]), dtype=dataset.X.dtype)), axis=0) assert dataset.X.shape[0] % batch_size == 0 X = model.get_input_space().make_batch_theano() Y = model.fprop(X) from theano import tensor as T y = T.argmax(Y, axis=1) from theano import function f = function([X], y) y = [] for i in xrange(dataset.X.shape[0] / batch_size): x_arg = dataset.X[i*batch_size:(i+1)*batch_size,:] if X.ndim > 2: x_arg = dataset.get_topological_view(x_arg) y.append(f(x_arg.astype(X.dtype))) y = np.concatenate(y) assert y.ndim == 1 assert y.shape[0] == dataset.X.shape[0] # discard any zero-padding that was used to give the batches uniform size y = y[:m] class_mapping = { 0: -1, 1: 1 } out = open(out_path, 'w') for i in xrange(y.shape[0]): p = y[i] p = class_mapping[p] out.write( '%d\n' % ( p )) out.close() 第一步从测试数据文件中获取数据dataset = AdultDataset( path = test_path, one_hot = True ),然后就是对测试数据分批次(如果测试数据不是特别多, 内存够大的话就没必要分批次),
接着是构建 Theano function(深入了解见Theano说明文档)
X = model.get_input_space().make_batch_theano() #X为输入的特征 Y = model.fprop(X) #fprop(x)使用模型预测
from theano import tensor as T
y = T.argmax(Y, axis=1) from theano import function f = function([X], y)
然后就是调用 f 对 x 进行预测了
到这里,分类器构建完成, 这么看来,其实还是蛮简单的, 虽然简单,但是我可是花了大半天的功夫, 一开始花了大量时间看Quick-start example(结果还是一头雾水).由此说明一篇好的文档非常重要, 不然你就慢慢去摸索吧.
感谢zygmuntz的无私奉献