set.seed(503)
#libraries
library(tidyverse)
library(rpart) #for the CART models
library(rpart.plot)
library(dplyr) #For data manipulation
#Spliting the data set
diamonds_test <- diamonds %>% mutate(diamond_id = row_number()) %>%
group_by(cut, color, clarity) %>% sample_frac(0.2) %>% ungroup()
diamonds_train <- anti_join(diamonds %>% mutate(diamond_id = row_number()),
diamonds_test, by = "diamond_id")
Diamonds dataset This dataset contains the prices and other attributes for around 50.000 diamons. The features included are the following:
carat: Weight of the diamond cut: Quality of the cut (Fair/Good/Very Good/ Premium/Ideal) color: Diamond color, ranging from J (worst) to D (best) clarity: A measurement of how clear the diamond is (I1 (worst), SI1, SI2, VS1, VS2, VVS1, VVS2, IF (best)) depth: Total depth percentage = z / mean(x, y) = 2 * z / (x + y) (43?79) table: Width of top of diamond relative to widest point (43?95) price: Price in US dollars ($326?$18,823) x: Length in mm (0?10.74) y: Width in mm (0?58.9) z: Depth in mm (0?31.8)
diamonds_train
## # A tibble: 43,143 x 11
## carat cut color clarity depth table price x y z
## <dbl> <ord> <ord> <ord> <dbl> <dbl> <int> <dbl> <dbl> <dbl>
## 1 0.23 Ideal E SI2 61.5 55 326 3.95 3.98 2.43
## 2 0.21 Premium E SI1 59.8 61 326 3.89 3.84 2.31
## 3 0.23 Good E VS1 56.9 65 327 4.05 4.07 2.31
## 4 0.29 Premium I VS2 62.4 58 334 4.20 4.23 2.63
## 5 0.24 Very Good J VVS2 62.8 57 336 3.94 3.96 2.48
## 6 0.24 Very Good I VVS1 62.3 57 336 3.95 3.98 2.47
## 7 0.26 Very Good H SI1 61.9 55 337 4.07 4.11 2.53
## 8 0.22 Fair E VS2 65.1 61 337 3.87 3.78 2.49
## 9 0.23 Very Good H VS1 59.4 61 338 4.00 4.05 2.39
## 10 0.30 Good J SI1 64.0 55 339 4.25 4.28 2.73
## # ... with 43,133 more rows, and 1 more variables: diamond_id <int>
diamonds_test
## # A tibble: 10,797 x 11
## carat cut color clarity depth table price x y z
## <dbl> <ord> <ord> <ord> <dbl> <dbl> <int> <dbl> <dbl> <dbl>
## 1 3.40 Fair D I1 66.8 52 15964 9.42 9.34 6.27
## 2 0.90 Fair D SI2 64.7 59 3205 6.09 5.99 3.91
## 3 0.95 Fair D SI2 64.4 60 3384 6.06 6.02 3.89
## 4 1.00 Fair D SI2 65.2 56 3634 6.27 6.21 4.07
## 5 0.70 Fair D SI2 58.1 60 2358 5.79 5.82 3.37
## 6 1.04 Fair D SI2 64.9 56 4398 6.39 6.34 4.13
## 7 0.70 Fair D SI2 65.6 55 2167 5.59 5.50 3.64
## 8 1.03 Fair D SI2 66.4 56 3743 6.31 6.19 4.15
## 9 1.10 Fair D SI2 64.6 54 4725 6.56 6.49 4.22
## 10 2.01 Fair D SI2 59.4 66 15627 8.20 8.17 4.86
## # ... with 10,787 more rows, and 1 more variables: diamond_id <int>
dim(diamonds_train)
## [1] 43143 11
dim(diamonds_test)
## [1] 10797 11
head(diamonds_train)
## # A tibble: 6 x 11
## carat cut color clarity depth table price x y z
## <dbl> <ord> <ord> <ord> <dbl> <dbl> <int> <dbl> <dbl> <dbl>
## 1 0.23 Ideal E SI2 61.5 55 326 3.95 3.98 2.43
## 2 0.21 Premium E SI1 59.8 61 326 3.89 3.84 2.31
## 3 0.23 Good E VS1 56.9 65 327 4.05 4.07 2.31
## 4 0.29 Premium I VS2 62.4 58 334 4.20 4.23 2.63
## 5 0.24 Very Good J VVS2 62.8 57 336 3.94 3.96 2.48
## 6 0.24 Very Good I VVS1 62.3 57 336 3.95 3.98 2.47
## # ... with 1 more variables: diamond_id <int>
summary(diamonds_train)
## carat cut color clarity
## Min. :0.2000 Fair : 1285 D:5416 SI1 :10449
## 1st Qu.:0.4000 Good : 3923 E:7835 VS2 : 9806
## Median :0.7000 Very Good: 9662 F:7629 SI2 : 7354
## Mean :0.7978 Premium :11036 G:9037 VS1 : 6538
## 3rd Qu.:1.0400 Ideal :17237 H:6646 VVS2 : 4052
## Max. :5.0100 I:4336 VVS1 : 2923
## J:2244 (Other): 2021
## depth table price x
## Min. :43.00 Min. :44.00 Min. : 326 Min. : 0.000
## 1st Qu.:61.00 1st Qu.:56.00 1st Qu.: 951 1st Qu.: 4.710
## Median :61.80 Median :57.00 Median : 2401 Median : 5.700
## Mean :61.74 Mean :57.46 Mean : 3933 Mean : 5.731
## 3rd Qu.:62.50 3rd Qu.:59.00 3rd Qu.: 5327 3rd Qu.: 6.540
## Max. :79.00 Max. :95.00 Max. :18823 Max. :10.740
##
## y z diamond_id
## Min. : 0.000 Min. : 0.000 Min. : 1
## 1st Qu.: 4.720 1st Qu.: 2.910 1st Qu.:13476
## Median : 5.710 Median : 3.530 Median :26981
## Mean : 5.735 Mean : 3.539 Mean :26986
## 3rd Qu.: 6.540 3rd Qu.: 4.030 3rd Qu.:40518
## Max. :58.900 Max. :31.800 Max. :53940
##
str(diamonds_train)
## Classes 'tbl_df', 'tbl' and 'data.frame': 43143 obs. of 11 variables:
## $ carat : num 0.23 0.21 0.23 0.29 0.24 0.24 0.26 0.22 0.23 0.3 ...
## $ cut : Ord.factor w/ 5 levels "Fair"<"Good"<..: 5 4 2 4 3 3 3 1 3 2 ...
## $ color : Ord.factor w/ 7 levels "D"<"E"<"F"<"G"<..: 2 2 2 6 7 6 5 2 5 7 ...
## $ clarity : Ord.factor w/ 8 levels "I1"<"SI2"<"SI1"<..: 2 3 5 4 6 7 3 4 5 3 ...
## $ depth : num 61.5 59.8 56.9 62.4 62.8 62.3 61.9 65.1 59.4 64 ...
## $ table : num 55 61 65 58 57 57 55 61 61 55 ...
## $ price : int 326 326 327 334 336 336 337 337 338 339 ...
## $ x : num 3.95 3.89 4.05 4.2 3.94 3.95 4.07 3.87 4 4.25 ...
## $ y : num 3.98 3.84 4.07 4.23 3.96 3.98 4.11 3.78 4.05 4.28 ...
## $ z : num 2.43 2.31 2.31 2.63 2.48 2.47 2.53 2.49 2.39 2.73 ...
## $ diamond_id: int 1 2 3 4 6 7 8 9 10 11 ...
ggplot(diamonds,aes(x=carat,y=price))+geom_jitter()
(formula <- as.formula(price ~ carat + cut + color + clarity + depth + table + x + y + z))
## price ~ carat + cut + color + clarity + depth + table + x + y +
## z
model1 <- rpart(formula, data = diamonds_train, method = "anova")
model1
## n= 43143
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 43143 687673200000 3933.419
## 2) carat< 0.995 27919 34912140000 1634.795
## 4) y< 5.525 19911 5428982000 1056.168 *
## 5) y>=5.525 8008 6241528000 3073.487 *
## 3) carat>=0.995 15224 234721600000 8148.821
## 6) y< 7.195 10290 48143650000 6134.445
## 12) clarity=I1,SI2,SI1,VS2 7828 16158110000 5394.948 *
## 13) clarity=VS1,VVS2,VVS1,IF 2462 14093830000 8485.700 *
## 7) y>=7.195 4934 57745320000 12349.860
## 14) y< 7.855 3214 27840990000 10961.400 *
## 15) y>=7.855 1720 12130460000 14944.340 *
rpart.plot(model1, type = 4, digits = 3, fallen.leaves = TRUE)
prediction <- predict(model1, diamonds_test)
summary(prediction)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 1056 1056 3073 3942 5395 14944
MAE <- function(actual, predicted) {mean(abs(actual - predicted))}
MAE(diamonds_test$price, prediction)
## [1] 889.7092
References:R - Regression Trees - CART Tree-Based Models_