cs231n nn分类

    xiaoxiao2021-04-13  42

    #python3 import numpy as np def unpickle(file): #数据集的python3 实例 import pickle with open(file, 'rb') as fo: dict = pickle.load(fo, encoding='bytes') return dict def load_CIFAR10(file): #get the training data 因为是一bytes编码的,需要在标签前面加b,提取数据 dataTrain = [] labelTrain = [] for i in range(1,6): dic = unpickle(file+"\\data_batch_"+str(i)) for item in dic[b"data"]: dataTrain.append(item) for item in dic[b"labels"]: labelTrain.append(item) #get test data dataTest = [] labelTest = [] dic = unpickle(file+"\\test_batch") for item in dic[b"data"]: dataTest.append(item) for item in dic[b"labels"]: labelTest.append(item) return dataTrain,labelTrain,dataTest,labelTest Xtr, Ytr, Xte, Yte = load_CIFAR10('tedata/cifar-10-batches-py') Xtr = np.asarray(Xtr) Xte = np.asarray(Xte) Ytr = np.asarray(Ytr) Yte = np.asarray(Yte) #Xtr_rows = Xtr.reshape(Xtr.shape[0], 32 * 32 * 3) # Xtr_rows becomes 50000 x 3072 #两种方式选一种 #Xte_rows = Xte.reshape(Xte.shape[0], 32 * 32 * 3) # Xte_rows becomes 10000 x 3072 class NearestNeighbor(object): def __init__(self): pass def train(self, X,y): self.xtr = X self.ytr = y def predict(self, X): num_test = X.shape[0] # lets make sure that the output type matches the input type Ypred = np.zeros(num_test, dtype = self.ytr.dtype) # loop over all test rows for i in range(num_test): distances = np.sqrt(np.sum(np.square(self.xtr - X[i,:]), axis = 1)) min_index = np.argmin(distances) # get the index with smallest distance Ypred[i] = self.ytr[min_index] # predict the label of the nearest example return Ypred nn = NearestNeighbor() # create a Nearest Neighbor classifier class nn.train(Xtr, Ytr) # train the classifier on the training images and labels Yte_predict = nn.predict(Xte) # predict labels on the test images # and now print the classification accuracy, which is the average number # of examples that are correctly predicted (i.e. label matches) print ('accuracy: %f' % ( np.mean(Yte_predict == Yte) ))

    这个代码跑了比较费时,我跑了一个小时才出来结果,主要原因是在predict过程中计算10000条数据费时。

    转载请注明原文地址: https://ju.6miu.com/read-668711.html

    最新回复(0)