본문 바로가기
AI/AI

[AI] K-Nearest Neighbors 직접 구현

by Foxy현 2022. 10. 19.
728x90
반응형

K-Nearest Neighbor은 대표적인 분류 알고리즘입니다.

유사한 속성을 가진 데이터는 유사한 그룹에 속한다는 아이디어로 사용합니다.

 

위의 그림을 보면 모든 데이터는 1,2,3이라는 각각 다른 색상으로 분류되어있다. 하지만

새로 입력한 저 빨간 점에 대한 분류는 어떻게 하는 것일까?

 

이에 KNN이라는 알고리즘을 도입하게 됐는데, 간단히 요점을 나열하자면

  • 유사한 데이터들끼리의 거리는 비교적 가깝다.
  • 분류를 알 수 없는 새로운 데이터는 가장 가까운 이웃 k개의 분류를 확인하여 vote 한다
  • k의 개수가 너무 작으면 과대적합이 일어날 수 있다
  • k의 개수가 너무 많으면 과소적합이 일어날 수 있다

필요한 라이브러리 불러오기

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import math
from collections import Counter
warnings.filterwarnings('ignore')

데이터 준비

center_1 = np.array([1,1])
center_2 = np.array([5,5])
center_3 = np.array([8,1])

data = np.random.randn(10,2) + center_1
data = np.concatenate((data,np.random.randn(10,2)+center_2))
data = np.concatenate((data,np.random.randn(10,2)+center_3))

cluster = np.ones(10)
cluster = np.concatenate((cluster,2*np.ones(10)))
cluster = np.concatenate((cluster,3*np.ones(10)))

data_new=[6,5]
data_new = np.asarray(data_new)

리스트 생성

k=3
dis = [];k_near= [];
index = [];cls= [];values=[]

이제 필요한 준비는 마쳤고, 각 데이터들을 그래프로 그려보면,

라는 그래프가 나온다.

우리의 목표는 k(3) 개의 가장 가까운 점들을 찾아, 그 점들의 카테고리를 확인하여 다수결로 새로운 점의 카테고리를 정하는 일이다.

 

먼저, 점들 간의 거리를 구하기 위해 유클리디언 거리를 이용하겠다.

for i in range(len(data)-1):
    dis.append(np.sqrt((data_new[0]-data[i,0])**2+(data_new[1]-data[i,1])**2))
dis.sort()

각 30개의 점들에 대한 거리를 구하고, 오름차순으로 정렬한다. 이후 우리가 할 일은? 

가장 가까운 k(3) 개의 점을 찾아야 하므로 30개 중에서 가장 맨 앞의 3개를 알면 되는 것이다.

 

for i in range(len(data)-1):
    if np.sqrt((data_new[0]-data[i,0])**2+(data_new[1]-data[i,1])**2)<dis[k]:
        plt.scatter(data[i,0],data[i,1],marker='*',c='b',s=100)
        k_near.append(data[i])
        values.append(np.sqrt((data_new[0]-data[i,0])**2+(data_new[1]-data[i,1])**2))
        index.append(i)
        cls.append(cluster[i])

해석을 하자면, 앞에서 4번째의 배열보다 값이 작다면, if문을 수행하는 것인데, if라는 조건 안에 들면 우리가 찾던

k개의 점 중 하나를 찾게 되는 것이다. 우리는 이 인덱스의 데이터를 저장하여 k_near라는 리스트 안에 저장하여 k점들을 찾을 수 있다.

 

데이터의 몇 번 인덱스가, 그리고 그 거리 값이 무엇인지 확인할 수 있었다.

most = Counter(cls).most_common(1)
print(most[0][0])

최빈값을 구하는 함수를 사용하여 새로운 점의 부류를 확인할 수 있다.

 

Hola

728x90
반응형