我們現在來看看, 可不可以讓電腦辨識, 這是哪個亞種的鳶尾花?
from sklearn.datasets import load_iris
iris = load_iris()
準備輸入及輸出數據, 注意 4 個特徵我們只用了兩個。
x = iris.datay = iris.target
X = x[:, 2:]Y = y
切分訓練及測試資料。
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=0)
看一下整筆數據的分佈。
plt.scatter(X[:,0], X[:,1], c=Y, cmap='Paired')
再一次, 三部曲打造函數學習機。
from sklearn.svm import SVC
clf = SVC()
clf.fit(x_train, y_train)
SVC()
y_predict= clf.predict(x_test)
看看我們模型預測和真實狀況差多少?
y_predict - y_test
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
在測試資料中是全對!! 我們畫圖來看看整體表現如何?
x0 = np.linspace(0, 7.5, 500)y0 = np.linspace(0, 2.7, 500)xm, ym = np.meshgrid(x0, y0)P = np.c_[xm.ravel(), ym.ravel()]z = clf.predict(P)Z = z.reshape(xm.shape)plt.contourf(xm, ym, Z, alpha=0.3)plt.scatter(X[:,0], X[:,1], c=Y)