機器學習 - 多項式迴歸
多項式迴歸
如果您的資料點明顯不適合線性迴歸(穿過所有資料點的直線),那麼多項式迴歸可能就是理想的選擇。
多項式迴歸與線性迴歸一樣,使用變數 x 和 y 之間的關係來找到穿過資料點的最佳直線。

它是如何工作的?
Python 提供了用於查詢資料點之間關係和繪製多項式迴歸線的方法。我們將向您展示如何使用這些方法,而不是深入研究數學公式。
在下面的示例中,我們記錄了 18 輛汽車透過某個收費站的情況。
我們記錄了汽車的速度,以及透過的時間(小時)。
x 軸表示一天中的小時,y 軸表示速度。
示例
首先繪製散點圖
import matplotlib.pyplot as plt
x = [1,2,3,5,6,7,8,9,10,12,13,14,15,16,18,19,21,22]
y = [100,90,80,60,60,55,60,65,70,70,75,76,78,79,90,99,99,100]
plt.scatter(x, y)
plt.show()
結果
示例
匯入 numpy
和 matplotlib
,然後繪製多項式迴歸線
import numpy
import matplotlib.pyplot as plt
x = [1,2,3,5,6,7,8,9,10,12,13,14,15,16,18,19,21,22]
y = [100,90,80,60,60,55,60,65,70,70,75,76,78,79,90,99,99,100]
mymodel = numpy.poly1d(numpy.polyfit(x, y, 3))
myline = numpy.linspace(1, 22, 100)
plt.scatter(x, y)
plt.plot(myline, mymodel(myline))
plt.show()
結果
示例解釋
匯入所需的模組。
您可以在我們的 NumPy 教程中瞭解 NumPy 模組。
您可以在我們的 SciPy 教程中瞭解 SciPy 模組。
import numpy
import matplotlib.pyplot as plt
建立表示 x 和 y 軸值的陣列
x = [1,2,3,5,6,7,8,9,10,12,13,14,15,16,18,19,21,22]
y = [100,90,80,60,60,55,60,65,70,70,75,76,78,79,90,99,99,100]
NumPy 有一個方法可以讓我們建立多項式模型
mymodel = numpy.poly1d(numpy.polyfit(x, y, 3))
然後指定直線的顯示方式,我們從位置 1 開始,到位置 22 結束
myline = numpy.linspace(1, 22, 100)
繪製原始散點圖
plt.scatter(x, y)
繪製多項式迴歸線
plt.plot(myline, mymodel(myline))
顯示圖表
plt.show()
R-平方
瞭解 x 軸和 y 軸變數之間的關係有多好很重要,如果沒有關係,多項式迴歸就無法用於任何預測。
這種關係用一個稱為 r-squared 的值來衡量。
r-squared 值範圍從 0 到 1,其中 0 表示沒有關係,1 表示 100% 相關。
Python 和 Sklearn 模組可以為您計算此值,您只需將 x 和 y 陣列輸入其中即可
示例
我的資料在多項式迴歸中的擬合度如何?
import numpy
from sklearn.metrics import r2_score
x = [1,2,3,5,6,7,8,9,10,12,13,14,15,16,18,19,21,22]
y = [100,90,80,60,60,55,60,65,70,70,75,76,78,79,90,99,99,100]
mymodel = numpy.poly1d(numpy.polyfit(x, y, 3))
print(r2_score(y, mymodel(x)))
親自嘗試 »
注意:結果 0.94 表明存在非常好的關係,我們可以將多項式迴歸用於未來的預測。
預測未來值
現在我們可以利用收集到的資訊來預測未來值。
示例:讓我們嘗試預測一輛在 17:00 左右透過收費站的汽車的速度。
為此,我們需要上面示例中的相同 mymodel
陣列
mymodel = numpy.poly1d(numpy.polyfit(x, y, 3))
示例
預測 17:00 透過的汽車的速度
import numpy
from sklearn.metrics import r2_score
x = [1,2,3,5,6,7,8,9,10,12,13,14,15,16,18,19,21,22]
y = [100,90,80,60,60,55,60,65,70,70,75,76,78,79,90,99,99,100]
mymodel = numpy.poly1d(numpy.polyfit(x, y, 3))
speed = mymodel(17)
print(speed)
執行示例 »
該示例預測的速度為 88.87,我們也可以從圖表中讀出這個值。

擬合不佳?
讓我們建立一個多項式迴歸不適合用於預測未來值的示例。
示例
這些 x 軸和 y 軸的值對於多項式迴歸來說應該擬合得很差。
import numpy
import matplotlib.pyplot as plt
x = [89,43,36,36,95,10,66,34,38,20,26,29,48,64,6,5,36,66,72,40]
y = [21,46,3,35,67,95,53,72,58,10,26,34,90,33,38,20,56,2,47,15]
mymodel = numpy.poly1d(numpy.polyfit(x, y, 3))
myline = numpy.linspace(2, 95, 100)
plt.scatter(x, y)
plt.plot(myline, mymodel(myline))
plt.show()
結果
那麼 r-squared 值是多少?
示例
您應該得到一個非常低的 r-squared 值。
import numpy
from sklearn.metrics import r2_score
x = [89,43,36,36,95,10,66,34,38,20,26,29,48,64,6,5,36,66,72,40]
y = [21,46,3,35,67,95,53,72,58,10,26,34,90,33,38,20,56,2,47,15]
mymodel = numpy.poly1d(numpy.polyfit(x, y, 3))
print(r2_score(y, mymodel(x)))
親自嘗試 »
結果:0.00995 表明關係非常糟糕,並且告訴我們該資料集不適合多項式迴歸。