Amateur Hour

10行代码实现kNN(K Nearesr Neighbor)算法

2018-11-28

使用 numpy 库,只需 10 行简单的代码就能实现 k 近邻算法。

算法逻辑

对要分类的点(X)进行下列运算:

  1. 计算 X 与已知分类的所有点的距离(欧氏距离);
  2. 距离按照递增排序;
  3. 选取距离最小的 k 个点;
  4. 计算这 k 个点中,每个分类出现的频率;
  5. 取频率最高的分类为预测分类。

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def kNN(X, dataset, labels, k):
# 计算X与所有其他点的差值
diff = np.tile(X, (dataset.shape[0], 1)) - dataset
# 计算欧氏距离
distances = ((diff**2).sum(axis=1))**0.5
# 排序
sorted_idx = distances.argsort()
class_cnt = {}
for i in range(k):
label = labels[sorted_idx[i]]
class_cnt[label] =class_cnt.get(label, 0) + 1
# 找出频率最高的分类作为预测结果
sorted_class = sorted(class_cnt.iteritems(), key=operator.itemgetter(1), reverse=True)
return sorted_class[0][0]

enjoy!