scikit -learn 的使用

    xiaoxiao2021-12-15  33

    由于做毕业论文方向是文本分类,需要用到scikit -learn 工具,借鉴前辈的基础上做了如下实验:

    参考了scikit-learn的官方网站

    1. 数据准备

    关于分类,我们使用了Iris数据集,这个scikit-learn自带了.  Iris数据集是常用的分类实验数据集,由Fisher, 1936收集整理。Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。

    注意,Iris数据集给出的三种花是按照顺序来的,前50个是第0类,51-100是第1类,101~150是第二类,如果我们分训练集和测试集的时候要把顺序打乱  这里我们引入一个两类shuffle的函数,它接收两个参数,分别是x和y,然后把x,y绑在一起shuffle.

    def shuffle_in_unison(a, b): assert len(a) == len(b) import numpy shuffled_a = numpy.empty(a.shape, dtype=a.dtype) shuffled_b = numpy.empty(b.shape, dtype=b.dtype) permutation = numpy.random.permutation(len(a)) for old_index, new_index in enumerate(permutation): shuffled_a[new_index] = a[old_index] shuffled_b[new_index] = b[old_index] return shuffled_a, shuffled_b 1234567891011 1234567891011

    下面我们导入Iris数据并打乱它,然后分为100个训练集和50个测试集

    from sklearn.datasets import load_iris iris = load_iris() def load_data(): iris.data, iris.target = shuffle_in_unison(iris.data, iris.target) x_train ,x_test = iris.data[:100],iris.data[100:] y_train, y_test = iris.target[:100].reshape(-1,1),iris.target[100:].reshape(-1,1) return x_train, y_train, x_test, y_test 12345678 12345678

    2. 利用SVM, kNN, 朴素贝叶斯, 集成方法有随机森林,Adaboost和GBDT等方法进行分类,代码如下: 

    from sklearn.datasets import load_iris iris = load_iris() def shuffle_in_unison(a, b): assert len(a) == len(b) import numpy shuffled_a = numpy.empty(a.shape, dtype=a.dtype) shuffled_b = numpy.empty(b.shape, dtype=b.dtype) permutation = numpy.random.permutation(len(a)) for old_index, new_index in enumerate(permutation): shuffled_a[new_index] = a[old_index] shuffled_b[new_index] = b[old_index] return shuffled_a, shuffled_b def load_data(): iris.data, iris.target = shuffle_in_unison(iris.data, iris.target) x_train ,x_test = iris.data[:100],iris.data[100:] y_train, y_test = iris.target[:100].reshape(-1,1),iris.target[100:].reshape(-1,1) return x_train, y_train, x_test, y_test from sklearn import tree, svm, naive_bayes,neighbors from sklearn.ensemble import BaggingClassifier, AdaBoostClassifier, RandomForestClassifier, GradientBoostingClassifier x_train, y_train, x_test, y_test = load_data() clfs = {'svm': svm.SVC(),\ 'decision_tree':tree.DecisionTreeClassifier(), 'naive_gaussian': naive_bayes.GaussianNB(), \ 'naive_mul':naive_bayes.MultinomialNB(),\ 'K_neighbor' : neighbors.KNeighborsClassifier(),\ 'bagging_knn' : BaggingClassifier(neighbors.KNeighborsClassifier(), max_samples=0.5,max_features=0.5), \ 'bagging_tree': BaggingClassifier(tree.DecisionTreeClassifier(), max_samples=0.5,max_features=0.5), 'random_forest' : RandomForestClassifier(n_estimators=50),\ 'adaboost':AdaBoostClassifier(n_estimators=50),\ 'gradient_boost' : GradientBoostingClassifier(n_estimators=50, learning_rate=1.0,max_depth=1, random_state=0) } def try_different_method(clf): clf.fit(x_train,y_train.ravel()) score = clf.score(x_test,y_test.ravel()) print('the score is :', score) for clf_key in clfs.keys(): print('the classifier is :',clf_key) clf = clfs[clf_key] try_different_method(clf) 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849

    得到的结果得分很高,如下所示:

    ('the classifier is :', 'decision_tree') ('the score is :', 0.92000000000000004) ('the classifier is :', 'naive_gaussian') ('the score is :', 0.93999999999999995) ('the classifier is :', 'gradient_boost') ('the score is :', 0.92000000000000004) ('the classifier is :', 'svm') ('the score is :', 0.93999999999999995) ('the classifier is :', 'random_forest') ('the score is :', 0.92000000000000004) ('the classifier is :', 'bagging_knn') ('the score is :', 0.92000000000000004) ('the classifier is :', 'naive_mul') ('the score is :', 0.80000000000000004) ('the classifier is :', 'K_neighbor') ('the score is :', 0.92000000000000004) ('the classifier is :', 'bagging_tree') ('the score is :', 0.90000000000000002) ('the classifier is :', 'adaboost') ('the score is :', 0.92000000000000004) 123456789101112131415161718192021 123456789101112131415161718192021 顶 0

    转载请注明原文地址: https://ju.6miu.com/read-1000253.html

    最新回复(0)