상세 컨텐츠

본문 제목

[혼자 공부하는 머신러닝+딮러닝] 회귀 알고리즘과 모델 규제 - k 최근접 이웃 회귀

AI

by 래모 2023. 10. 28. 18:32

본문

지도 학습 알고리즘은 크게 분류와 회귀로 나뉜다.

 

회귀란?

: 임의의 어떤 숫자를 예측하는 문제

ex) 내년도 경제 성장률 예측, 배달 도착시간 예측, 농어의 무게를 예측

 

k-최근접 알고리즘이란?

예측하려는 샘플에 가장 가까운 샘플 k개를 선택한 후 이 샘플들의 클래스를 확인하여 다수 클래스를 새로운 샘플의 클래스로 예측한다.

k-최근접 이웃 회귀도 이와 비슷!!

  • 분류와 똑같이 예측하려는 샘플에 가장 가까운 샘플 k개를 선택 (회귀이기 때문에 이웃한 샘플의 타깃은 어떤 클래스가 아니라 임의의 수치)
  • 이웃 샘플의 수치를 사용해 새로운 샘플x의 타깃을 예측함 => 수치들의 평균 구하기!

 

KNeighborsRegressor

사이킷런에서 k-최근접 이웃 회귀 알고리즘을 구현한 클래스

객체를 생성하고 fit메서드로 회귀 모델을 훈련함

 

농어 데이터를 준비하고 이를 train_test_split으로 train과 test로 나누었다

그 이후에 KNeighborsRegressor의 fit으로 훈련을 시켜준다.

 

from sklearn.model_selection import train_test_split

train_input, test_input, train_target, test_target = train_test_split(perch_length, perch_weight, random_state = 42)

train_input = train_input.reshape(-1,1)
test_input = test_input.reshape(-1,1)
# 모두 하나의 배열로 하게 바꿈 즉 (42,1)

from sklearn.neighbors import KNeighborsRegressor

knr = KNeighborsRegressor()
knr.fit(train_input, train_target)

knr.score(test_input, test_target) #0.992809406101064

위 score함수를 통해 나온 값은 간단히 말해 정답을 맞힌 개수의 비율이다.

회귀의 경우 이 점수를 결정계수(R^2)라고 부른다

 

R^2 = 1 -  { (타깃 - 예측)^2의 합 } / { (타깃 - 평균)^2의 합 }

 

타깃의 평균 정도를 예측하려는 수준이라면 R^2은 0에 가까워지고

예측이 타킷에 아주 가까워지면 1에 가까운 값이 된다.

 

mean_absolute_error

: 타깃과 예측의 절댓값 오차를 평균하여 반환

 

from sklearn.metrics import mean_absolute_error

test_prediction = knr.predict(test_input)
mae = mean_absolute_error(test_target, test_prediction)
print(mae) # 19.157142857142862

 

과대적합 vs 과소적합

knr.score(test_input, test_target) # 0.992809406101064
knr.score(train_input, train_target)# 0.9698823289099254
과대적합(overfitting) 과소적합(underfitting)
훈련 세트에선 점수가 굉장히 좋았는데 테스트 세트에서는 점수가 나쁜 경우 훈련세트에 과대적합되었다고 말함 훈련세트보다 테스트 세트의 점수가 높거나 두 점수 모두 너무 낮은 경우 모델이 훈련 세트에 과소적합되었다고 말함

 

우리가 이전에 했던 과소적합의 상황이 일어남

이럴때 어떻게?

=> 모델을 복잡하게 만들자!

 

k-최근접 이웃 알고리즘으로 모델을 더 복잡하게 만드는 방법은 이웃의 개수 k를 줄이는 것

knr.n_neighbors = 3 # 기본값은 5임
knr.fit(train_input, train_target)

print(knr.score(train_input, train_target)) # 0.9804899950518966

print(knr.score(test_input, test_target)) # 0.9746459963987609

 

관련글 더보기