Introduction
If you've explored machine learning models, you've probably come across the term "cross-validation" at some point. But what exactly is it, and why is it important?
In this blog, we'll break cross-validation into simple terms. With a practical demonstration, we'll equip you with the knowledge to confidently use cross-validation in your machine learning projects.
Model Validation in Machine Learning
Machine learning validation methods provide a means for us to estimate generalization error. This is crucial for determining what model provides the most best predictions for unobserved data.
In cases where large amounts of data are available, machine learning data validation begins with splitting the data into three separate datasets:
- A training set is used to train the machine learning model(s) during development.
- A validation set is used to estimate the generalization error of the model created from the training set for the purpose of model selection.
- A test set is used to estimate the generalization error of the final model.
Cross-Validation in Machine Learning
The model validation process in the previous section works when we have large datasets. When data is limited we must instead use a technique called cross-validation.
The purpose of cross-validation is to provide a better estimate of a model's ability to perform on unseen data. It provides an unbiased estimate of the generalization error, especially in the case of limited data.
There are many reasons we may want to do this:
- To have a clearer measure of how our model performs.
- To tune hyperparameters.
- To make model selections.
The intuition behind cross-validation is simple - rather than training our models on one training set we train our model on multiple subsets of data.
The basic steps of cross-validation are:
- Split data into portions.
- Train our model on a subset of the portions.
- Test our model on the remaining subsets of the data.
- Repeat steps 2-3 until the model has been trained and tested on the entire dataset.
- Average the model performance across all iterations of testing to get the total model performance.
Common Cross-Validation Methods
Though the basic concept of cross-validation is fairly simple, there are a number of ways to go about each step. A few examples of cross-validation methods include
-
k-Fold Cross-Validation
In k-fold cross-validation:- The dataset is divided into k equal sized-folds.
- The model is trained on k-1 folds and tested on the remaining fold.
- The process is repeated k times, with each fold serving as the test set exactly once.
- The performance metrics are averaged over the k iterations.
-
Stratified k-Fold Cross-Validation
This process is similar to k-fold cross-validation with minor but important exceptions:- The class distribution in each fold is preserved.
- It is useful for imbalanced datasets.
-
Leave-One-Out Cross-Validation
The Leave-one-out cross-validation process:- Trains the model using all data observations except one.
- Tests the data using the unused data point.
- Repeats this for n iterations until each data point is used exactly once as a test set.
- Time-Series Cross-Validation
This cross-validation method, designed specifically for time-series:- Splits the data into training and testing sets in a chronologically ordered manner, such as sliding or expanding windows.
- Trains the model on past data and tests the model on future data, based on the splitting point.
Method | Advantages | Disdvantages |
---|---|---|
k-Fold Cross-Validation |
|
|
Stratified k-Fold Cross-Validation |
|
|
Leave-One-Out Cross-Validation (LOOCV) |
|
|
Time Series Cross-Validation |
|
|
k-Fold Cross-Validation Example
Let's look at k-fold cross-validation in action, using the wine quality dataset included in the GAUSS Machine Learning (GML) library. This file is based on the Kaggle Wine Quality dataset.
Our objective is to classify wines into quality categories using 11 qualities:
- Fixed acidity.
- Volatile acidity.
- Citric acid.
- Residual sugar.
- Chlorides.
- Free sulfur dioxide.
- Total sulfur dioxide.
- Density.
- pH.
- Sulphates.
- Alcohol.
We'll use k-fold cross-validation to examine the performance of a random forest classification model.
Data Loading and Organization
First we will load our data directly from the GML library:
/*
** Load data and prepare data
*/
// Filename
fname = getGAUSSHome("pkgs/gml/examples/winequality.csv");
// Load wine quality dataset
dataset = loadd(fname);
After loading the data, we need to shuffle the data and extract our dependent and independent variables.
// Enable repeatable sampling
rndseed 754931;
// Shuffle the dataset (sample without replacement),
// because cvSplit does not shuffle.
dataset = sampleData(dataset, rows(dataset));
y = dataset[.,"quality"];
X = delcols(dataset, "quality");
Setting Random Forest Hyperparameters
After loading our data, we will set the random forest hyperparameters using the dfControl
structure.
// Enable GML library functions
library gml;
/*
** Model settings
*/
// The dfModel structure holds the trained model
struct dfModel dfm;
// Declare 'dfc' to be a dfControl
// structure and fill with default settings
struct dfControl dfc;
dfc = dfControlCreate();
// Create 200 decision trees
dfc.numTrees = 200;
// Stop splitting if impurity at
// a node is less than 0.15
dfc.impurityThreshold = 0.15;
// Only consider 2 features per split
dfc.featuresPerSplit = 2;
k-fold Cross-Validation
Now that we have loaded our data and set our hyperparameters, we are ready to fit our random forest model and implement k-fold cross-validation.
First we setup the number of folds and pre-allocate a storage vector for model accuracy.
// Specify number of folds
// This generally is 5-10
nfolds = 5;
// Pre-allocate vector to hold the results
accuracy = zeros(nfolds, 1);
Next we use a GAUSS for loop
to complete four steps:
- Select testing and training data from our folds using the
cvSplit
procedure. - Fit our random forest classification model on the chosen training data using
decForestCFit
procedure. - Make classification predictions using the chosen testing data and the
decForestPredict
procedure. - Compute and store model accuracy for each iteration.
for i(1, nfolds, 1);
{ y_train, y_test, X_train, X_test } = cvSplit(y, X, nfolds, i);
// Fit model using this fold's training data
dfm = decForestCFit(y_train, X_train, dfc);
// Make predictions using this fold's test data
predictions = decForestPredict(dfm, X_test);
accuracy[i] = meanc(y_test .== predictions);
endfor;
Results
Let's print the accuracy results and the total model accuracy:
/*
** Print Results
*/
sprintf("%7s %10s", "Fold", "Accuracy");;
sprintf("%7d %10.2f", seqa(1,1,nfolds), accuracy);
sprintf("Total model accuracy : %10.2f", meanc(accuracy));
sprintf("Accuracy variation across folds: %10.3f", stdc(accuracy));
Fold Accuracy 1 0.70 2 0.73 3 0.65 4 0.71 5 0.71 Total model accuracy : 0.70 Accuracy variation across folds: 0.028
Our results provide some important insights into why we conduct cross-validation:
- The model accuracy is different across folds, with a standard deviation of 0.028.
- The maximum accuracy, using fold 2, is 0.73.
- The minimum accuracy, using folds 3 is 0.65.
Depending on how we split our testing and training, we could get a different picture of model performance.
The total model accuracy, at 0.70, gives a better overall measure of model performance. The standard deviation of the accuracy gives us some insight into how much our prediction accuracy might vary.
Conclusion
If you're looking to improve the accuracy and reliability of your statistical analysis, cross-validation is a crucial technique to learn. In today's blog we've provided a guide to getting started with cross-validation.
Our step-by-step practical demonstration using GAUSS should prepare you to confidently implement cross-validation in your own data analysis projects.
Further Machine Learning Reading
- Predicting Recessions with Machine Learning Techniques
- Applications of Principal Components Analysis in Finance
- Predicting The Output Gap With Machine Learning Regression Models
- Fundamentals of Tuning Machine Learning Hyperparameters
- Machine Learning With Real-World Data
- Classification with Regularized Logistic Regression
Eric has been working to build, distribute, and strengthen the GAUSS universe since 2012. He is an economist skilled in data analysis and software development. He has earned a B.A. and MSc in economics and engineering and has over 18 years of combined industry and academic experience in data analysis and research.