Diamons Price Estimation with CART - Decision Tree

In this article i’ll try to predict prices of diamonds using decision trees. diamonds Dataset is used in tidyverse package.
CART modeling method is used for predicting prices. Materials used in this article refered at the refences section.

Exploratory Data Analysis

library(tidyverse) #For data manipulation
library(rpart) #To construct CART models
library(rpart.plot) 
library(rattle)

Diamond dataset is containing the prices and other attributes of diamonds such as carat, clarity etc.

Dataset contains 53940 obversations and 10 variables.

Structure of Dataset

glimpse(diamonds)
## Observations: 53,940
## Variables: 10
## $ carat   <dbl> 0.23, 0.21, 0.23, 0.29, 0.31, 0.24, 0.24, 0.26, 0.22, ...
## $ cut     <ord> Ideal, Premium, Good, Premium, Good, Very Good, Very G...
## $ color   <ord> E, E, E, I, J, J, I, H, E, H, J, J, F, J, E, E, I, J, ...
## $ clarity <ord> SI2, SI1, VS1, VS2, SI2, VVS2, VVS1, SI1, VS2, VS1, SI...
## $ depth   <dbl> 61.5, 59.8, 56.9, 62.4, 63.3, 62.8, 62.3, 61.9, 65.1, ...
## $ table   <dbl> 55, 61, 65, 58, 58, 57, 57, 55, 61, 61, 55, 56, 61, 54...
## $ price   <int> 326, 326, 327, 334, 335, 336, 336, 337, 337, 338, 339,...
## $ x       <dbl> 3.95, 3.89, 4.05, 4.20, 4.34, 3.94, 3.95, 4.07, 3.87, ...
## $ y       <dbl> 3.98, 3.84, 4.07, 4.23, 4.35, 3.96, 3.98, 4.11, 3.78, ...
## $ z       <dbl> 2.43, 2.31, 2.31, 2.63, 2.75, 2.48, 2.47, 2.53, 2.49, ...

Description of each variables

Variables Description
carat weight of the diamond (0.2–5.01)
cut quality of the cut (Fair, Good, Very Good, Premium, Ideal)
color diamond colour, from J (worst) to D (best)
clarity a measurement of how clear the diamond is (I1 (worst), SI2, SI1, VS2, VS1, VVS2, VVS1, IF (best))
depth total depth percentage = z / mean(x, y) = 2 * z / (x + y) (43–79)
table width of top of diamond relative to widest point (43–95)
price price in US dollars ($326–$18,823)
x length in mm (0–10.74)
y width in mm (0–58.9)
z depth in mm (0–31.8)

Let’s summarize the dataset and show first 10 obversation

summary(diamonds)
##      carat               cut        color        clarity     
##  Min.   :0.2000   Fair     : 1610   D: 6775   SI1    :13065  
##  1st Qu.:0.4000   Good     : 4906   E: 9797   VS2    :12258  
##  Median :0.7000   Very Good:12082   F: 9542   SI2    : 9194  
##  Mean   :0.7979   Premium  :13791   G:11292   VS1    : 8171  
##  3rd Qu.:1.0400   Ideal    :21551   H: 8304   VVS2   : 5066  
##  Max.   :5.0100                     I: 5422   VVS1   : 3655  
##                                     J: 2808   (Other): 2531  
##      depth           table           price             x         
##  Min.   :43.00   Min.   :43.00   Min.   :  326   Min.   : 0.000  
##  1st Qu.:61.00   1st Qu.:56.00   1st Qu.:  950   1st Qu.: 4.710  
##  Median :61.80   Median :57.00   Median : 2401   Median : 5.700  
##  Mean   :61.75   Mean   :57.46   Mean   : 3933   Mean   : 5.731  
##  3rd Qu.:62.50   3rd Qu.:59.00   3rd Qu.: 5324   3rd Qu.: 6.540  
##  Max.   :79.00   Max.   :95.00   Max.   :18823   Max.   :10.740  
##                                                                  
##        y                z         
##  Min.   : 0.000   Min.   : 0.000  
##  1st Qu.: 4.720   1st Qu.: 2.910  
##  Median : 5.710   Median : 3.530  
##  Mean   : 5.735   Mean   : 3.539  
##  3rd Qu.: 6.540   3rd Qu.: 4.040  
##  Max.   :58.900   Max.   :31.800  
## 
head(diamonds,10)
## # A tibble: 10 x 10
##    carat       cut color clarity depth table price     x     y     z
##    <dbl>     <ord> <ord>   <ord> <dbl> <dbl> <int> <dbl> <dbl> <dbl>
##  1  0.23     Ideal     E     SI2  61.5    55   326  3.95  3.98  2.43
##  2  0.21   Premium     E     SI1  59.8    61   326  3.89  3.84  2.31
##  3  0.23      Good     E     VS1  56.9    65   327  4.05  4.07  2.31
##  4  0.29   Premium     I     VS2  62.4    58   334  4.20  4.23  2.63
##  5  0.31      Good     J     SI2  63.3    58   335  4.34  4.35  2.75
##  6  0.24 Very Good     J    VVS2  62.8    57   336  3.94  3.96  2.48
##  7  0.24 Very Good     I    VVS1  62.3    57   336  3.95  3.98  2.47
##  8  0.26 Very Good     H     SI1  61.9    55   337  4.07  4.11  2.53
##  9  0.22      Fair     E     VS2  65.1    61   337  3.87  3.78  2.49
## 10  0.23 Very Good     H     VS1  59.4    61   338  4.00  4.05  2.39

Exploring the data

Correlation of variable with Price

If we examine the correlation of variables with price. We can see that there is strong positive correlation with x,y,z and carat. Correlation with table and depth variable is low.

with(diamonds,
     data.frame(cor_x_price = cor(x, price),
                cor_y_price = cor(y, price),
                cor_z_price = cor(z, price),
                cor_depth_price = cor(depth, price),
                cor_table_price2 = cor(table, price),
                cor_carat_price3 = cor(carat, price)
     )
)
##   cor_x_price cor_y_price cor_z_price cor_depth_price cor_table_price2
## 1   0.8844352   0.8654209   0.8612494      -0.0106474        0.1271339
##   cor_carat_price3
## 1        0.9215913

As the carat of diamond increases price also increase as stated in correlation also. Price range is very big even at the same carat. This shows us that there are other variables that effect the price also.

ggplot(diamonds) + 
  geom_point(aes(x = carat, y = price), alpha = .05, color = "blue") + 
  labs(title = "Diamond Price vs. Length", x = "Length of diamond (mm)", y = "Price of diamond (USD)")

Let’s add the clarity information to plot.

ggplot(aes(x = carat, y = price), data = diamonds) + 
  geom_point(alpha = 0.5, size = 1, position = 'jitter',aes(color=clarity)) +
  scale_color_brewer(type = 'div',
                     guide = guide_legend(title = 'Clarity', reverse = T,
                                          override.aes = list(alpha = 1, size = 2))) +                         
  ggtitle('Price by Carat and Clarity')

CART modelling

First we prepare the dataset we use to train the model then test the model.

set.seed(58) #Set the random seed

#Test data
diamonds_test <- diamonds %>% mutate(diamond_id = row_number()) %>% 
    group_by(cut, color, clarity) %>% sample_frac(0.2) %>% ungroup()

#Train data
diamonds_train <- anti_join(diamonds %>% mutate(diamond_id = row_number()), 
    diamonds_test, by = "diamond_id")

As there is no correlation between price and table & depth. These variables are removed.

diamond_model <- rpart(price ~ x+y+z+carat+cut+color+clarity, data=diamonds_train)

summary(diamond_model)
## Call:
## rpart(formula = price ~ x + y + z + carat + cut + color + clarity, 
##     data = diamonds_train)
##   n= 43143 
## 
##           CP nsplit rel error    xerror        xstd
## 1 0.60896855      0 1.0000000 1.0000530 0.009840248
## 2 0.18582398      1 0.3910315 0.3910476 0.004379470
## 3 0.03373935      2 0.2052075 0.2052427 0.002298806
## 4 0.02640837      3 0.1714681 0.1716709 0.002295891
## 5 0.02584292      4 0.1450598 0.1495979 0.002060551
## 6 0.01000000      5 0.1192168 0.1198359 0.001717405
## 
## Variable importance
##   carat       y       x       z clarity   color 
##      25      24      24      23       2       1 
## 
## Node number 1: 43143 observations,    complexity param=0.6089685
##   mean=3929.224, MSE=1.589526e+07 
##   left son=2 (27933 obs) right son=3 (15210 obs)
##   Primary splits:
##       carat < 0.995 to the left,  improve=0.60896850, (0 missing)
##       y     < 6.345 to the left,  improve=0.60707500, (0 missing)
##       x     < 6.305 to the left,  improve=0.60380560, (0 missing)
##       z     < 3.915 to the left,  improve=0.59811950, (0 missing)
##       color splits as  LLLLRRR,   improve=0.02246444, (0 missing)
##   Surrogate splits:
##       x       < 6.275 to the left,  agree=0.984, adj=0.954, (0 split)
##       y       < 6.285 to the left,  agree=0.981, adj=0.947, (0 split)
##       z       < 3.895 to the left,  agree=0.977, adj=0.936, (0 split)
##       clarity splits as  RRLLLLLL,  agree=0.679, adj=0.090, (0 split)
##       color   splits as  LLLLLRR,   agree=0.660, adj=0.037, (0 split)
## 
## Node number 2: 27933 observations,    complexity param=0.03373935
##   mean=1633.408, MSE=1247559 
##   left son=4 (19914 obs) right son=5 (8019 obs)
##   Primary splits:
##       y       < 5.525 to the left,  improve=0.66395110, (0 missing)
##       carat   < 0.625 to the left,  improve=0.66358320, (0 missing)
##       x       < 5.465 to the left,  improve=0.66085580, (0 missing)
##       z       < 3.375 to the left,  improve=0.65940960, (0 missing)
##       clarity splits as  RRRLLLLL,  improve=0.01034406, (0 missing)
##   Surrogate splits:
##       x       < 5.495 to the left,  agree=0.992, adj=0.974, (0 split)
##       carat   < 0.635 to the left,  agree=0.991, adj=0.970, (0 split)
##       z       < 3.395 to the left,  agree=0.984, adj=0.945, (0 split)
##       clarity splits as  RRLLLLLL,  agree=0.723, adj=0.036, (0 split)
##       cut     splits as  RLLLL,     agree=0.716, adj=0.010, (0 split)
## 
## Node number 3: 15210 observations,    complexity param=0.185824
##   mean=8145.466, MSE=1.533921e+07 
##   left son=6 (10283 obs) right son=7 (4927 obs)
##   Primary splits:
##       y       < 7.195 to the left,  improve=0.54619500, (0 missing)
##       carat   < 1.485 to the left,  improve=0.53825550, (0 missing)
##       x       < 7.195 to the left,  improve=0.53802950, (0 missing)
##       z       < 4.425 to the left,  improve=0.52571030, (0 missing)
##       clarity splits as  LLLRRRRR,  improve=0.05603985, (0 missing)
##   Surrogate splits:
##       x     < 7.185 to the left,  agree=0.984, adj=0.951, (0 split)
##       carat < 1.445 to the left,  agree=0.979, adj=0.936, (0 split)
##       z     < 4.435 to the left,  agree=0.964, adj=0.889, (0 split)
##       color splits as  LLLLLLR,   agree=0.678, adj=0.004, (0 split)
## 
## Node number 4: 19914 observations
##   mean=1055.871, MSE=274215.1 
## 
## Node number 5: 8019 observations
##   mean=3067.634, MSE=779390.7 
## 
## Node number 6: 10283 observations,    complexity param=0.02640837
##   mean=6141.886, MSE=4709926 
##   left son=12 (7835 obs) right son=13 (2448 obs)
##   Primary splits:
##       clarity splits as  LLLLRRRR,  improve=0.3739261, (0 missing)
##       y       < 6.775 to the left,  improve=0.1191022, (0 missing)
##       carat   < 1.175 to the left,  improve=0.1046128, (0 missing)
##       x       < 6.775 to the left,  improve=0.1035341, (0 missing)
##       color   splits as  RRRRLLL,   improve=0.1022999, (0 missing)
## 
## Node number 7: 4927 observations,    complexity param=0.02584292
##   mean=12327.08, MSE=1.165918e+07 
##   left son=14 (3213 obs) right son=15 (1714 obs)
##   Primary splits:
##       y       < 7.855 to the left,  improve=0.30851000, (0 missing)
##       x       < 7.845 to the left,  improve=0.29967630, (0 missing)
##       carat   < 1.915 to the left,  improve=0.28846080, (0 missing)
##       z       < 4.805 to the left,  improve=0.27258020, (0 missing)
##       clarity splits as  LRRRRRRR,  improve=0.07369484, (0 missing)
##   Surrogate splits:
##       x       < 7.885 to the left,  agree=0.983, adj=0.950, (0 split)
##       carat   < 1.835 to the left,  agree=0.974, adj=0.925, (0 split)
##       z       < 4.825 to the left,  agree=0.953, adj=0.863, (0 split)
##       clarity splits as  RRLLLLLL,  agree=0.681, adj=0.082, (0 split)
## 
## Node number 12: 7835 observations
##   mean=5400.087, MSE=2065554 
## 
## Node number 13: 2448 observations
##   mean=8516.066, MSE=5775533 
## 
## Node number 14: 3213 observations
##   mean=10941.86, MSE=8708725 
## 
## Node number 15: 1714 observations
##   mean=14923.76, MSE=6850265

From the summary we can see that carat, y, x and z variables are important.

Next we can draw the decision tree based on our model

fancyRpartPlot(diamond_model)

With decision tree we can test our model with test data.

pred_Diamond_test <- predict(diamond_model, newdata = diamonds_test)
head(pred_Diamond_test,10)
##         1         2         3         4         5         6         7 
## 10941.860  3067.634 10941.860  1055.871  3067.634  5400.087  3067.634 
##         8         9        10 
##  3067.634  3067.634  5400.087

Before we prune the tree, display CP table look for the lowest cross-validation error(xerror). Lowest xerror is 0.12015 at CP value of 0.01

printcp(diamond_model)
## 
## Regression tree:
## rpart(formula = price ~ x + y + z + carat + cut + color + clarity, 
##     data = diamonds_train)
## 
## Variables actually used in tree construction:
## [1] carat   clarity y      
## 
## Root node error: 6.8577e+11/43143 = 15895265
## 
## n= 43143 
## 
##         CP nsplit rel error  xerror      xstd
## 1 0.608969      0   1.00000 1.00005 0.0098402
## 2 0.185824      1   0.39103 0.39105 0.0043795
## 3 0.033739      2   0.20521 0.20524 0.0022988
## 4 0.026408      3   0.17147 0.17167 0.0022959
## 5 0.025843      4   0.14506 0.14960 0.0020606
## 6 0.010000      5   0.11922 0.11984 0.0017174
#Get the lowest CP value from CP table
min.xerror <- diamond_model$cptable[which.min(diamond_model$cptable[,"xerror"]),"CP"]

min.xerror
## [1] 0.01

Next, we prune the tree based on this value of CP:

# Prune the tree
diamond_model.pruned <- prune(diamond_model, cp = min.xerror) 

# Draw the prune tree
fancyRpartPlot(diamond_model.pruned)

Then use this prune tree to evaluate the our test data

pred_Diamond_test.pruned <- predict(diamond_model.pruned, newdata = diamonds_test)

Obtain the pseudo R2 - a correlation.

fitcorr <- format(cor(diamonds_test$price, pred_Diamond_test.pruned)^2, digits=4)

fitcorr
## [1] "0.8771"

Result

# Create a data frame with the predictions for each method
all.predictions <- data.frame(actual = diamonds_test$price,
                              full.tree = pred_Diamond_test,
                              pruned.tree = pred_Diamond_test.pruned)

#For each actual create model and predictions row
all.predictions <- gather(all.predictions, key = model, value = predictions, 2:3)

# Plot "Predicted vs. actual, by model""
ggplot(data = all.predictions, aes(x = actual, y = predictions)) + 
  geom_point(colour = "blue") + 
  geom_abline(intercept = 0, slope = 1, colour = "red") +
  facet_wrap(~ model, ncol = 2) + 
  ggtitle("Predicted vs. Actual, by model")