Utilities for Machine Learning in Julia
To use, run the command:
Pkg.clone("https://github.com/Wedg/MLTools.jl.git")
The ROC plot is built with the function:
-
plot_ROC_curve(y_cond, y_prob)
y_cond
is the "truth" vector with each element either0
or1
.
y_prob
is the hypothesis vector with each element a probability in the range[0, 1]
.
As well as the ROC curve of the predictor, the plot shows the model's accuracy, true positive rate, and false positive rate as well as the summary statistic AUC (Area Under Curve).
The plot will look something like this.
The confusion matrix plot is built with the function:
-
plot_confusion_matrix(y_cond, y_pred, "y_label", "x_label", classes)
y_cond
is the "truth" vector with each element one of the labels1
throughk
.
y_pred
is the model prediction vector with each element one of the labels1
throughk
.
"y_label"
is the string that will be the label of the y axis.
"x_label"
is the string that will be the label of the x axis.
classes
is a vector of strings representing the classes.
An example will hopefully make this clearer. This is taken from the Alice package demo of a reduced STL10 dataset where the classes are one of airplane, car, cat or dog.
See the demo for more context. The plot is called with the following:
classes = ["airplane", "car", "cat", "dog"]
plot_confusion_matrix(y_cond, y_pred, "Truth", "Prediction", classes)
And produces this plot