使用kNN实现手写体识别

【KNN的总结】本质就是使用测试图片与样本图片进行比较,找到K个最近的图片,在K个图片中选择概率出现最高的那一个,把数字记录下来,这个数字就是最终目标。步骤如下: 1)数据的加载。注意是随机数的加载 有4组,分别为训练数据,训练标签,测试图片,测试标签 2)计算测试图片与训练图片的距离 3)计算K个最近的图片(实际上就是排序) 4)将得到的最近的图片转换为标签,并且对标签按照少数服从多数的原则,得到最终的标签 5)检测概率统计(将测试得到的标签与实际的标签进行比较)

 

可以修改的地方: 1)K值 2)测试图片和训练图片的数目

 

import tensorflow as tf

 

import numpy as np

 

import random

 

from tensorflow.examples.tutorials.mnist import input_data

 

#导入Mnist数据集

 

mnist = input_data.read_data_sets(r”path”,one_hot=True)

 

#属性设置

 

trainnum=55000

 

testnum=10000

 

trainsize=1000

 

testsize=500

 

k=20

 

#下来将数据进行分解

 

trainindex=np.random.choice(trainnum,trainsize,replace=False)

 

testindex=np.random.choice(testnum,testsize,replace=False)

 

traindata=mnist.train.images[trainindex]#训练图片

 

trainlabel=mnist.train.labels[trainindex]#训练标签

 

testdata=mnist.test.images[testindex]#测试图片

 

testlabel=mnist.test.labels[testindex]#测试标签

 

#数据定义好了之后,就需要用tensorflow来定义输入(需要的训练数据就已经定义好了)

 

traindatainput=tf.placeholder(shape=[None,784],dtype=tf.float32)

 

#正确的标签

 

trainlabelinput=tf.placeholder(shape=[None,10],dtype=tf.float32)#到这训练数据的数据和标签就已经生成

 

#再把测试数据和测试标签生成一下

 

testdatainput=tf.placeholder(shape=[None,784],dtype=tf.float32)

 

testlabelinput=tf.placeholder(shape=[None,10],dtype=tf.float32)#到这里测试数据的数据和标签就已经准备好

 

#在数据全部准备完之后,就可以开始进行训练了

 

#计算knn距离

 

f1=tf.expand_dims(testdatainput,1)#将当前的输入数据增加一项这样转换的目的是要用来计算数据应该是一个3维数据(3D)

 

f2=tf.subtract(traindatainput,f1)#就得到了3维数据,测试数据与500个的距离

 

f3=tf.reduce_sum(tf.abs(f2),reduction_indices=2)#这一步完成数据的累加,这里的差值是取绝对值之后的f3是一个(5*500的)

 

f4=tf.negative(f3)#p4完成取反功能

 

f5,f6=tf.nn.top_k(f4,k=20)#选取f4中最大的四个值,相当于f3中最小的四个值,f5存的是最近的距离,f6存入的是最近的值的下标

 

f7=tf.gather(trainlabelinput,f6)#f6存放的是最近的点的下标,根据下标来索引图片标签

 

#最后一步应该是将当前的lbel转换为数字

 

f8=tf.reduce_sum(f7,reduction_indices=1)#将竖直方向的量进行累加,这样少数到时候服从多数,竖直方向相加的值代表了哪个次数最大

 

f9=tf.arg_max(f8,dimension=1)#tf.argmax代表的是找最大的数值所对应的下标

 

with tf.Session()as sess:

 

p1=sess.run(f1,feed_dict={testdatainput:testdata[0:500]})

 

print(‘p1=’,p1.shape)

 

p2=sess.run(f2,feed_dict={traindatainput:traindata,testdatainput:testdata[0:500]})

 

print(‘p2=’,p2.shape)#P2=(5,5000,784)

 

p3=sess.run(f3,feed_dict={traindatainput:traindata,testdatainput:testdata[0:500]})

 

print(‘p3=’,p3.shape)

 

print(‘p3[0,0]=’,p3[0,0])

 

p4=sess.run(f4,feed_dict={traindatainput:traindata,testdatainput:testdata[0:500]})

 

print(‘p4=’,p4.shape)

 

print(‘p4[0,0]’,p4[0,0])

 

p5,p6=sess.run((f5,f6),feed_dict={traindatainput:traindata,testdatainput:testdata[0:500]})

 

#每一张测试图片(5张)分别对应值的最近的4张图片

 

print(‘p5=’,p5.shape)

 

print(‘p6=’,p6.shape)

 

print(‘p5’,p5[0])

 

print(‘p6’,p6[0])#到这里距离和下标已经知道, 但并不知道图片描述的是哪些点,因此需要解析这四个最近点的内容

 

p7=sess.run(f7,feed_dict={traindatainput:traindata,testdatainput:testdata[0:500],trainlabelinput:trainlabel})

 

print(‘p7’,p7.shape)

 

print(‘p7’,p7)

 

p8=sess.run(f8,feed_dict={traindatainput:traindata,testdatainput:testdata[0:500],trainlabelinput:trainlabel})

 

print(‘p8.shape’,p8.shape)

 

print(‘p8[]=’,p8)

 

p9=sess.run(f9,feed_dict={traindatainput:traindata,testdatainput:testdata[0:500],trainlabelinput:trainlabel})

 

print(‘p9.shape’,p9.shape)

 

print(‘p9[]=’,p9)

 

p10=np.argmax(testlabel[0:500],axis=1)#p10代表的是样本标签

 

j=0

 

for i in range(0,500):

 

if p10[i]==p9[i]:

 

j=j+1

 

print(‘acc=’,j*100/500)

发表评论

电子邮件地址不会被公开。 必填项已用*标注