import numpy
as np
def unpickle(file):
import pickle
with open(file,
'rb')
as fo:
dict = pickle.load(fo, encoding=
'bytes')
return dict
def load_CIFAR10(file):
dataTrain = []
labelTrain = []
for i
in range(
1,
6):
dic = unpickle(file+
"\\data_batch_"+str(i))
for item
in dic[
b"data"]:
dataTrain.append(item)
for item
in dic[
b"labels"]:
labelTrain.append(item)
dataTest = []
labelTest = []
dic = unpickle(file+
"\\test_batch")
for item
in dic[
b"data"]:
dataTest.append(item)
for item
in dic[
b"labels"]:
labelTest.append(item)
return dataTrain,labelTrain,dataTest,labelTest
Xtr, Ytr, Xte, Yte = load_CIFAR10(
'tedata/cifar-10-batches-py')
Xtr = np.asarray(Xtr)
Xte = np.asarray(Xte)
Ytr = np.asarray(Ytr)
Yte = np.asarray(Yte)
class NearestNeighbor(object):
def __init__(self):
pass
def train(self, X,y):
self.xtr = X
self.ytr = y
def predict(self, X):
num_test = X.shape[
0]
Ypred = np.zeros(num_test, dtype = self.ytr.dtype)
for i
in range(num_test):
distances = np.sqrt(np.sum(np.square(self.xtr - X[i,:]), axis =
1))
min_index = np.argmin(distances)
Ypred[i] = self.ytr[min_index]
return Ypred
nn = NearestNeighbor()
nn.train(Xtr, Ytr)
Yte_predict = nn.predict(Xte)
print (
'accuracy: %f' % ( np.mean(Yte_predict == Yte) ))
这个代码跑了比较费时,我跑了一个小时才出来结果,主要原因是在predict过程中计算10000条数据费时。
转载请注明原文地址: https://ju.6miu.com/read-668711.html