- Machine Learning for OpenCV
- Michael Beyeler
- 428字
- 2021-07-02 19:47:24
Testing the model
In order to test the generalization performance of the model, we calculate the mean squared error on the test data:
In [11]: y_pred = linreg.predict(X_test)
In [12]: metrics.mean_squared_error(y_test, y_pred)
Out[12]: 15.010997321630166
We note that the mean squared error is a little lower on the test set than the training set. This is good news, as we care mostly about the test error. However, from these numbers, it is really hard to understand how good the model really is. Perhaps it's better to plot the data:
In [13]: plt.figure(figsize=(10, 6))
... plt.plot(y_test, linewidth=3, label='ground truth')
... plt.plot(y_pred, linewidth=3, label='predicted')
... plt.legend(loc='best')
... plt.xlabel('test data points')
... plt.ylabel('target value')
Out[13]: <matplotlib.text.Text at 0x7ff46783c7b8>
This produces the following figure:

This makes more sense! Here we see the ground truth housing prices for all test samples in blue and our predicted housing prices in red. Pretty close, if you ask me. It is interesting to note though that the model tends to be off the most for really high or really low housing prices, such as the peak values of data point 12, 18, and 42. We can formalize the amount of variance in the data that we were able to explain by calculating R squared:
In [14]: plt.plot(y_test, y_pred, 'o')
... plt.plot([-10, 60], [-10, 60], 'k--')
... plt.axis([-10, 60, -10, 60])
... plt.xlabel('ground truth')
... plt.ylabel('predicted')
This will plot the ground truth prices, y_test, on the x axis, and our predictions, y_pred, on the y axis. We also plot a diagonal line for reference (using a black dashed line, 'k--'), as we will see soon. But we also want to display the R2 score and mean squared error in a text box:
... scorestr = r'R$^2$ = %.3f' % linreg.score(X_test, y_test)
... errstr = 'MSE = %.3f' % metrics.mean_squared_error(y_test, y_pred)
... plt.text(-5, 50, scorestr, fontsize=12)
... plt.text(-5, 45, errstr, fontsize=12)
Out[14]: <matplotlib.text.Text at 0x7ff4642d0400>
This will produce the following figure, and is a professional way of plotting a model fit:

If our model was perfect, then all data points would lie on the dashed diagonal, since y_pred would always be equal to y_true. Deviations from the diagonal indicate that the model made some errors, or that there is some variance in the data that the model was not able to explain. Indeed, R2 indicates that we were able to explain 76 percent of the scatter in the data, with a mean squared error of 15.011. These are some hard numbers we can use to compare the linear regression model to some more complicated ones.
- Apache ZooKeeper Essentials
- Mastering Adobe Captivate 2017(Fourth Edition)
- Docker進階與實戰(zhàn)
- 單片機C語言程序設(shè)計實訓100例:基于STC8051+Proteus仿真與實戰(zhàn)
- Reactive Android Programming
- C語言課程設(shè)計
- Arduino可穿戴設(shè)備開發(fā)
- 算法圖解
- 超簡單:Photoshop+JavaScript+Python智能修圖與圖像自動化處理
- Xamarin Cross-Platform Development Cookbook
- C語言程序設(shè)計教程
- 用Python動手學統(tǒng)計學
- Scratch少兒編程高手的7個好習慣
- Elasticsearch實戰(zhàn)(第2版)
- 軟件測試