forked from neelsoumya/basic_statistics
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcaret_example.R
625 lines (427 loc) · 19.5 KB
/
caret_example.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
<https://www.r-project.org/conferences/useR-2013/Tutorials/kuhn/user_caret_2up.pdf>
Cross validation:
========================================================
- CV is central to tuning the model parameters
- Data is divided into training and test sets
- Use the Training set to 'build' the model
- Use the Test set to 'validate' the model
5-fold CV:
- Divide the Training set into 5 equally sized partitions
- Use one of the partitions for validation and the rest for training
- 'Repeated' CV uses resampling to take different looks at the data
Cross validation (II):
========================================================
- Best estimates for model parameters are obtained by optimisation of an objective (cost) function
- Examples of objective functions: classification accuracy, regression RMSE
- Optimisation is usually by gradient decent, or information-theoretic optimisation for tree-based models, or back-propagation for deep neural networks, etc
- Cross-validation accuracy is calculated by averaging across all the resamplings
- 'Prediction accuracy' is obtained by applying the fitted model to held-out Test data
Why we need caret
========================================================
|obj Class |Package |predict Function Syntax |
|:----------|:-------|:----------------------------------------|
|lda |MASS |predict(obj) (no options needed) |
|glm |stats |predict(obj, type = "response") |
|gbm |gbm |predict(obj, type = "response", n.trees) |
|mda |mda |predict(obj, type = "posterior") |
|rpart |rpart |predict(obj, type = "prob") |
|Weka |RWeka |predict(obj, type = "probability") |
|LogitBoost |caTools |predict(obj, type = "raw", nIter) |
https://www.r-project.org/conferences/useR-2013/Tutorials/kuhn/user_caret_2up.pdf
Available Models
========================================================
<https://topepo.github.io/caret/available-models.html>
Load **CARET** package
library(caret)
Other required packages are **doMC** (parallel processing) and **corrplot** (correlation matrix plots):
library(doMC)
library(corrplot)
Example data set
=======================================================
type:section
Wheat seeds data set
=======================================================
The seeds data set https://archive.ics.uci.edu/ml/datasets/seeds contains morphological measurements on the kernels of three varieties of wheat: Kama, Rosa and Canadian.
Load the data into your R session using:
load("data/wheat_seeds/wheat_seeds.Rda")
What objects have been loaded into our R session?
ls()
```
[1] "morphometrics" "variety"
```
Wheat seeds data set: predictors
======================================================
The **morphometrics** data.frame contains seven variables describing the morphology of the seeds.
str(morphometrics)
```
'data.frame': 210 obs. of 7 variables:
$ area : num 15.3 14.9 14.3 13.8 16.1 ...
$ perimeter : num 14.8 14.6 14.1 13.9 15 ...
$ compactness : num 0.871 0.881 0.905 0.895 0.903 ...
$ kernLength : num 5.76 5.55 5.29 5.32 5.66 ...
$ kernWidth : num 3.31 3.33 3.34 3.38 3.56 ...
$ asymCoef : num 2.22 1.02 2.7 2.26 1.35 ...
$ grooveLength: num 5.22 4.96 4.83 4.8 5.17 ...
```
Wheat seeds data set: class labels
======================================================
The class labels of the seeds are in the factor **variety**.
summary(variety)
```
Canadian Kama Rosa
70 70 70
```
Partition data
======================================================
type:section
Training and test set
======================================================
![](img/cross-validation.png)
Partition data into training and test set
======================================================
set.seed(42)
trainIndex <- createDataPartition(y=variety, times=1, p=0.7, list=F)
varietyTrain <- variety[trainIndex]
morphTrain <- morphometrics[trainIndex,]
varietyTest <- variety[-trainIndex]
morphTest <- morphometrics[-trainIndex,]
Class distributions are balanced across the splits
====================================================
Training set
summary(varietyTrain)
```
Canadian Kama Rosa
49 49 49
```
Test set
summary(varietyTest)
Canadian Kama Rosa
21 21 21
Assess data quality
======================================================
type:section
Identification of near zero variance predictors
======================================================
The function **nearZeroVar** identifies predictors that have one unique value. It also diagnoses predictors having both of the following characteristics:
* very few unique values relative to the number of samples
* the ratio of the frequency of the most common value to the frequency of the 2nd most common value is large.
Such zero and near zero-variance predictors have a deleterious impact on modelling and may lead to unstable fits.
Identification of near zero variance predictors cont.
======================================================
nearZeroVar(morphTrain, saveMetrics = T)
```
freqRatio percentUnique zeroVar nzv
area 1.5 93.87755 FALSE FALSE
perimeter 1.0 85.03401 FALSE FALSE
compactness 1.0 93.19728 FALSE FALSE
kernLength 1.5 91.83673 FALSE FALSE
kernWidth 1.5 91.15646 FALSE FALSE
asymCoef 1.0 98.63946 FALSE FALSE
grooveLength 1.0 77.55102 FALSE FALSE
```
Are all predictors on the same scale?
======================================================
featurePlot(x = morphTrain,
y = varietyTrain,
plot = "box",
## Pass in options to bwplot()
scales = list(y = list(relation="free"),
x = list(rot = 90)),
layout = c(4,2))
Feature plots
======================================================
<img src="caret-figure/unnamed-chunk-15-1.png" title="plot of chunk unnamed-chunk-15" alt="plot of chunk unnamed-chunk-15" width="100%" style="display: block; margin: auto;" />
Predictors on different scales
=====================================================
The variables in this data set are on different scales. In this situation it is important to **centre** and **scale** each predictor.
- A predictor variable is **centered** by subtracting the mean of the predictor from each value.
- To **scale** a predictor variable, each value is divided by its standard deviation.
After centring and scaling the predictor variable has a mean of 0 and a standard deviation of 1.
Pairwise correlation between predictors
=====================================================
Examine pairwise correlations of predictors to identify redundancy in data set
corMat <- cor(morphTrain)
corrplot(corMat, order="hclust", tl.cex=1)
Pairwise correlation between predictors cont.
=====================================================
<img src="caret-figure/unnamed-chunk-17-1.png" title="plot of chunk unnamed-chunk-17" alt="plot of chunk unnamed-chunk-17" width="100%" style="display: block; margin: auto;" />
Find highly correlated predictors
=====================================================
highCorr <- findCorrelation(corMat, cutoff=0.75)
length(highCorr)
```
[1] 4
```
names(morphTrain)[highCorr]
```
[1] "area" "kernWidth" "perimeter" "kernLength"
```
Model training and parameter tuning
====================================================
type: section
Models to evaluate
======================================================
- **svmRadialCost** with one tuning parameter **C**
- **svmRadialSigma** with two tuning parameters: **sigma** and **C**
To find out more information about a particular model use:
getModelInfo("svmRadialSigma")
Parameter tuning using cross-validation
======================================================
![](img/cross-validation.png)
Parallel processing
======================================================
We will use repeated cross-validation to find the best value of our tuning parameters and we will try 10 values of each.
Repeated cross-validation can readily be parallelized to increase speed of execution. All we need to do is create a local cluster. **CARET** will then use this cluster to parallelize the cross-validation.
registerDoMC(detectCores())
getDoParWorkers()
```
[1] 4
```
Resampling
======================================================
The resampling method is specified using the **trainControl** function. To repeat five-fold cross validation a total of five times we would use:
train_ctrl <- trainControl(method="repeatedcv",
number = 5,
repeats = 5)
Resampling cont.
======================================================
To make the analysis reproducible we need to specify the seed for each resampling iteration.
set.seed(42)
seeds <- vector(mode = "list", length = 26)
for(i in 1:25) seeds[[i]] <- sample.int(1000, 10)
seeds[[26]] <- sample.int(1000,1)
train_ctrl <- trainControl(method="repeatedcv",
number = 5,
repeats = 5,
seeds = seeds)
Train svmRadialCost model
========================================================
The **train** function is used to tune a model
rcFit <- train(morphTrain, varietyTrain,
method="svmRadialCost",
preProcess = c("center", "scale"),
#tuneGrid=tuneParam,
tuneLength=10,
trControl=train_ctrl)
rcFit
Train svmRadialCost model cont.
========================================================
```
Support Vector Machines with Radial Basis Function Kernel
147 samples
7 predictor
3 classes: 'Canadian', 'Kama', 'Rosa'
Pre-processing: centered (7), scaled (7)
Resampling: Cross-Validated (5 fold, repeated 5 times)
Summary of sample sizes: 118, 117, 117, 117, 119, 118, ...
Resampling results across tuning parameters:
C Accuracy Kappa
0.25 0.9211790 0.8817025
0.50 0.9238456 0.8856728
1.00 0.9278456 0.8916728
2.00 0.9277997 0.8916162
4.00 0.9195632 0.8792973
8.00 0.9182266 0.8772683
16.00 0.9128506 0.8691940
32.00 0.9185090 0.8776993
64.00 0.9089918 0.8634461
128.00 0.9090837 0.8635816
Accuracy was used to select the optimal model using the largest value.
The final value used for the model was C = 1.
```
Train svmRadialCost model cont.
=========================================================
<img src="caret-figure/unnamed-chunk-26-1.png" title="plot of chunk unnamed-chunk-26" alt="plot of chunk unnamed-chunk-26" width="100%" style="display: block; margin: auto;" />
Train svmRadialSigma model
======================================================
If we set **tuneLength** to 10 the svmRadialSigma model will be evaluated with 10 different values of **C**. The svmRadialSigma model is setup to evaluate a maximum of six values of sigma. Therefore in each resampling iteration we need a total of 60 seeds (10x6).
set.seed(42)
seeds <- vector(mode = "list", length = 26)
for(i in 1:25) seeds[[i]] <- sample.int(1000, 60)
seeds[[26]] <- sample.int(1000,1)
train_ctrl <- trainControl(method="repeatedcv",
number = 5,
repeats = 5,
seeds = seeds)
Train svmRadialSigma model cont.
========================================================
The **train** function is used to tune a model
rsFit <- train(morphTrain, varietyTrain,
method="svmRadialSigma",
preProcess = c("center", "scale"),
#tuneGrid=tuneParam,
tuneLength=10,
trControl=train_ctrl)
rsFit
Train svmRadialSigma model cont.
========================================================
```
Support Vector Machines with Radial Basis Function Kernel
147 samples
7 predictor
3 classes: 'Canadian', 'Kama', 'Rosa'
Pre-processing: centered (7), scaled (7)
Resampling: Cross-Validated (5 fold, repeated 5 times)
Summary of sample sizes: 118, 117, 117, 119, 117, 117, ...
Resampling results across tuning parameters:
sigma C Accuracy Kappa
0.03298587 0.25 0.9127455 0.8690039
0.03298587 0.50 0.9139803 0.8707992
0.03298587 1.00 0.9249787 0.8872711
0.03298587 2.00 0.9279278 0.8917430
0.03298587 4.00 0.9320657 0.8979946
0.03298587 8.00 0.9318325 0.8976274
0.03298587 16.00 0.9319737 0.8977663
0.03298587 32.00 0.9399737 0.9097827
0.03298587 64.00 0.9413071 0.9117827
0.03298587 128.00 0.9401609 0.9100772
0.11220186 0.25 0.9275993 0.8912570
0.11220186 0.50 0.9317833 0.8975240
0.11220186 1.00 0.9304959 0.8955636
0.11220186 2.00 0.9415895 0.9122590
0.11220186 4.00 0.9403021 0.9102937
0.11220186 8.00 0.9346864 0.9018007
0.11220186 16.00 0.9305944 0.8956279
0.11220186 32.00 0.9361182 0.9039525
0.11220186 64.00 0.9348768 0.9021248
0.11220186 128.00 0.9305944 0.8956908
0.19141785 0.25 0.9237865 0.8855445
0.19141785 0.50 0.9345419 0.9016186
0.19141785 1.00 0.9388243 0.9081049
0.19141785 2.00 0.9430640 0.9144600
0.19141785 4.00 0.9333530 0.8997595
0.19141785 8.00 0.9263547 0.8892215
0.19141785 16.00 0.9252611 0.8876242
0.19141785 32.00 0.9266897 0.8897655
0.19141785 64.00 0.9211232 0.8814091
0.19141785 128.00 0.9210739 0.8813310
0.27063384 0.25 0.9278325 0.8916086
0.27063384 0.50 0.9331133 0.8994356
0.27063384 1.00 0.9388276 0.9080804
0.27063384 2.00 0.9347323 0.9018646
0.27063384 4.00 0.9304959 0.8954346
0.27063384 8.00 0.9197833 0.8793671
0.27063384 16.00 0.9185977 0.8776302
0.27063384 32.00 0.9157406 0.8733349
0.27063384 64.00 0.9144072 0.8713351
0.27063384 128.00 0.9144072 0.8713351
0.34984983 0.25 0.9277833 0.8914907
0.34984983 0.50 0.9331133 0.8994356
0.34984983 1.00 0.9374450 0.9059735
0.34984983 2.00 0.9346371 0.9017149
0.34984983 4.00 0.9183547 0.8772711
0.34984983 8.00 0.9157373 0.8733367
0.34984983 16.00 0.9185944 0.8776533
0.34984983 32.00 0.9130246 0.8692947
0.34984983 64.00 0.9130246 0.8692947
0.34984983 128.00 0.9130246 0.8692947
0.42906582 0.25 0.9291166 0.8934907
0.42906582 0.50 0.9316847 0.8973104
0.42906582 1.00 0.9360164 0.9038402
0.42906582 2.00 0.9346371 0.9017313
0.42906582 4.00 0.9155928 0.8731255
0.42906582 8.00 0.9144039 0.8713367
0.42906582 16.00 0.9117865 0.8674486
0.42906582 32.00 0.9075501 0.8610938
0.42906582 64.00 0.9075501 0.8610938
0.42906582 128.00 0.9075501 0.8610938
Accuracy was used to select the optimal model using the largest value.
The final values used for the model were sigma = 0.1914179 and C = 2.
```
Train svmRadialSigma model cont.
=========================================================
<img src="caret-figure/unnamed-chunk-31-1.png" title="plot of chunk unnamed-chunk-31" alt="plot of chunk unnamed-chunk-31" width="100%" style="display: block; margin: auto;" />
Model comparison
=========================================================
type:section
Make a list of our models
========================================================
model_list <- list(radialCost=rcFit,
radialSigma=rsFit)
Collect resampling results for each model
========================================================
resamps <- resamples(model_list)
resamps
Call:
resamples.default(x = model_list)
Models: radialCost, radialSigma
Number of resamples: 25
Performance metrics: Accuracy, Kappa
Time estimates for: everything, final model fit
```
Summarize resampling results
========================================================
summary(resamps)
```
Call:
summary.resamples(object = resamps)
Models: radialCost, radialSigma
Number of resamples: 25
Accuracy
Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
radialCost 0.8333333 0.9000000 0.9310345 0.9278456 0.9642857 1 0
radialSigma 0.8620690 0.9285714 0.9333333 0.9430640 0.9666667 1 0
Kappa
Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
radialCost 0.7500000 0.8500000 0.8966132 0.8916728 0.9464627 1 0
radialSigma 0.7913669 0.8923077 0.9000000 0.9144600 0.9500000 1 0
```
Plot resampling results
=========================================================
bwplot(resamps)
Boxplots of resampling results
=========================================================
<img src="caret-figure/unnamed-chunk-36-1.png" title="plot of chunk unnamed-chunk-36" alt="plot of chunk unnamed-chunk-36" width="100%" style="display: block; margin: auto;" />
Predict test set
========================================================
type:section
Predict test set
========================================================
Predict varieties of the test set using best model.
test_pred <- predict(rsFit, morphTest)
confusionMatrix(test_pred, varietyTest)
Confusion matrix
========================================================
```
Confusion Matrix and Statistics
Reference
Prediction Canadian Kama Rosa
Canadian 20 0 0
Kama 1 20 2
Rosa 0 1 19
Overall Statistics
Accuracy : 0.9365
95% CI : (0.8453, 0.9824)
No Information Rate : 0.3333
P-Value [Acc > NIR] : < 2.2e-16
Kappa : 0.9048
Mcnemar's Test P-Value : NA
Statistics by Class:
Class: Canadian Class: Kama Class: Rosa
Sensitivity 0.9524 0.9524 0.9048
Specificity 1.0000 0.9286 0.9762
Pos Pred Value 1.0000 0.8696 0.9500
Neg Pred Value 0.9767 0.9750 0.9535
Prevalence 0.3333 0.3333 0.3333
Detection Rate 0.3175 0.3175 0.3016
Detection Prevalence 0.3175 0.3651 0.3175
Balanced Accuracy 0.9762 0.9405 0.9405
```
Performance measures
=========================================================
**sensitivity** = TPR = TP/P = TP/(TP+FN)
**specificity** = TNR = TN/N = TN/(TN+FP)
**precision** = PPV = TP/(TP+FP)
**negative predictive value** = TN/(TN+FN)
Bias-variance tradeoff
=========================================================
![](img/overfitting.png)
- Bias is residual error from fitting the Training data
- Variance is generalization error when applying the model fit to Test data
![](img/bias-variance.png)
An underfit simple model misses out important features of the data, wheras an overfit complex model fits the noise and outliers.
Resources
========================================================
- Manual: http://topepo.github.io/caret/index.html
- JSS Paper: http://www.jstatsoft.org/v28/i05/paper
- Book: http://appliedpredictivemodeling.com