官术网_书友最值得收藏!

  • Deep Learning with R for Beginners
  • Mark Hodnett Joshua F. Wiley Yuxi (Hayden) Liu Pablo Maldonado
  • 632字
  • 2021-06-24 14:30:43

Weight decay (L2 penalty in neural networks)

We have already unknowingly used regularization in the previous chapter. The neural network we trained using the caret and nnet package used a weight decay of 0.10. We can investigate the use of weight decay by varying it, and tuning it using cross-validation:

  1. Load the data as before. Then we create a local cluster to run the cross-validation in parallel:
set.seed(1234)
## same data as from previous chapter
if (!file.exists('../data/train.csv'))
{
link <- 'https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/R/data/mnist_csv.zip'
if (!file.exists(paste(dataDirectory,'/mnist_csv.zip',sep="")))
download.file(link, destfile = paste(dataDirectory,'/mnist_csv.zip',sep=""))
unzip(paste(dataDirectory,'/mnist_csv.zip',sep=""), exdir = dataDirectory)
if (file.exists(paste(dataDirectory,'/test.csv',sep="")))
file.remove(paste(dataDirectory,'/test.csv',sep=""))
}

digits.train <- read.csv("../data/train.csv")

## convert to factor
digits.train$label <- factor(digits.train$label, levels = 0:9)

sample <- sample(nrow(digits.train), 6000)
train <- sample[1:5000]
test <- sample[5001:6000]

digits.X <- digits.train[train, -1]
digits.y <- digits.train[train, 1]
test.X <- digits.train[test, -1]
test.y <- digits.train[test, 1]

## try various weight decays and number of iterations
## register backend so that different decays can be
## estimated in parallel
cl <- makeCluster(5)
clusterEvalQ(cl, {source("cluster_inc.R")})
registerDoSNOW(cl)
  1. Train a neural network on the digit classification, and vary the weight-decay penalty at 0 (no penalty) and 0.10. We also loop through two sets of the number of iterations allowed: 100 or 150. Note that this code is computationally intensive and takes some time to run:
set.seed(1234)
digits.decay.m1 <- lapply(c(100, 150), function(its) {
caret::train(digits.X, digits.y,
method = "nnet",
tuneGrid = expand.grid(
.size = c(10),
.decay = c(0, .1)),
trControl = caret::trainControl(method="cv", number=5, repeats=1),
MaxNWts = 10000,
maxit = its)
})
  1. Examining the results, we see that, when we limit to only 100 iterations, both the non--regularized model and regularized model have the same accuracy at 0.56, based on cross-validated results, which is not very good on this data:
digits.decay.m1[[1]]
Neural Network

5000 samples
784 predictor
10 classes: '0', '1', '2', '3', '4', '5', '6', '7', '8', '9'

No pre-processing
Resampling: Cross-Validated (5 fold)
Summary of sample sizes: 4000, 4001, 4000, 3998, 4001
Resampling results across tuning parameters:

decay Accuracy Kappa
0.0 0.56 0.51
0.1 0.56 0.51

Tuning parameter 'size' was held constant at a value of 10
Accuracy was used to select the optimal model using the
largest value.
The final values used for the model were size = 10 and decay = 0.1.
  1. Examine the model with 150 iterations to see whether the regularized or non-regularized model performs better:
digits.decay.m1[[2]]
Neural Network

5000 samples
784 predictor
10 classes: '0', '1', '2', '3', '4', '5', '6', '7', '8', '9'

No pre-processing
Resampling: Cross-Validated (5 fold)
Summary of sample sizes: 4000, 4002, 3998, 4000, 4000
Resampling results across tuning parameters:

decay Accuracy Kappa
0.0 0.64 0.60
0.1 0.63 0.59

Tuning parameter 'size' was held constant at a value of 10
Accuracy was used to select the optimal model using the
largest value.
The final values used for the model were size = 10 and decay = 0.

Overall, the model with more iterations outperforms the model with fewer iterations, regardless of the regularization. However, comparing both models with 150 iterations, the regularized model is superior (accuracy= 0.66) to the non-regularized model (accuracy= 0.65), although here the difference is relatively small.

These results highlight that regularization is often most useful for more complex models that have greater flexibility to fit (and overfit) the data. In models that are appropriate or overly simplistic for the data, regularization will probably decrease performance. When developing a new model architecture, you should avoid adding regularization until the model is performing well on the training data. If you add regularization beforehand and the model performs poorly on the training data, you will not know whether the problem is with the model's architecture or because of the regularization. In the next section, we'll discuss ensemble and model averaging techniques, the last forms of regularization that are highlighted in this book.

主站蜘蛛池模板: 商都县| 汤阴县| 遂川县| 石渠县| 江源县| 平南县| 原平市| 深州市| 洛扎县| 株洲县| 榆树市| 溧水县| 阿鲁科尔沁旗| 忻城县| 宾川县| 滁州市| 洛宁县| 阜新| 来凤县| 阿拉善盟| 临清市| 穆棱市| 自贡市| 平罗县| 浦东新区| 资源县| 游戏| 垣曲县| 抚顺县| 塔河县| 江永县| 从江县| 洪泽县| 双桥区| 高邮市| 海盐县| 石柱| 东乡族自治县| 弥勒县| 兰考县| 宁陕县|