new;
library gml;
/*
** Load and transform data
*/
// Load hitters dataset
dataset = getGAUSSHome $+ "pkgs/gml/examples/hitters.xlsx";
// Load salary and perform natural log transform
y = loadd(dataset, "ln(salary)");
// Load all variables except 'salary'
X = loadd(dataset, ". - salary");
/*
** Split into test and training sets
*/
// Set seed for repeatable sampling
rndseed 234234;
// Split data into training and test sets
{ y_train, y_test, X_train, X_test } = trainTestSplit(y, X, 0.7);
/*
** Estimate random forest model
*/
// Declare 'dfc' to be a dfControl structure
// and fill with default settings.
struct dfControl dfc;
dfc = dfControlCreate();
// Turn on variable importance
dfc.variableImportanceMethod = 1;
// Turn on OOB error
dfc.oobError = 1;
// Structure to hold model results
struct dfModel mdl;
// Fit training data using random forest
mdl = decForestRFit(y_train, X_train, dfc);
// OOB Error
print "Out-of-bag error:" mdl.oobError;
/*
** Plot variable importance
*/
// Load variable names from dataset
// and assign to dfModel structure
mdl.varNames = getHeaders(dataset);
// Draw variable importance plot
plotVariableImportance(mdl);
/*
** Predictions
*/
// Make predictions using test data
predictions = decForestPredict(mdl, X_test);
// Print predictions and decision forest test MSE
print predictions[1:5,.]~y_test[1:5,.];
print "";
print "random forest test MSE:" meanc((predictions - y_test).^2);
// Print ols test MSE
b_hat = y_train / (ones(rows(X_train), 1)~X_train);
alpha_hat = b_hat[1];
b_hat = trimr(b_hat, 1, 0);
y_hat = alpha_hat + X_test * b_hat;
print "OLS test MSE :" meanc((y_hat - y_test).^2);
Have a Specific Question?
Get a real answer from a real person
Need Support?
Get help from our friendly experts.