library(tidyverse)
library(ggtext)
library(janitor)
library(flextable)
library(caret)
library(recipes)

buffer <- readRDS("./buffer.rds")
codebook <- readRDS("./codebook.rds")

The k-nearest neighbors (kNN) algorithm is a type of supervised ML algorithm that classifies observations by the “majority vote” of the k nearest neighbors. The value of k is a hyperparameter to optimize.

kNN is a non-parametric model, meaning it makes no assumptions about the underlying data. This is often an advantage in cases where data does not follow standard distributions. kNN is also “lazy” algorithm in that it uses the training data set data to classify future responses rather than using it to create classification rules. The qualities make kNN relatively straight-forward to understand.

Set Up Workflow

I’ll partition my data into an 80:20 train:test split.

set.seed(801)
train_idx <- createDataPartition(buffer$struggle, p = 0.8, list = FALSE)
dat_train <- buffer[train_idx, ]
dat_test <- buffer[-train_idx, ]

and train a model with 10-fold CV.

train_control <- trainControl(
  method = "cv", 
  number = 10,
  savePredictions = "final",
  classProbs = TRUE
)

Prep Model

My model data set variables are the outcome variable struggle plus 15 predictors. I’ll drop the character country column, plus the less-informative role and industry columns, leaving struggle plus 12 predictors.

mdl_vars <- dat_train %>% select(struggle, everything(), -c(country, role, industry)) %>% colnames()
mdl_vars
##  [1] "struggle"         "race"             "work_exp"         "remote_exp"      
##  [5] "disability"       "caregiver"        "prefer_remote"    "recommend_remote"
##  [9] "benefit"          "covid"            "emp_type"         "fte"             
## [13] "pct_remote"

I’ll use the recipe method to train.

rcpe <- recipe(struggle ~ ., data = dat_train[, mdl_vars]) %>%
  step_dummy(all_nominal(), -all_outcomes()) %>% 
  step_center(all_predictors()) %>%
  step_scale(all_predictors())

prep(rcpe, training = dat_train)
## Data Recipe
## 
## Inputs:
## 
##       role #variables
##    outcome          1
##  predictor         12
## 
## Training data contained 1777 data points and no missing data.
## 
## Operations:
## 
## Dummy variables from race, work_exp, remote_exp, disability, ... [trained]
## Centering for race_Black, race_Hispanic, ... [trained]
## Scaling for race_Black, race_Hispanic, ... [trained]

Fit KNN

The KNN model has a single hyperparameter to fit: k.

set.seed(1970)
mdl_knn <- train(
  rcpe,
  data = dat_train[, mdl_vars],
  method = "knn",
  trControl = train_control,
  tuneGrid = expand.grid(k = c(25, 50, 100, 150, 200, 250, 300, 350)),
  metric = "Accuracy"
)

mdl_knn
## k-Nearest Neighbors 
## 
## 1777 samples
##   12 predictor
##    7 classes: 'Timezones', 'Collaboration', 'Distractions', 'Loneliness', 'Unplugging', 'Motivation', 'Other' 
## 
## Recipe steps: dummy, center, scale 
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 1599, 1601, 1599, 1600, 1600, 1599, ... 
## Resampling results across tuning parameters:
## 
##   k    Accuracy   Kappa       
##    25  0.2403922   0.031255551
##    50  0.2555513   0.031143308
##   100  0.2555004   0.008549662
##   150  0.2628135   0.007392772
##   200  0.2633944   0.003353586
##   250  0.2661811   0.003981376
##   300  0.2639308  -0.001040259
##   350  0.2644958  -0.001151478
## 
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was k = 250.

Prediction accuracy was maximized at K = 250. Kappa was maximized 50.

plot(mdl_knn)

Resampling Performance

The confusion matrix compares the predicted to actual values from the 10-CV at the optimal k.

(confusion <- confusionMatrix(mdl_knn))
## Cross-Validated (10 fold) Confusion Matrix 
## 
## (entries are percentual average cell counts across resamples)
##  
##                Reference
## Prediction      Timezones Collaboration Distractions Loneliness Unplugging
##   Timezones           0.0           0.0          0.0        0.0        0.0
##   Collaboration       0.1           0.5          0.1        0.3        0.5
##   Distractions        0.0           0.0          0.0        0.1        0.1
##   Loneliness          0.0           0.1          0.2        0.1        0.1
##   Unplugging          6.8          15.6         14.7       15.4       26.0
##   Motivation          0.0           0.1          0.1        0.0        0.0
##   Other               0.0           0.0          0.0        0.0        0.0
##                Reference
## Prediction      Motivation Other
##   Timezones            0.0   0.0
##   Collaboration        0.1   0.1
##   Distractions         0.1   0.0
##   Loneliness           0.3   0.0
##   Unplugging          11.8   7.0
##   Motivation           0.0   0.0
##   Other                0.0   0.0
##                             
##  Accuracy (average) : 0.2662

The model basically predicts Unplugging all the time (97.3%). Not a bad guess, really.

Holdout Performance

Here is the model performance on the test data set.

preds_knn <- bind_cols(
  # dat_test,
  # predict(mdl_knn, newdata = dat_test, type = "prob"),
  Predicted = predict(mdl_knn, newdata = dat_test, type = "raw"),
  Actual = dat_test$struggle
)

(confusion_test <- confusionMatrix(preds_knn$Predicted, reference = preds_knn$Actual))
## Confusion Matrix and Statistics
## 
##                Reference
## Prediction      Timezones Collaboration Distractions Loneliness Unplugging
##   Timezones             0             0            0          0          0
##   Collaboration         0             1            1          0          5
##   Distractions          0             0            0          0          0
##   Loneliness            0             0            1          0          0
##   Unplugging           30            71           64         70        113
##   Motivation            0             0            0          0          0
##   Other                 0             0            0          0          0
##                Reference
## Prediction      Motivation Other
##   Timezones              0     0
##   Collaboration          2     0
##   Distractions           0     0
##   Loneliness             0     0
##   Unplugging            52    31
##   Motivation             0     0
##   Other                  0     0
## 
## Overall Statistics
##                                          
##                Accuracy : 0.2585         
##                  95% CI : (0.2182, 0.302)
##     No Information Rate : 0.2676         
##     P-Value [Acc > NIR] : 0.6835         
##                                          
##                   Kappa : -0.0091        
##                                          
##  Mcnemar's Test P-Value : NA             
## 
## Statistics by Class:
## 
##                      Class: Timezones Class: Collaboration Class: Distractions
## Sensitivity                   0.00000             0.013889              0.0000
## Specificity                   1.00000             0.978320              1.0000
## Pos Pred Value                    NaN             0.111111                 NaN
## Neg Pred Value                0.93197             0.835648              0.8503
## Prevalence                    0.06803             0.163265              0.1497
## Detection Rate                0.00000             0.002268              0.0000
## Detection Prevalence          0.00000             0.020408              0.0000
## Balanced Accuracy             0.50000             0.496104              0.5000
##                      Class: Loneliness Class: Unplugging Class: Motivation
## Sensitivity                   0.000000           0.95763            0.0000
## Specificity                   0.997305           0.01548            1.0000
## Pos Pred Value                0.000000           0.26218               NaN
## Neg Pred Value                0.840909           0.50000            0.8776
## Prevalence                    0.158730           0.26757            0.1224
## Detection Rate                0.000000           0.25624            0.0000
## Detection Prevalence          0.002268           0.97732            0.0000
## Balanced Accuracy             0.498652           0.48655            0.5000
##                      Class: Other
## Sensitivity               0.00000
## Specificity               1.00000
## Pos Pred Value                NaN
## Neg Pred Value            0.92971
## Prevalence                0.07029
## Detection Rate            0.00000
## Detection Prevalence      0.00000
## Balanced Accuracy         0.50000

Predictably, it predicted Unplugging (97.7%) of the time, and was able to score an overall accuracy of 26%.

Model Evaluation

I’m not sure how variable importance is calculated for kNN. The matrix below places high importance on whether remote work stemmed from the COVID pandemic.

varImp(mdl_knn) %>%
  pluck("importance") %>%
  rownames_to_column(var = "FactorLevel") %>%
  pivot_longer(-FactorLevel) %>%
  ggplot(aes(x = name, y = fct_rev(FactorLevel))) +
  geom_tile(aes(fill = value), show.legend = FALSE) +
  geom_text(aes(label = round(value, 0)), size = 3) +
  theme_light() +
  scale_fill_gradient(low = "white", high = "steelblue") +
  labs(title = "Variable Importance", y = NULL, x = NULL)