Cross Validation and its Importance
In this post, we will take a look at Cross validation and
it’s importance in building machine learning models.
There are multiple approaches to split a dataset for
training and estimating the accuracy of the model. One of the common approaches
is to split the dataset into 3 sets, a training set, a validation set and a
test set. While the training set is used for training the model, the validation
set is used for choosing the best model from amongst many models. The test set
is finally used to estimate the model performance on unseen data.
Now, what if our split resulted in some important data
points being grouped in the validation and test set. Our model would miss those
important data points from being trained on and will result in a poorer model. To
alleviate this, we use cross validation. There are multiple flavors of cross
validation, viz., K-fold cross validation, Stratified cross validation and
Leave one out cross validation (LOOC).
K-Fold cross validation:
In this approach, the training and validation sets are
combined and then split into k smaller subsets. The model is then trained on
k-1 subsets and then validated on the remaining one subset. This process is
carried out k times, allowing all the model to get trained on all possible
combinations of datasets. The process is illustrated in the image above.
How do we estimate the performance of the model:
The model performance could be estimated as the average of
the k estimates.
Stratified cross validation:
Let us consider a scenario, where we are building a model to
predict whether a loan applicant will default or not. If our dataset contains
1000 datapoints and only 20 of them are defaulters. Thus, only 2% of our
dataset contains the defaulter applicants. When we split our data into
training, validation, test (or cross validation), there is no guarantee that
our training set will contain these datapoints.
To resolve this, we can split in such a way that each
dataset gets an equal proportion of these classes. Stratified cross validation
does exactly this and provides us a proportional training data to work with.
Leave one out cross validation:
In this approach, if there are N datapoints available for
training and validation combined, n-1 datapoints are used for training and one
datapoint is used for validation. This is repeated N times, allowing all
This approach is a very compute intensive approach, since if
we have 1000 datapoints, the training will happen 1000 times. This is generally
used only we don’t have enough data to work with.