In this article i’ll try to predict prices of diamonds using decision trees. diamonds Dataset is used in tidyverse package.
CART modeling method is used for predicting prices. Materials used in this article refered at the refences section.
library(tidyverse) #For data manipulation
library(rpart) #To construct CART models
library(rpart.plot)
library(rattle)
Diamond dataset is containing the prices and other attributes of diamonds such as carat, clarity etc.
Dataset contains 53940 obversations and 10 variables.
glimpse(diamonds)
## Observations: 53,940
## Variables: 10
## $ carat <dbl> 0.23, 0.21, 0.23, 0.29, 0.31, 0.24, 0.24, 0.26, 0.22, ...
## $ cut <ord> Ideal, Premium, Good, Premium, Good, Very Good, Very G...
## $ color <ord> E, E, E, I, J, J, I, H, E, H, J, J, F, J, E, E, I, J, ...
## $ clarity <ord> SI2, SI1, VS1, VS2, SI2, VVS2, VVS1, SI1, VS2, VS1, SI...
## $ depth <dbl> 61.5, 59.8, 56.9, 62.4, 63.3, 62.8, 62.3, 61.9, 65.1, ...
## $ table <dbl> 55, 61, 65, 58, 58, 57, 57, 55, 61, 61, 55, 56, 61, 54...
## $ price <int> 326, 326, 327, 334, 335, 336, 336, 337, 337, 338, 339,...
## $ x <dbl> 3.95, 3.89, 4.05, 4.20, 4.34, 3.94, 3.95, 4.07, 3.87, ...
## $ y <dbl> 3.98, 3.84, 4.07, 4.23, 4.35, 3.96, 3.98, 4.11, 3.78, ...
## $ z <dbl> 2.43, 2.31, 2.31, 2.63, 2.75, 2.48, 2.47, 2.53, 2.49, ...
Variables | Description |
---|---|
carat | weight of the diamond (0.2–5.01) |
cut | quality of the cut (Fair, Good, Very Good, Premium, Ideal) |
color | diamond colour, from J (worst) to D (best) |
clarity | a measurement of how clear the diamond is (I1 (worst), SI2, SI1, VS2, VS1, VVS2, VVS1, IF (best)) |
depth | total depth percentage = z / mean(x, y) = 2 * z / (x + y) (43–79) |
table | width of top of diamond relative to widest point (43–95) |
price | price in US dollars ($326–$18,823) |
x | length in mm (0–10.74) |
y | width in mm (0–58.9) |
z | depth in mm (0–31.8) |
Let’s summarize the dataset and show first 10 obversation
summary(diamonds)
## carat cut color clarity
## Min. :0.2000 Fair : 1610 D: 6775 SI1 :13065
## 1st Qu.:0.4000 Good : 4906 E: 9797 VS2 :12258
## Median :0.7000 Very Good:12082 F: 9542 SI2 : 9194
## Mean :0.7979 Premium :13791 G:11292 VS1 : 8171
## 3rd Qu.:1.0400 Ideal :21551 H: 8304 VVS2 : 5066
## Max. :5.0100 I: 5422 VVS1 : 3655
## J: 2808 (Other): 2531
## depth table price x
## Min. :43.00 Min. :43.00 Min. : 326 Min. : 0.000
## 1st Qu.:61.00 1st Qu.:56.00 1st Qu.: 950 1st Qu.: 4.710
## Median :61.80 Median :57.00 Median : 2401 Median : 5.700
## Mean :61.75 Mean :57.46 Mean : 3933 Mean : 5.731
## 3rd Qu.:62.50 3rd Qu.:59.00 3rd Qu.: 5324 3rd Qu.: 6.540
## Max. :79.00 Max. :95.00 Max. :18823 Max. :10.740
##
## y z
## Min. : 0.000 Min. : 0.000
## 1st Qu.: 4.720 1st Qu.: 2.910
## Median : 5.710 Median : 3.530
## Mean : 5.735 Mean : 3.539
## 3rd Qu.: 6.540 3rd Qu.: 4.040
## Max. :58.900 Max. :31.800
##
head(diamonds,10)
## # A tibble: 10 x 10
## carat cut color clarity depth table price x y z
## <dbl> <ord> <ord> <ord> <dbl> <dbl> <int> <dbl> <dbl> <dbl>
## 1 0.23 Ideal E SI2 61.5 55 326 3.95 3.98 2.43
## 2 0.21 Premium E SI1 59.8 61 326 3.89 3.84 2.31
## 3 0.23 Good E VS1 56.9 65 327 4.05 4.07 2.31
## 4 0.29 Premium I VS2 62.4 58 334 4.20 4.23 2.63
## 5 0.31 Good J SI2 63.3 58 335 4.34 4.35 2.75
## 6 0.24 Very Good J VVS2 62.8 57 336 3.94 3.96 2.48
## 7 0.24 Very Good I VVS1 62.3 57 336 3.95 3.98 2.47
## 8 0.26 Very Good H SI1 61.9 55 337 4.07 4.11 2.53
## 9 0.22 Fair E VS2 65.1 61 337 3.87 3.78 2.49
## 10 0.23 Very Good H VS1 59.4 61 338 4.00 4.05 2.39
If we examine the correlation of variables with price. We can see that there is strong positive correlation with x,y,z and carat. Correlation with table and depth variable is low.
with(diamonds,
data.frame(cor_x_price = cor(x, price),
cor_y_price = cor(y, price),
cor_z_price = cor(z, price),
cor_depth_price = cor(depth, price),
cor_table_price2 = cor(table, price),
cor_carat_price3 = cor(carat, price)
)
)
## cor_x_price cor_y_price cor_z_price cor_depth_price cor_table_price2
## 1 0.8844352 0.8654209 0.8612494 -0.0106474 0.1271339
## cor_carat_price3
## 1 0.9215913
As the carat of diamond increases price also increase as stated in correlation also. Price range is very big even at the same carat. This shows us that there are other variables that effect the price also.
ggplot(diamonds) +
geom_point(aes(x = carat, y = price), alpha = .05, color = "blue") +
labs(title = "Diamond Price vs. Length", x = "Length of diamond (mm)", y = "Price of diamond (USD)")
Let’s add the clarity information to plot.
ggplot(aes(x = carat, y = price), data = diamonds) +
geom_point(alpha = 0.5, size = 1, position = 'jitter',aes(color=clarity)) +
scale_color_brewer(type = 'div',
guide = guide_legend(title = 'Clarity', reverse = T,
override.aes = list(alpha = 1, size = 2))) +
ggtitle('Price by Carat and Clarity')
First we prepare the dataset we use to train the model then test the model.
set.seed(58) #Set the random seed
#Test data
diamonds_test <- diamonds %>% mutate(diamond_id = row_number()) %>%
group_by(cut, color, clarity) %>% sample_frac(0.2) %>% ungroup()
#Train data
diamonds_train <- anti_join(diamonds %>% mutate(diamond_id = row_number()),
diamonds_test, by = "diamond_id")
As there is no correlation between price and table & depth. These variables are removed.
diamond_model <- rpart(price ~ x+y+z+carat+cut+color+clarity, data=diamonds_train)
summary(diamond_model)
## Call:
## rpart(formula = price ~ x + y + z + carat + cut + color + clarity,
## data = diamonds_train)
## n= 43143
##
## CP nsplit rel error xerror xstd
## 1 0.60896855 0 1.0000000 1.0000530 0.009840248
## 2 0.18582398 1 0.3910315 0.3910476 0.004379470
## 3 0.03373935 2 0.2052075 0.2052427 0.002298806
## 4 0.02640837 3 0.1714681 0.1716709 0.002295891
## 5 0.02584292 4 0.1450598 0.1495979 0.002060551
## 6 0.01000000 5 0.1192168 0.1198359 0.001717405
##
## Variable importance
## carat y x z clarity color
## 25 24 24 23 2 1
##
## Node number 1: 43143 observations, complexity param=0.6089685
## mean=3929.224, MSE=1.589526e+07
## left son=2 (27933 obs) right son=3 (15210 obs)
## Primary splits:
## carat < 0.995 to the left, improve=0.60896850, (0 missing)
## y < 6.345 to the left, improve=0.60707500, (0 missing)
## x < 6.305 to the left, improve=0.60380560, (0 missing)
## z < 3.915 to the left, improve=0.59811950, (0 missing)
## color splits as LLLLRRR, improve=0.02246444, (0 missing)
## Surrogate splits:
## x < 6.275 to the left, agree=0.984, adj=0.954, (0 split)
## y < 6.285 to the left, agree=0.981, adj=0.947, (0 split)
## z < 3.895 to the left, agree=0.977, adj=0.936, (0 split)
## clarity splits as RRLLLLLL, agree=0.679, adj=0.090, (0 split)
## color splits as LLLLLRR, agree=0.660, adj=0.037, (0 split)
##
## Node number 2: 27933 observations, complexity param=0.03373935
## mean=1633.408, MSE=1247559
## left son=4 (19914 obs) right son=5 (8019 obs)
## Primary splits:
## y < 5.525 to the left, improve=0.66395110, (0 missing)
## carat < 0.625 to the left, improve=0.66358320, (0 missing)
## x < 5.465 to the left, improve=0.66085580, (0 missing)
## z < 3.375 to the left, improve=0.65940960, (0 missing)
## clarity splits as RRRLLLLL, improve=0.01034406, (0 missing)
## Surrogate splits:
## x < 5.495 to the left, agree=0.992, adj=0.974, (0 split)
## carat < 0.635 to the left, agree=0.991, adj=0.970, (0 split)
## z < 3.395 to the left, agree=0.984, adj=0.945, (0 split)
## clarity splits as RRLLLLLL, agree=0.723, adj=0.036, (0 split)
## cut splits as RLLLL, agree=0.716, adj=0.010, (0 split)
##
## Node number 3: 15210 observations, complexity param=0.185824
## mean=8145.466, MSE=1.533921e+07
## left son=6 (10283 obs) right son=7 (4927 obs)
## Primary splits:
## y < 7.195 to the left, improve=0.54619500, (0 missing)
## carat < 1.485 to the left, improve=0.53825550, (0 missing)
## x < 7.195 to the left, improve=0.53802950, (0 missing)
## z < 4.425 to the left, improve=0.52571030, (0 missing)
## clarity splits as LLLRRRRR, improve=0.05603985, (0 missing)
## Surrogate splits:
## x < 7.185 to the left, agree=0.984, adj=0.951, (0 split)
## carat < 1.445 to the left, agree=0.979, adj=0.936, (0 split)
## z < 4.435 to the left, agree=0.964, adj=0.889, (0 split)
## color splits as LLLLLLR, agree=0.678, adj=0.004, (0 split)
##
## Node number 4: 19914 observations
## mean=1055.871, MSE=274215.1
##
## Node number 5: 8019 observations
## mean=3067.634, MSE=779390.7
##
## Node number 6: 10283 observations, complexity param=0.02640837
## mean=6141.886, MSE=4709926
## left son=12 (7835 obs) right son=13 (2448 obs)
## Primary splits:
## clarity splits as LLLLRRRR, improve=0.3739261, (0 missing)
## y < 6.775 to the left, improve=0.1191022, (0 missing)
## carat < 1.175 to the left, improve=0.1046128, (0 missing)
## x < 6.775 to the left, improve=0.1035341, (0 missing)
## color splits as RRRRLLL, improve=0.1022999, (0 missing)
##
## Node number 7: 4927 observations, complexity param=0.02584292
## mean=12327.08, MSE=1.165918e+07
## left son=14 (3213 obs) right son=15 (1714 obs)
## Primary splits:
## y < 7.855 to the left, improve=0.30851000, (0 missing)
## x < 7.845 to the left, improve=0.29967630, (0 missing)
## carat < 1.915 to the left, improve=0.28846080, (0 missing)
## z < 4.805 to the left, improve=0.27258020, (0 missing)
## clarity splits as LRRRRRRR, improve=0.07369484, (0 missing)
## Surrogate splits:
## x < 7.885 to the left, agree=0.983, adj=0.950, (0 split)
## carat < 1.835 to the left, agree=0.974, adj=0.925, (0 split)
## z < 4.825 to the left, agree=0.953, adj=0.863, (0 split)
## clarity splits as RRLLLLLL, agree=0.681, adj=0.082, (0 split)
##
## Node number 12: 7835 observations
## mean=5400.087, MSE=2065554
##
## Node number 13: 2448 observations
## mean=8516.066, MSE=5775533
##
## Node number 14: 3213 observations
## mean=10941.86, MSE=8708725
##
## Node number 15: 1714 observations
## mean=14923.76, MSE=6850265
From the summary we can see that carat, y, x and z variables are important.
Next we can draw the decision tree based on our model
fancyRpartPlot(diamond_model)
With decision tree we can test our model with test data.
pred_Diamond_test <- predict(diamond_model, newdata = diamonds_test)
head(pred_Diamond_test,10)
## 1 2 3 4 5 6 7
## 10941.860 3067.634 10941.860 1055.871 3067.634 5400.087 3067.634
## 8 9 10
## 3067.634 3067.634 5400.087
Before we prune the tree, display CP table look for the lowest cross-validation error(xerror). Lowest xerror is 0.12015 at CP value of 0.01
printcp(diamond_model)
##
## Regression tree:
## rpart(formula = price ~ x + y + z + carat + cut + color + clarity,
## data = diamonds_train)
##
## Variables actually used in tree construction:
## [1] carat clarity y
##
## Root node error: 6.8577e+11/43143 = 15895265
##
## n= 43143
##
## CP nsplit rel error xerror xstd
## 1 0.608969 0 1.00000 1.00005 0.0098402
## 2 0.185824 1 0.39103 0.39105 0.0043795
## 3 0.033739 2 0.20521 0.20524 0.0022988
## 4 0.026408 3 0.17147 0.17167 0.0022959
## 5 0.025843 4 0.14506 0.14960 0.0020606
## 6 0.010000 5 0.11922 0.11984 0.0017174
#Get the lowest CP value from CP table
min.xerror <- diamond_model$cptable[which.min(diamond_model$cptable[,"xerror"]),"CP"]
min.xerror
## [1] 0.01
Next, we prune the tree based on this value of CP:
# Prune the tree
diamond_model.pruned <- prune(diamond_model, cp = min.xerror)
# Draw the prune tree
fancyRpartPlot(diamond_model.pruned)
Then use this prune tree to evaluate the our test data
pred_Diamond_test.pruned <- predict(diamond_model.pruned, newdata = diamonds_test)
Obtain the pseudo R2 - a correlation.
fitcorr <- format(cor(diamonds_test$price, pred_Diamond_test.pruned)^2, digits=4)
fitcorr
## [1] "0.8771"
# Create a data frame with the predictions for each method
all.predictions <- data.frame(actual = diamonds_test$price,
full.tree = pred_Diamond_test,
pruned.tree = pred_Diamond_test.pruned)
#For each actual create model and predictions row
all.predictions <- gather(all.predictions, key = model, value = predictions, 2:3)
# Plot "Predicted vs. actual, by model""
ggplot(data = all.predictions, aes(x = actual, y = predictions)) +
geom_point(colour = "blue") +
geom_abline(intercept = 0, slope = 1, colour = "red") +
facet_wrap(~ model, ncol = 2) +
ggtitle("Predicted vs. Actual, by model")