Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

catboost and predict threshold #58

Open
pecto2020 opened this issue Feb 2, 2022 · 1 comment
Open

catboost and predict threshold #58

pecto2020 opened this issue Feb 2, 2022 · 1 comment

Comments

@pecto2020
Copy link

pecto2020 commented Feb 2, 2022

Hi,
predict(catboost) in tidymodels doesn't use the default threshold of 0.5 but something else. Does catboost use a class_weight during the training process? In that case how do I change it in tidymodels/treesnip? I attach a comparison between catboost and random forest.
Thanks

library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
library(mlbench)
library(catboost)
library(treesnip)
set_dependency("boost_tree", eng = "catboost", "catboost")
set_dependency("boost_tree", eng = "catboost", "treesnip")


#load data
data(PimaIndiansDiabetes)
diabetes_orig<-PimaIndiansDiabetes

#set random seed
set.seed(123)
#create initial split
diabetes_split <- initial_split(diabetes_orig, prop = 3/4)
diabetes_split
#> <Analysis/Assess/Total>
#> <576/192/768>
#create training set
diabetes_train <- training(diabetes_split)
#create test set
diabetes_test <- testing(diabetes_split)

#train Random Forest

# model specification
trees_spec<-rand_forest()%>%
  set_mode("classification") %>%
  set_engine("ranger")

# fit on training data
trees_fit<-trees_spec %>% fit(diabetes~., data=diabetes_train)

# predict
trees_pred<-predict(trees_fit, diabetes_test)%>%
  bind_cols(predict(trees_fit,diabetes_test, type="prob"))%>%
  bind_cols(diabetes_test%>% select(diabetes)) 
# get metrics
trees_perf<- trees_pred %>%
  roc_auc(truth = diabetes, .pred_pos, event_level="second") %>%
  bind_rows(trees_pred %>% sens(trut = diabetes, .pred_class, event_levels="second"))

# change threshold
trees_05<-trees_pred %>% 
  mutate(
    .pred_class = ifelse(.pred_pos>0.5,"pos","neg"))%>%
  mutate_if(is.character, as.factor)
# get metrics
trees_perf_05<-trees_05%>% 
  roc_auc(truth = diabetes, .pred_pos, event_level="second") %>%
  bind_rows( trees_05 %>% sens(truth = diabetes, .pred_class, event_levels="second"))

trees_perf
#> # A tibble: 2 x 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 roc_auc binary         0.823
#> 2 sens    binary         0.856
trees_perf_05
#> # A tibble: 2 x 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 roc_auc binary         0.823
#> 2 sens    binary         0.856

#train Catboost

# model specification
catboost_spec<-(boost_tree(tree_depth=10) %>% 
                  set_mode("classification") %>%
                  set_engine("catboost", nthread=4))
# fit on training data
catboost_fit<-catboost_spec %>% fit(diabetes~., data=diabetes_train)

# predict
catboost_pred<-predict(catboost_fit, diabetes_test) %>%
  bind_cols(predict(catboost_fit,diabetes_test, type="prob"))%>%
  bind_cols(diabetes_test%>% select(diabetes)) 

# get metrics
catboost_perf<- catboost_pred %>%
  roc_auc(truth = diabetes, .pred_pos, event_level="second") %>%
  bind_rows(catboost_pred %>% sens(truth = diabetes, .pred_class, event_levels="second"))



#  change threshold
catboost_05<-catboost_pred %>% 
  mutate(
    .pred_class = ifelse(.pred_pos>0.5,"pos","neg"))%>%
  mutate_if(is.character, as.factor)
# get metrics
catboost_perf_05<-catboost_05%>% 
  roc_auc(truth = diabetes, .pred_pos, event_level="second") %>%
  bind_rows(catboost_05 %>% sens(truth = diabetes, .pred_class, event_levels="second"))

catboost_perf
#> # A tibble: 2 x 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 roc_auc binary         0.801
#> 2 sens    binary         1
catboost_perf_05
#> # A tibble: 2 x 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 roc_auc binary         0.801
#> 2 sens    binary         0.992

Created on 2022-02-02 by the reprex package (v2.0.1)

@pecto2020
Copy link
Author

pecto2020 commented Feb 2, 2022

Notably, using catboost with caret seems to work

library(mlbench)
library(catboost)
library(caret)
#> Loading required package: ggplot2
#> Loading required package: lattice
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip


data(PimaIndiansDiabetes)
diabetes_orig<-PimaIndiansDiabetes

#set random seed
set.seed(123)
#create initial split
diabetes_split <- initial_split(diabetes_orig, prop = 3/4)
diabetes_split
#> <Analysis/Assess/Total>
#> <576/192/768>
#create training set
diabetes_train <- training(diabetes_split)
#create test set
diabetes_test <- testing(diabetes_split)


fitControl <- trainControl(method = "cv",
                             number = 3,
                             savePredictions = TRUE,
                             summaryFunction = twoClassSummary,
                             classProbs = TRUE)

model <- train(x = diabetes_train %>% select(-diabetes),
               y = diabetes_train$diabetes,
               method = catboost.caret, 
               trControl = fitControl, 
               tuneLength = 3,
               metric = "ROC")


preds1<-predict(model, diabetes_test) %>% as_tibble() %>% mutate(.pred_class = value, .keep="unused") %>%
    bind_cols(predict(model,diabetes_test, type="prob")) %>%
    bind_cols(diabetes_test %>% select(diabetes))

  
  
preds1%>% roc_auc(truth = diabetes,pos, event_level="second") %>%
    bind_rows( preds1 %>% sens(truth = diabetes, .pred_class, event_levels="second"))
#> # A tibble: 2 x 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 roc_auc binary         0.821
#> 2 sens    binary         0.848
    

preds1_05<-preds1 %>% mutate(
      .pred_class = ifelse(pos>0.5,"pos","neg"))%>%
      mutate_if(is.character, as.factor)
    
    preds1_05%>% roc_auc(truth = diabetes,pos, event_level="second") %>%
      bind_rows( preds1_05 %>% sens(truth = diabetes, .pred_class, event_levels="second"))
#> # A tibble: 2 x 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 roc_auc binary         0.821
#> 2 sens    binary         0.848
Created on 2022-02-02 by the reprex package (v2.0.1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant