This part explains classification and regression trees, following the work of Breiman et al. (there is a copy at the library). Although there are other and improved versions, we cover CART because it is important to understand fundamentals. The R package rpart
also follows the same study to construct trees.
We are going to use the following packages. Install those packages (use install.packages("packagenamehere")
), if you haven’t already.
library(tidyverse) #For data manipulation
library(rpart) #To construct CART models
library(rpart.plot) # It also includes titanic data
library(rattle) #For visualization
The problem Suppose we would like to correctly identify each survivor in Titanic accident just by looking at some traits such as passenger class, sex and age.
Bonus: A version of this example can be found on Kaggle as a tutorial! Check it out from this link.
set.seed(58) #Set the random seed
data(ptitanic) #Call the data
titanic_data <-
ptitanic %>%
select(survived,pclass,sex,age) %>% #Select only required columns
mutate(age=as.numeric(age)) %>% #This is a fix, just for this data set
filter(complete.cases(.)) %>% #Remove all NA including rows
#Split the data into train and test
mutate(train_test = ifelse(runif(nrow(.)) < 0.25,"test","train")) %>%
tbl_df()
print(titanic_data)
## # A tibble: 1,046 x 5
## survived pclass sex age train_test
## <fctr> <fctr> <fctr> <dbl> <chr>
## 1 survived 1st female 29.0000 train
## 2 survived 1st male 0.9167 test
## 3 died 1st female 2.0000 train
## 4 died 1st male 30.0000 train
## 5 died 1st female 25.0000 train
## 6 survived 1st male 48.0000 train
## 7 survived 1st female 63.0000 test
## 8 died 1st male 39.0000 train
## 9 survived 1st female 53.0000 train
## 10 died 1st male 71.0000 train
## # ... with 1,036 more rows
#Build the model with the training data
titanic_train <- titanic_data %>% filter(train_test == "train") %>% select(-train_test)
titanic_model <- rpart(survived ~ ., data=titanic_train)
fancyRpartPlot(titanic_model)
CART algorithm is a very convenient and easy to interpret method to partition data in a meaningful way with binary splits (i.e. \(x > A\) and \(x \le A\)). Both regression and classification trees aim to minimize the error (misclassification).
Error rate for regression trees is the squared distance (very similar to linear regression) of each item in the node to the average response (\(\hat{Y_k}\)) value of the node. Suppose there are \(K\) final nodes, each group’s response variable is defined as \(\hat{Y_k} = \dfrac{1}{N_k}\sum_{i \in G_k} y_i\). So the final objective function to minimize is as follows.
\[\min \sum_{k}^K (\min_{i \in G_k} (y_i - \hat{Y_k})^2) + \alpha |K| \]
The complexity parameter \(\alpha\) is used as a balancer between the number of final nodes and total error. If \(\alpha = 0\) then each node will consist of a single item to minimize error.
In classification, it is a bit different. For the sake of simplicity, assume binary classification (0 or 1). A node’s class (C_k) is determined by majority. \(C_k = \argmax_{0,1}\{N_{0,k}, N_{1,k}\}\). We also calculate the quality (or probability) of correct classification with \(\hat{p}_{(C_k,k)} = \dfrac{1}{N_k} \sum_{i \in G_k} I(y_i = C_k)\) (proportion of correct class items). For binary clasification we can simply say \(\hat{p}_k\).
There are three types of error: misclassification, Gini index and cross-entropy.
ps. rpart
uses Gini index to measure error.
Fitting the model and getting good results is fine. But, if it cannot predict response values outside the training data, then your model overfits.
## In sample prediction classification results
titanic_in_sample <- predict(titanic_model)
print(head(titanic_in_sample))
## died survived
## 1 0.06358382 0.9364162
## 2 0.06358382 0.9364162
## 3 0.83690987 0.1630901
## 4 0.06358382 0.9364162
## 5 0.83690987 0.1630901
## 6 0.83690987 0.1630901
in_sample_prediction <-
cbind(
titanic_in_sample %>% tbl_df %>%
transmute(survived_predict = ifelse(survived >= 0.5,1,0)),
titanic_train %>% tbl_df %>%
transmute(survived_actual = ifelse(survived == "survived",1,0))
) %>%
mutate(correct_class = (survived_predict == survived_actual)) %>%
group_by(correct_class) %>%
summarise(count=n(),percentage=n()/nrow(.))
print(in_sample_prediction)
## # A tibble: 2 x 3
## correct_class count percentage
## <lgl> <int> <dbl>
## 1 FALSE 146 0.1843434
## 2 TRUE 646 0.8156566
titanic_test <- titanic_data %>% filter(train_test=="test") %>% select(-train_test)
titanic_predict <- predict(titanic_model,newdata=titanic_test)
print(head(titanic_predict))
## died survived
## 1 0.00000000 1.0000000
## 2 0.06358382 0.9364162
## 3 0.83690987 0.1630901
## 4 0.83690987 0.1630901
## 5 0.83690987 0.1630901
## 6 0.06358382 0.9364162
out_of_sample_prediction <-
cbind(
titanic_predict %>% tbl_df %>%
transmute(survived_predict = ifelse(survived >= 0.5,1,0)),
titanic_test %>% tbl_df %>%
transmute(survived_actual = ifelse(survived == "survived",1,0))
) %>%
mutate(correct_class = (survived_predict == survived_actual)) %>%
group_by(correct_class) %>%
summarise(count=n(),percentage=n()/nrow(.))
print(out_of_sample_prediction)
## # A tibble: 2 x 3
## correct_class count percentage
## <lgl> <int> <dbl>
## 1 FALSE 56 0.2204724
## 2 TRUE 198 0.7795276
You can improve the model with these extra properties.
rpart.control(
minsplit = 20, #Min # of items that should be in a node to do a split
minbucket = round(minsplit/3), #Minimum number of items in a final node
cp = 0.01, #Complexity parameter (min improvement to generate a split)
maxcompete = 4, #Not related to model. Some printout for analyses
maxsurrogate = 5, #Used to deal with missing values
usesurrogate = 2, #Used to deal with missing values
xval = 10, #Number of cross validations
surrogatestyle = 0, #Used to deal with missing values
maxdepth = 30 #Tree depth
)
titanic_logit_model <- glm(survived ~ ., data=titanic_train,family=binomial(link = "logit"))
titanic_probit_model <- glm(survived ~ ., data=titanic_train,family=binomial(link = "probit"))
titanic_logit_in_sample <- predict(titanic_logit_model,type="response")
titanic_logit_in_sample_prediction <-
data.frame(in_sample=(titanic_logit_in_sample >= 0.5)*1,
actual=(titanic_train$survived == "survived")*1) %>%
mutate(correct_class= (in_sample == actual)) %>%
group_by(correct_class) %>%
summarise(count=n(),percentage=n()/nrow(.))
print(titanic_logit_in_sample_prediction)
## # A tibble: 2 x 3
## correct_class count percentage
## <lgl> <int> <dbl>
## 1 FALSE 169 0.2133838
## 2 TRUE 623 0.7866162
titanic_logit_out_of_sample <- predict(titanic_logit_model,newdata=titanic_test,type="response")
titanic_logit_out_of_sample_prediction <-
data.frame(out_of_sample=(titanic_logit_out_of_sample >= 0.5)*1,
actual=(titanic_test$survived == "survived")*1) %>%
mutate(correct_class= (out_of_sample == actual)) %>%
group_by(correct_class) %>%
summarise(count=n(),percentage=n()/nrow(.))
print(titanic_logit_out_of_sample_prediction)
## # A tibble: 2 x 3
## correct_class count percentage
## <lgl> <int> <dbl>
## 1 FALSE 59 0.2322835
## 2 TRUE 195 0.7677165
titanic_probit_in_sample <- predict(titanic_probit_model,type="response")
titanic_probit_in_sample_prediction <-
data.frame(in_sample=(titanic_probit_in_sample >= 0.5)*1,
actual=(titanic_train$survived == "survived")*1) %>%
mutate(correct_class= (in_sample == actual)) %>%
group_by(correct_class) %>%
summarise(count=n(),percentage=n()/nrow(.))
print(titanic_probit_in_sample_prediction)
## # A tibble: 2 x 3
## correct_class count percentage
## <lgl> <int> <dbl>
## 1 FALSE 166 0.209596
## 2 TRUE 626 0.790404
titanic_probit_out_of_sample <- predict(titanic_probit_model,newdata=titanic_test,type="response")
titanic_probit_out_of_sample_prediction <-
data.frame(out_of_sample=(titanic_probit_out_of_sample >= 0.5)*1,
actual=(titanic_test$survived == "survived")*1) %>%
mutate(correct_class= (out_of_sample == actual)) %>%
group_by(correct_class) %>%
summarise(count=n(),percentage=n()/nrow(.))
print(titanic_probit_out_of_sample_prediction)
## # A tibble: 2 x 3
## correct_class count percentage
## <lgl> <int> <dbl>
## 1 FALSE 60 0.2362205
## 2 TRUE 194 0.7637795
Let’s see all analyses together. Keep in mind that we did not do any external cross-validation (it would be good exercise).
complete_benchmark <- data.frame(
model = c("CART","Logistic Reg. - Logit Link","Logistic Reg. - Probit Link"),
in_sample_accuracy = c(
in_sample_prediction %>% filter(correct_class) %>% transmute(round(percentage,4)) %>% unlist(),
titanic_logit_in_sample_prediction %>% filter(correct_class) %>% transmute(round(percentage,4)) %>% unlist(),
titanic_probit_in_sample_prediction %>% filter(correct_class) %>% transmute(round(percentage,4)) %>% unlist()
),
out_of_sample_accuracy = c(
out_of_sample_prediction %>% filter(correct_class) %>% transmute(round(percentage,4)) %>% unlist(),
titanic_logit_out_of_sample_prediction %>% filter(correct_class) %>% transmute(round(percentage,4)) %>% unlist(),
titanic_probit_out_of_sample_prediction %>% filter(correct_class) %>% transmute(round(percentage,4)) %>% unlist()
)
)
print(complete_benchmark)
## model in_sample_accuracy out_of_sample_accuracy
## 1 CART 0.8157 0.7795
## 2 Logistic Reg. - Logit Link 0.7866 0.7677
## 3 Logistic Reg. - Probit Link 0.7904 0.7638