titanic-tidymodels: boost simple

1 Load packages

library(tidyverse)  # data wrangling
library(tidymodels)  # modelling

2 Objective

Predicting the survival in the Titanic disaster. We’ll be using a tidymodels approach.

3 Load and prepare data

3.1 Hide details in a function

prepare_data <- function(traindata_url = "https://raw.githubusercontent.com/sebastiansauer/Statistiklehre/main/data/titanic/train.csv",
                         testdata_url = "https://raw.githubusercontent.com/sebastiansauer/Statistiklehre/main/data/titanic/test.csv") {


  # import the data:
  train <- read_csv(traindata_url)
  test <- read_csv(testdata_url)
  
 
  # bind both samples into one:
  data_raw <-
    train %>% 
    bind_rows(test)
  
  
  # drop unused variables:
  data <-
    data_raw %>% 
    select(-c(Name, Cabin, Ticket))
  
  
  # convert string to factors:
  data2 <- 
    data %>% 
    mutate(across(where(is.character), as.factor))
  
  # convert numeric outcome to nominal, to indicate classification:
  data2 <- data2 %>% 
    mutate(Survived = as.factor(Survived))
  
  
  return(data2)
}
data2 <- prepare_data()

4 Split data into train and test

split_titanic <- initial_time_split(data = data2, prop = 891/1309)
train <- training(split_titanic)
test <- testing(split_titanic)

5 Define recipe


titanic_recipe <- 
  
  # define model formula:
  recipe(Survived ~ ., data = train) %>%
  
  # Use "ID" etc as ID, not as predictor:
  update_role(PassengerId, new_role = "ID") %>%   

  # impute missing values:
  step_knnimpute(all_predictors(), neighbors = 3) %>%  
  
  # convert character and factor type variables into dummy variables:
  step_dummy(all_nominal(), -all_outcomes()) %>%   
  
  # exclude near zero variance predictors:
  step_nzv(all_predictors()) %>%  
  
  # exclude highly correlated vars:
  step_corr(all_predictors()) %>% 
  
  # center (set mean to zero):
  step_center(all_predictors(), -all_outcomes()) %>%  
  
  # set sd=1 
  step_scale(all_predictors(), -all_outcomes())  

6 Define model

boost_mod <- 
  boost_tree(mtry = 100, 
             min_n = 10, 
             learn_rate = 0.1,
             tree_depth = 4) %>% 
  set_engine("xgboost") %>% 
  set_mode("classification")

See details of the model:

boost_mod %>% 
  translate()
#> Boosted Tree Model Specification (classification)
#> 
#> Main Arguments:
#>   mtry = 100
#>   min_n = 10
#>   tree_depth = 4
#>   learn_rate = 0.1
#> 
#> Computational engine: xgboost 
#> 
#> Model fit template:
#> parsnip::xgb_train(x = missing_arg(), y = missing_arg(), colsample_bytree = 100, 
#>     min_child_weight = 10, max_depth = 4, eta = 0.1, nthread = 1, 
#>     verbose = 0)

7 Define workflow

titanic_boost_wf1 <-
  workflow() %>% 
  add_model(boost_mod) %>% 
  add_recipe(titanic_recipe)

8 Fit the model

titanic_boost_wf1 <-
  titanic_boost_wf1 %>% 
  fit(data = train)

9 Predict the test data

Test data need to be baked (and hence prepped):

test_baked <- bake(prep(titanic_recipe), 
                   new_data = test)
test_with_pred <-
  test_baked %>% 
  bind_cols(predict(titanic_boost_wf1, test)) %>% 
  rename(pred_glm1 = .pred_class)

10 Save csv file to disk

test_with_pred %>% 
  select(PassengerId, Survived = pred_glm1) %>% 
  write_csv(file = "titanic-boost1.csv")

11 Reproducibility

#> ─ Session info ───────────────────────────────────────────────────────────────────────────────────────────────────────
#>  setting  value                       
#>  version  R version 4.0.2 (2020-06-22)
#>  os       macOS Catalina 10.15.7      
#>  system   x86_64, darwin17.0          
#>  ui       X11                         
#>  language (EN)                        
#>  collate  en_US.UTF-8                 
#>  ctype    en_US.UTF-8                 
#>  tz       Europe/Berlin               
#>  date     2020-12-14                  
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────────────────────────────────────────────
#>  package     * version     date       lib source                             
#>  assertthat    0.2.1       2019-03-21 [1] CRAN (R 4.0.0)                     
#>  backports     1.2.0       2020-11-02 [1] CRAN (R 4.0.2)                     
#>  blogdown      0.21        2020-10-11 [1] CRAN (R 4.0.2)                     
#>  bookdown      0.21        2020-10-13 [1] CRAN (R 4.0.2)                     
#>  broom         0.7.2       2020-10-20 [1] CRAN (R 4.0.2)                     
#>  callr         3.5.1       2020-10-13 [1] CRAN (R 4.0.2)                     
#>  cellranger    1.1.0       2016-07-27 [1] CRAN (R 4.0.0)                     
#>  cli           2.2.0       2020-11-20 [1] CRAN (R 4.0.2)                     
#>  codetools     0.2-16      2018-12-24 [2] CRAN (R 4.0.2)                     
#>  colorspace    2.0-0       2020-11-11 [1] CRAN (R 4.0.2)                     
#>  crayon        1.3.4       2017-09-16 [1] CRAN (R 4.0.0)                     
#>  DBI           1.1.0       2019-12-15 [1] CRAN (R 4.0.0)                     
#>  dbplyr        2.0.0       2020-11-03 [1] CRAN (R 4.0.2)                     
#>  desc          1.2.0       2018-05-01 [1] CRAN (R 4.0.0)                     
#>  devtools      2.3.2       2020-09-18 [1] CRAN (R 4.0.2)                     
#>  digest        0.6.27      2020-10-24 [1] CRAN (R 4.0.2)                     
#>  dplyr       * 1.0.2       2020-08-18 [1] CRAN (R 4.0.2)                     
#>  ellipsis      0.3.1       2020-05-15 [1] CRAN (R 4.0.0)                     
#>  evaluate      0.14        2019-05-28 [1] CRAN (R 4.0.0)                     
#>  fansi         0.4.1       2020-01-08 [1] CRAN (R 4.0.0)                     
#>  forcats     * 0.5.0       2020-03-01 [1] CRAN (R 4.0.0)                     
#>  fs            1.5.0       2020-07-31 [1] CRAN (R 4.0.2)                     
#>  generics      0.1.0       2020-10-31 [1] CRAN (R 4.0.2)                     
#>  ggplot2     * 3.3.2       2020-06-19 [1] CRAN (R 4.0.0)                     
#>  glue          1.4.2       2020-08-27 [1] CRAN (R 4.0.2)                     
#>  gtable        0.3.0       2019-03-25 [1] CRAN (R 4.0.0)                     
#>  haven         2.3.1       2020-06-01 [1] CRAN (R 4.0.0)                     
#>  hms           0.5.3       2020-01-08 [1] CRAN (R 4.0.0)                     
#>  htmltools     0.5.0       2020-06-16 [1] CRAN (R 4.0.0)                     
#>  httr          1.4.2       2020-07-20 [1] CRAN (R 4.0.2)                     
#>  jsonlite      1.7.1       2020-09-07 [1] CRAN (R 4.0.2)                     
#>  knitr         1.30        2020-09-22 [1] CRAN (R 4.0.2)                     
#>  lifecycle     0.2.0       2020-03-06 [1] CRAN (R 4.0.0)                     
#>  lubridate     1.7.9.2     2020-11-13 [1] CRAN (R 4.0.2)                     
#>  magrittr      2.0.1       2020-11-17 [1] CRAN (R 4.0.2)                     
#>  memoise       1.1.0       2017-04-21 [1] CRAN (R 4.0.0)                     
#>  modelr        0.1.8       2020-05-19 [1] CRAN (R 4.0.0)                     
#>  munsell       0.5.0       2018-06-12 [1] CRAN (R 4.0.0)                     
#>  pillar        1.4.7       2020-11-20 [1] CRAN (R 4.0.2)                     
#>  pkgbuild      1.1.0       2020-07-13 [1] CRAN (R 4.0.2)                     
#>  pkgconfig     2.0.3       2019-09-22 [1] CRAN (R 4.0.0)                     
#>  pkgload       1.1.0       2020-05-29 [1] CRAN (R 4.0.0)                     
#>  prettyunits   1.1.1       2020-01-24 [1] CRAN (R 4.0.0)                     
#>  processx      3.4.5       2020-11-30 [1] CRAN (R 4.0.2)                     
#>  ps            1.4.0       2020-10-07 [1] CRAN (R 4.0.2)                     
#>  purrr       * 0.3.4       2020-04-17 [1] CRAN (R 4.0.0)                     
#>  R6            2.5.0       2020-10-28 [1] CRAN (R 4.0.2)                     
#>  Rcpp          1.0.5       2020-07-06 [1] CRAN (R 4.0.2)                     
#>  readr       * 1.4.0       2020-10-05 [1] CRAN (R 4.0.2)                     
#>  readxl        1.3.1       2019-03-13 [1] CRAN (R 4.0.0)                     
#>  remotes       2.2.0       2020-07-21 [1] CRAN (R 4.0.2)                     
#>  reprex        0.3.0       2019-05-16 [1] CRAN (R 4.0.0)                     
#>  rlang         0.4.9       2020-11-26 [1] CRAN (R 4.0.2)                     
#>  rmarkdown     2.5         2020-10-21 [1] CRAN (R 4.0.2)                     
#>  rprojroot     2.0.2       2020-11-15 [1] CRAN (R 4.0.2)                     
#>  rstudioapi    0.13.0-9000 2020-12-09 [1] Github (rstudio/rstudioapi@4baeb39)
#>  rvest         0.3.6       2020-07-25 [1] CRAN (R 4.0.2)                     
#>  scales        1.1.1       2020-05-11 [1] CRAN (R 4.0.0)                     
#>  sessioninfo   1.1.1       2018-11-05 [1] CRAN (R 4.0.0)                     
#>  stringi       1.5.3       2020-09-09 [1] CRAN (R 4.0.2)                     
#>  stringr     * 1.4.0       2019-02-10 [1] CRAN (R 4.0.0)                     
#>  testthat      3.0.0       2020-10-31 [1] CRAN (R 4.0.2)                     
#>  tibble      * 3.0.4       2020-10-12 [1] CRAN (R 4.0.2)                     
#>  tidyr       * 1.1.2       2020-08-27 [1] CRAN (R 4.0.2)                     
#>  tidyselect    1.1.0       2020-05-11 [1] CRAN (R 4.0.0)                     
#>  tidyverse   * 1.3.0       2019-11-21 [1] CRAN (R 4.0.0)                     
#>  usethis       1.6.3       2020-09-17 [1] CRAN (R 4.0.2)                     
#>  vctrs         0.3.5       2020-11-17 [1] CRAN (R 4.0.2)                     
#>  withr         2.3.0       2020-09-22 [1] CRAN (R 4.0.2)                     
#>  xfun          0.19        2020-10-30 [1] CRAN (R 4.0.2)                     
#>  xml2          1.3.2       2020-04-23 [1] CRAN (R 4.0.0)                     
#>  yaml          2.2.1       2020-02-01 [1] CRAN (R 4.0.0)                     
#> 
#> [1] /Users/sebastiansaueruser/Rlibs
#> [2] /Library/Frameworks/R.framework/Versions/4.0/Resources/library