This example uses the k-nearest neighbor method to classify the "iris.csv" dataset. The features used in this example to classify the data include :
- sepal width
- sepal_length
- petal length
- petal_length
Split the dataset
The loadd
function is used to load the data from the dataset. In addition, prior to fitting the knn model, the trainTestSplit
function is used to split the model data into a test and training set.
library gml;
fname = getGAUSSHome() $+ "pkgs/gml/examples/iris.csv";
//Load numeric predictors
X = loadd(fname, ". -Species");
//Load string labels
species = csvReadSA(fname, 2, 5);
//Split data set
{ X_train, X_test, y_train, y_test } = trainTestSplit(X, species, 0.7);
Estimate The Model
The knnClassifyFit
function is used on the x_train
matrix to classify the data using the k-nearest neighbor model. All results are stored in a knnModel
structure:
//Specify number of neighbors
k = 3;
//Declare the knnModel structure
struct knnModel mdl;
//Call knnClassifyFit
mdl = knnClassifyFit(y_train, X_train, k);
Make predictions
Once the model is fit, predictions can be made from the x_test dataset using knnClassifyPredict
function. The knnClassifyPredict
function requires two inputs, a knnModel
structure and a data matrix of predictors:
//Predict classes
y_hat = knnClassifyPredict(mdl, X_test);
print "prediction accuracy = " meanc(y_hat .$== y_test);