官术网_书友最值得收藏!

Testing the model

In order to test the generalization performance of the model, we calculate the mean squared error on the test data:

In [11]: y_pred = linreg.predict(X_test)
In [12]: metrics.mean_squared_error(y_test, y_pred)
Out[12]: 15.010997321630166

We note that the mean squared error is a little lower on the test set than the training set. This is good news, as we care mostly about the test error. However, from these numbers, it is really hard to understand how good the model really is. Perhaps it's better to plot the data:

In [13]: plt.figure(figsize=(10, 6))
... plt.plot(y_test, linewidth=3, label='ground truth')
... plt.plot(y_pred, linewidth=3, label='predicted')
... plt.legend(loc='best')
... plt.xlabel('test data points')
... plt.ylabel('target value')
Out[13]: <matplotlib.text.Text at 0x7ff46783c7b8>

This produces the following figure:

Linear regression model

This makes more sense! Here we see the ground truth housing prices for all test samples in blue and our predicted housing prices in red. Pretty close, if you ask me. It is interesting to note though that the model tends to be off the most for really high or really low housing prices, such as the peak values of data point 12, 18, and 42. We can formalize the amount of variance in the data that we were able to explain by calculating R squared:

In [14]: plt.plot(y_test, y_pred, 'o')
... plt.plot([-10, 60], [-10, 60], 'k--')
... plt.axis([-10, 60, -10, 60])
... plt.xlabel('ground truth')
... plt.ylabel('predicted')

This will plot the ground truth prices, y_test, on the x axis, and our predictions, y_pred, on the y axis. We also plot a diagonal line for reference (using a black dashed line, 'k--'), as we will see soon. But we also want to display the R2 score and mean squared error in a text box:

...      scorestr = r'R$^2$ = %.3f' % linreg.score(X_test, y_test)
... errstr = 'MSE = %.3f' % metrics.mean_squared_error(y_test, y_pred)
... plt.text(-5, 50, scorestr, fontsize=12)
... plt.text(-5, 45, errstr, fontsize=12)
Out[14]: <matplotlib.text.Text at 0x7ff4642d0400>

This will produce the following figure, and is a professional way of plotting a model fit:

Model predictions versus ground truth

If our model was perfect, then all data points would lie on the dashed diagonal, since y_pred would always be equal to y_true. Deviations from the diagonal indicate that the model made some errors, or that there is some variance in the data that the model was not able to explain. Indeed, R2 indicates that we were able to explain 76 percent of the scatter in the data, with a mean squared error of 15.011. These are some hard numbers we can use to compare the linear regression model to some more complicated ones.

主站蜘蛛池模板: 佛山市| 西乡县| 繁峙县| 铁岭县| 湖北省| 龙山县| 建水县| 交城县| 双牌县| 天峻县| 海淀区| 海安县| 昭觉县| 尚义县| 曲周县| 通榆县| 东阳市| 岳西县| 湘乡市| 思茅市| 探索| 叶城县| 探索| 财经| 康马县| 明溪县| 依安县| 陇南市| 罗山县| 竹北市| 恩平市| 洪泽县| 岑巩县| 金坛市| 万山特区| 卢龙县| 江门市| 和顺县| 泌阳县| 元阳县| 论坛|