使用 numpy 库,只需 10 行简单的代码就能实现 k 近邻算法。
算法逻辑
对要分类的点(X)进行下列运算:
- 计算 X 与已知分类的所有点的距离(欧氏距离);
- 距离按照递增排序;
- 选取距离最小的 k 个点;
- 计算这 k 个点中,每个分类出现的频率;
- 取频率最高的分类为预测分类。
代码实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| def kNN(X, dataset, labels, k): # 计算X与所有其他点的差值 diff = np.tile(X, (dataset.shape[0], 1)) - dataset # 计算欧氏距离 distances = ((diff**2).sum(axis=1))**0.5 # 排序 sorted_idx = distances.argsort() class_cnt = {} for i in range(k): label = labels[sorted_idx[i]] class_cnt[label] =class_cnt.get(label, 0) + 1 # 找出频率最高的分类作为预测结果 sorted_class = sorted(class_cnt.iteritems(), key=operator.itemgetter(1), reverse=True) return sorted_class[0][0]
|
enjoy!