YD's blog

Posted 二 26 4月 2016

K-Nearest Neighbor algorithms (KNN)

如果我們有一筆資料 $(x_1,y_1),...,(x_n,y_n)$,

且 $x_i \in \mathbb{R}^2, y_i \in {0,1}$ (即 $x_i=(x_1,x_2)$為二維的自變項以及 $y_i = 0 ~or~ 1$的依變項)。如圖所示:

橫軸為 $x_1$、縱軸為 $x_2$、藍色和紅色分別為已知的分類(0或1)。圖中的黃點並非髒污,而是要進行分類的 $x$。

KNN的概念其實很就是去算點與點之間的距離,距離越近表示類型越像(和K-means有異曲同工之妙,但注意,KNN是屬於Supervised learning,這點和K-means不同)。 若 $k=1$,則是根據離黃點最近的「1個點」之分類為分類; $k=3$則是根據離黃點最近的「3個點」之分類為分類並採取多數決(majority vote of k nearest)。以下圖為例:

此例為 $k=3$,意即離黃點最近的三個點分別為1,1,0的分類,依照多數決應將黃點分類為類別1。

若將KNN寫成公式的話:

$d(x_i,x_j)=\sqrt{\sum_{k=1}^{d}(x_{ik}-x_{jk})^2}, x_i=(x_{i1},...,x_{id})$

此為 $k=1$的例子,已知紅點距離最近的點為藍色,因此判斷紅色的分類為1。有興趣的話可以參考下列R code。可以自己調整參數玩玩看

# K-Nearest Neighbor classification ---------------------------------------
library(data.table)
library(ggplot2)
#training data
#simplify to two-dimensional real numbers
dt <- data.table(x1=sample(1:10,30,replace=T),x2=sample(1:10,30,replace=T),label=as.factor(sample(0:1,30,replace=T)))

#new x^(2)=(x1,x2)
newX=c(8,7)
k=1 #k could be a odd number

#find out the label of k smallest distances of x
nearestLabels <- dt$label[!is.na(match(sqrt((newX[1]-dt$x1)^2+(newX[2]-dt$x2)^2),sort(sqrt((newX[1]-dt$x1)^2+(newX[2]-dt$x2)^2))[1:k]))]

#majority vote of k nearest points
tb <- as.data.table(table(nearestLabels))
predictLabel <- paste("predict: ",tb$nearestLabels[match(max(tb$N),tb$N)],sep='')
predict.dt <- data.table(x1=newX[1],x2=newX[2],label=predictLabel)

#bind preditVal to original dt
dt <- rbind(predict.dt,dt)

#plot the outcome
qplot(x1,x2,colour=label,data=dt)

Category: Stat
Tags: Stat R