This example uses the red wine quality dataset from Cortez, et al., 2009 to fit a random forest regression model. The dataset contains 200 observations and includes 12 variables: fixed acidity, volatile acidity, citric acid, residual sugar, chlorides, free sulfur dioxide, total sulfur dioxide, density, pH, sulphates, alcohol, and quality.
Split the dataset
Prior to fitting the random forest model, the testTrainSplit
function is used to split the model data into test and training sets. The testTrainSplit
function is compatible with the GAUSS formula string syntax and creates the test and train datasets without loading the full dataset. In this regression model quality is the response variable and all other variables are used as the predictors:
// Specify dataset name with full path
dataset = getGAUSSHome() $+ "pkgs/gml/examples/winequality-red.csv";
// Split data into 70% training and 30% test sets
{ y_train, y_test, x_train, x_test } = testTrainSplit(dataset, "quality ~ .", 0.7);
Estimate The Model
The rfRegressFit
function is used on the y_train
and x_train
matrices to fit a random forest regression model. All results are stored in a rfModel
structure:
// Output structure
struct rfModel rfm;
// Fit training data using random forest
rfm = rfRegressFit(y_train, x_train);
Make predictions
Once the model is fit predictions can be made from the x_test dataset using rfRegressPredict
function. The rfRegressPredict
function requires two inputs, a rfModel
structure and a data matrix of predictors:
// Make predictions using test data
predictions = rfRegressPredict(rfm, x_test);
// Print predictions
print predictions[1:10]~y_test[1:10];
print "random forest MSE: " meanc((predictions - y_test).^2);
// Print ols MSE
b_hat = y_train / (ones(rows(x_train), 1)~x_train);
y_hat = (ones(rows(x_test),1)~x_test) * b_hat;
print "OLS MSE : " meanc((y_hat - y_test).^2);
Output
The output from the code above looks similar to :
5.0929643 5.0000000 5.1308175 5.0000000 5.1799206 5.0000000 5.1720873 5.0000000 5.5779881 7.0000000 5.3214921 5.0000000 5.3608810 5.0000000 5.7493413 7.0000000 5.2857103 5.0000000 5.1103095 4.0000000 random forest MSE: 0.35596292 OLS MSE : 0.42207614