library(tidyverse)
library(ggtext)
library(janitor)
library(flextable)
library(caret)
library(recipes)
<- readRDS("./buffer.rds")
buffer <- readRDS("./codebook.rds") codebook
The k-nearest neighbors (kNN) algorithm is a type of supervised ML algorithm that classifies observations by the “majority vote” of the k nearest neighbors. The value of k is a hyperparameter to optimize.
kNN is a non-parametric model, meaning it makes no assumptions about the underlying data. This is often an advantage in cases where data does not follow standard distributions. kNN is also “lazy” algorithm in that it uses the training data set data to classify future responses rather than using it to create classification rules. The qualities make kNN relatively straight-forward to understand.
I’ll partition my data into an 80:20 train:test split.
set.seed(801)
<- createDataPartition(buffer$struggle, p = 0.8, list = FALSE)
train_idx <- buffer[train_idx, ]
dat_train <- buffer[-train_idx, ] dat_test
and train a model with 10-fold CV.
<- trainControl(
train_control method = "cv",
number = 10,
savePredictions = "final",
classProbs = TRUE
)
My model data set variables are the outcome variable struggle
plus 15 predictors. I’ll drop the character country
column, plus the less-informative role
and industry
columns, leaving struggle
plus 12 predictors.
<- dat_train %>% select(struggle, everything(), -c(country, role, industry)) %>% colnames()
mdl_vars mdl_vars
## [1] "struggle" "race" "work_exp" "remote_exp"
## [5] "disability" "caregiver" "prefer_remote" "recommend_remote"
## [9] "benefit" "covid" "emp_type" "fte"
## [13] "pct_remote"
I’ll use the recipe method to train.
<- recipe(struggle ~ ., data = dat_train[, mdl_vars]) %>%
rcpe step_dummy(all_nominal(), -all_outcomes()) %>%
step_center(all_predictors()) %>%
step_scale(all_predictors())
prep(rcpe, training = dat_train)
## Data Recipe
##
## Inputs:
##
## role #variables
## outcome 1
## predictor 12
##
## Training data contained 1777 data points and no missing data.
##
## Operations:
##
## Dummy variables from race, work_exp, remote_exp, disability, ... [trained]
## Centering for race_Black, race_Hispanic, ... [trained]
## Scaling for race_Black, race_Hispanic, ... [trained]
The KNN model has a single hyperparameter to fit: k.
set.seed(1970)
<- train(
mdl_knn
rcpe,data = dat_train[, mdl_vars],
method = "knn",
trControl = train_control,
tuneGrid = expand.grid(k = c(25, 50, 100, 150, 200, 250, 300, 350)),
metric = "Accuracy"
)
mdl_knn
## k-Nearest Neighbors
##
## 1777 samples
## 12 predictor
## 7 classes: 'Timezones', 'Collaboration', 'Distractions', 'Loneliness', 'Unplugging', 'Motivation', 'Other'
##
## Recipe steps: dummy, center, scale
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 1599, 1601, 1599, 1600, 1600, 1599, ...
## Resampling results across tuning parameters:
##
## k Accuracy Kappa
## 25 0.2403922 0.031255551
## 50 0.2555513 0.031143308
## 100 0.2555004 0.008549662
## 150 0.2628135 0.007392772
## 200 0.2633944 0.003353586
## 250 0.2661811 0.003981376
## 300 0.2639308 -0.001040259
## 350 0.2644958 -0.001151478
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was k = 250.
Prediction accuracy was maximized at K = 250. Kappa was maximized 50.
plot(mdl_knn)
The confusion matrix compares the predicted to actual values from the 10-CV at the optimal k.
<- confusionMatrix(mdl_knn)) (confusion
## Cross-Validated (10 fold) Confusion Matrix
##
## (entries are percentual average cell counts across resamples)
##
## Reference
## Prediction Timezones Collaboration Distractions Loneliness Unplugging
## Timezones 0.0 0.0 0.0 0.0 0.0
## Collaboration 0.1 0.5 0.1 0.3 0.5
## Distractions 0.0 0.0 0.0 0.1 0.1
## Loneliness 0.0 0.1 0.2 0.1 0.1
## Unplugging 6.8 15.6 14.7 15.4 26.0
## Motivation 0.0 0.1 0.1 0.0 0.0
## Other 0.0 0.0 0.0 0.0 0.0
## Reference
## Prediction Motivation Other
## Timezones 0.0 0.0
## Collaboration 0.1 0.1
## Distractions 0.1 0.0
## Loneliness 0.3 0.0
## Unplugging 11.8 7.0
## Motivation 0.0 0.0
## Other 0.0 0.0
##
## Accuracy (average) : 0.2662
The model basically predicts Unplugging all the time (97.3%). Not a bad guess, really.
Here is the model performance on the test data set.
<- bind_cols(
preds_knn # dat_test,
# predict(mdl_knn, newdata = dat_test, type = "prob"),
Predicted = predict(mdl_knn, newdata = dat_test, type = "raw"),
Actual = dat_test$struggle
)
<- confusionMatrix(preds_knn$Predicted, reference = preds_knn$Actual)) (confusion_test
## Confusion Matrix and Statistics
##
## Reference
## Prediction Timezones Collaboration Distractions Loneliness Unplugging
## Timezones 0 0 0 0 0
## Collaboration 0 1 1 0 5
## Distractions 0 0 0 0 0
## Loneliness 0 0 1 0 0
## Unplugging 30 71 64 70 113
## Motivation 0 0 0 0 0
## Other 0 0 0 0 0
## Reference
## Prediction Motivation Other
## Timezones 0 0
## Collaboration 2 0
## Distractions 0 0
## Loneliness 0 0
## Unplugging 52 31
## Motivation 0 0
## Other 0 0
##
## Overall Statistics
##
## Accuracy : 0.2585
## 95% CI : (0.2182, 0.302)
## No Information Rate : 0.2676
## P-Value [Acc > NIR] : 0.6835
##
## Kappa : -0.0091
##
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: Timezones Class: Collaboration Class: Distractions
## Sensitivity 0.00000 0.013889 0.0000
## Specificity 1.00000 0.978320 1.0000
## Pos Pred Value NaN 0.111111 NaN
## Neg Pred Value 0.93197 0.835648 0.8503
## Prevalence 0.06803 0.163265 0.1497
## Detection Rate 0.00000 0.002268 0.0000
## Detection Prevalence 0.00000 0.020408 0.0000
## Balanced Accuracy 0.50000 0.496104 0.5000
## Class: Loneliness Class: Unplugging Class: Motivation
## Sensitivity 0.000000 0.95763 0.0000
## Specificity 0.997305 0.01548 1.0000
## Pos Pred Value 0.000000 0.26218 NaN
## Neg Pred Value 0.840909 0.50000 0.8776
## Prevalence 0.158730 0.26757 0.1224
## Detection Rate 0.000000 0.25624 0.0000
## Detection Prevalence 0.002268 0.97732 0.0000
## Balanced Accuracy 0.498652 0.48655 0.5000
## Class: Other
## Sensitivity 0.00000
## Specificity 1.00000
## Pos Pred Value NaN
## Neg Pred Value 0.92971
## Prevalence 0.07029
## Detection Rate 0.00000
## Detection Prevalence 0.00000
## Balanced Accuracy 0.50000
Predictably, it predicted Unplugging (97.7%) of the time, and was able to score an overall accuracy of 26%.
I’m not sure how variable importance is calculated for kNN. The matrix below places high importance on whether remote work stemmed from the COVID pandemic.
varImp(mdl_knn) %>%
pluck("importance") %>%
rownames_to_column(var = "FactorLevel") %>%
pivot_longer(-FactorLevel) %>%
ggplot(aes(x = name, y = fct_rev(FactorLevel))) +
geom_tile(aes(fill = value), show.legend = FALSE) +
geom_text(aes(label = round(value, 0)), size = 3) +
theme_light() +
scale_fill_gradient(low = "white", high = "steelblue") +
labs(title = "Variable Importance", y = NULL, x = NULL)