Cross Validation and its Importance


In this post, we will take a look at Cross validation and it’s importance in building machine learning models.


There are multiple approaches to split a dataset for training and estimating the accuracy of the model. One of the common approaches is to split the dataset into 3 sets, a training set, a validation set and a test set. While the training set is used for training the model, the validation set is used for choosing the best model from amongst many models. The test set is finally used to estimate the model performance on unseen data.

Now, what if our split resulted in some important data points being grouped in the validation and test set. Our model would miss those important data points from being trained on and will result in a poorer model. To alleviate this, we use cross validation. There are multiple flavors of cross validation, viz., K-fold cross validation, Stratified cross validation and Leave one out cross validation (LOOC).

K-Fold cross validation:

In this approach, the training and validation sets are combined and then split into k smaller subsets. The model is then trained on k-1 subsets and then validated on the remaining one subset. This process is carried out k times, allowing all the model to get trained on all possible combinations of datasets. The process is illustrated in the image above.

How do we estimate the performance of the model:

The model performance could be estimated as the average of the k estimates.

Stratified cross validation:

Let us consider a scenario, where we are building a model to predict whether a loan applicant will default or not. If our dataset contains 1000 datapoints and only 20 of them are defaulters. Thus, only 2% of our dataset contains the defaulter applicants. When we split our data into training, validation, test (or cross validation), there is no guarantee that our training set will contain these datapoints.

To resolve this, we can split in such a way that each dataset gets an equal proportion of these classes. Stratified cross validation does exactly this and provides us a proportional training data to work with.

Leave one out cross validation:

In this approach, if there are N datapoints available for training and validation combined, n-1 datapoints are used for training and one datapoint is used for validation. This is repeated N times, allowing all different combinations.

This approach is a very compute intensive approach, since if we have 1000 datapoints, the training will happen 1000 times. This is generally used only we don’t have enough data to work with.


Popular posts from this blog

Spring Integration - Bulk processing Example

Pivotal Cloud Foundry (PCF) Integration with Elastic Cloud Storage (ECS)

Restful code example using Spring MVC