This function uses the resampling results from a train
object to generate performance statistics over a set of probability
thresholds for two-class problems.
thresholder(x, threshold, final = TRUE, statistics = "all")
| x | A |
|---|---|
| threshold | A numeric vector of candidate probability thresholds between [0,1]. If the class probability corresponding to the first level of the outcome is greater than the threshold, the data point is classified as that level. |
| final | A logical: should only the final tuning parameters
chosen by |
| statistics | A character vector indicating which statistics to
calculate. See details below for possible choices; the default value
|
A data frame with columns for each of the tuning parameters
from the model along with an additional column called
prob_threshold for the probability threshold. There are
also columns for summary statistics averaged over resamples with
column names corresponding to the input argument statistics.
The argument statistics designates the statistics to compute
for each probability threshold. One or more of the following statistics can
be selected:
Sensitivity
Specificity
Pos Pred Value
Neg Pred Value
Precision
Recall
F1
Prevalence
Detection Rate
Detection Prevalence
Balanced Accuracy
Accuracy
Kappa
J
Dist
For a description of these statistics (except the last two), see the
documentation of confusionMatrix. The last two statistics
are Youden's J statistic and the distance to the best possible cutoff (i.e.
perfect sensitivity and specificity.
if (FALSE) { set.seed(2444) dat <- twoClassSim(500, intercept = -10) table(dat$Class) ctrl <- trainControl(method = "cv", classProbs = TRUE, savePredictions = "all", summaryFunction = twoClassSummary) set.seed(2863) mod <- train(Class ~ ., data = dat, method = "rda", tuneLength = 4, metric = "ROC", trControl = ctrl) resample_stats <- thresholder(mod, threshold = seq(.5, 1, by = 0.05), final = TRUE) ggplot(resample_stats, aes(x = prob_threshold, y = J)) + geom_point() ggplot(resample_stats, aes(x = prob_threshold, y = Dist)) + geom_point() ggplot(resample_stats, aes(x = prob_threshold, y = Sensitivity)) + geom_point() + geom_point(aes(y = Specificity), col = "red") }