set.seed(503)
library(tidyverse)
diamonds_test <- diamonds %>% mutate(diamond_id = row_number()) %>%
group_by(cut, color, clarity) %>% sample_frac(0.2) %>% ungroup()
diamonds_train <- anti_join(diamonds %>% mutate(diamond_id = row_number()),
diamonds_test, by = "diamond_id")
dim(diamonds)
## [1] 53940 10
dim(diamonds_train)
## [1] 43143 11
dim(diamonds_test)
## [1] 10797 11
head(diamonds_train)
## # A tibble: 6 x 11
## carat cut color clarity depth table price x y z
## <dbl> <ord> <ord> <ord> <dbl> <dbl> <int> <dbl> <dbl> <dbl>
## 1 0.23 Ideal E SI2 61.5 55 326 3.95 3.98 2.43
## 2 0.21 Premium E SI1 59.8 61 326 3.89 3.84 2.31
## 3 0.23 Good E VS1 56.9 65 327 4.05 4.07 2.31
## 4 0.29 Premium I VS2 62.4 58 334 4.20 4.23 2.63
## 5 0.24 Very Good J VVS2 62.8 57 336 3.94 3.96 2.48
## 6 0.24 Very Good I VVS1 62.3 57 336 3.95 3.98 2.47
## # ... with 1 more variables: diamond_id <int>
library(dplyr)
library(rpart)
library(rpart.plot)
glimpse(diamonds_train)
## Observations: 43,143
## Variables: 11
## $ carat <dbl> 0.23, 0.21, 0.23, 0.29, 0.24, 0.24, 0.26, 0.22, 0.2...
## $ cut <ord> Ideal, Premium, Good, Premium, Very Good, Very Good...
## $ color <ord> E, E, E, I, J, I, H, E, H, J, J, F, J, E, E, I, J, ...
## $ clarity <ord> SI2, SI1, VS1, VS2, VVS2, VVS1, SI1, VS2, VS1, SI1,...
## $ depth <dbl> 61.5, 59.8, 56.9, 62.4, 62.8, 62.3, 61.9, 65.1, 59....
## $ table <dbl> 55, 61, 65, 58, 57, 57, 55, 61, 61, 55, 56, 61, 54,...
## $ price <int> 326, 326, 327, 334, 336, 336, 337, 337, 338, 339, 3...
## $ x <dbl> 3.95, 3.89, 4.05, 4.20, 3.94, 3.95, 4.07, 3.87, 4.0...
## $ y <dbl> 3.98, 3.84, 4.07, 4.23, 3.96, 3.98, 4.11, 3.78, 4.0...
## $ z <dbl> 2.43, 2.31, 2.31, 2.63, 2.48, 2.47, 2.53, 2.49, 2.3...
## $ diamond_id <int> 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,...
summary(diamonds_train)
## carat cut color clarity
## Min. :0.2000 Fair : 1285 D:5416 SI1 :10449
## 1st Qu.:0.4000 Good : 3923 E:7835 VS2 : 9806
## Median :0.7000 Very Good: 9662 F:7629 SI2 : 7354
## Mean :0.7978 Premium :11036 G:9037 VS1 : 6538
## 3rd Qu.:1.0400 Ideal :17237 H:6646 VVS2 : 4052
## Max. :5.0100 I:4336 VVS1 : 2923
## J:2244 (Other): 2021
## depth table price x
## Min. :43.00 Min. :44.00 Min. : 326 Min. : 0.000
## 1st Qu.:61.00 1st Qu.:56.00 1st Qu.: 951 1st Qu.: 4.710
## Median :61.80 Median :57.00 Median : 2401 Median : 5.700
## Mean :61.74 Mean :57.46 Mean : 3933 Mean : 5.731
## 3rd Qu.:62.50 3rd Qu.:59.00 3rd Qu.: 5327 3rd Qu.: 6.540
## Max. :79.00 Max. :95.00 Max. :18823 Max. :10.740
##
## y z diamond_id
## Min. : 0.000 Min. : 0.000 Min. : 1
## 1st Qu.: 4.720 1st Qu.: 2.910 1st Qu.:13476
## Median : 5.710 Median : 3.530 Median :26981
## Mean : 5.735 Mean : 3.539 Mean :26986
## 3rd Qu.: 6.540 3rd Qu.: 4.030 3rd Qu.:40518
## Max. :58.900 Max. :31.800 Max. :53940
##
(formula <- as.formula(price ~ carat + cut + color + clarity + depth + table + x + y + z))
## price ~ carat + cut + color + clarity + depth + table + x + y +
## z
m1 <- rpart(formula, data = diamonds_train, method = "anova")
m1
## n= 43143
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 43143 687673200000 3933.419
## 2) carat< 0.995 27919 34912140000 1634.795
## 4) y< 5.525 19911 5428982000 1056.168 *
## 5) y>=5.525 8008 6241528000 3073.487 *
## 3) carat>=0.995 15224 234721600000 8148.821
## 6) y< 7.195 10290 48143650000 6134.445
## 12) clarity=I1,SI2,SI1,VS2 7828 16158110000 5394.948 *
## 13) clarity=VS1,VVS2,VVS1,IF 2462 14093830000 8485.700 *
## 7) y>=7.195 4934 57745320000 12349.860
## 14) y< 7.855 3214 27840990000 10961.400 *
## 15) y>=7.855 1720 12130460000 14944.340 *
rpart.plot(m1, type = 4, digits = 3, fallen.leaves = TRUE)
rpart.plot(m1, type = 3, digits = 2, fallen.leaves = FALSE)
p1 <- predict(m1, diamonds_test)
summary(p1)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 1056 1056 3073 3942 5395 14944
MAE <- function(actual, predicted) {mean(abs(actual - predicted))}
MAE(diamonds_test$price, p1)
## [1] 889.7092
RMSE <- function(actual, predicted) {sqrt(mean((actual - predicted)^2))}
RMSE(diamonds_test$price, p1)
## [1] 1397.941
formula
## price ~ carat + cut + color + clarity + depth + table + x + y +
## z
m2 <- lm(formula, diamonds_train)
p2 <- predict(m2, diamonds_test)
MAE(diamonds_test$price, p2)
## [1] 745.5789
RMSE(diamonds_test$price, p2)
## [1] 1167.965
Conclusion : Second model is predict better than Crat model based on MAE and RMSE