#Set the random seed
set.seed(503)
# Install necessary packages
library(tidyverse)
## ── Attaching packages ────────────────────────────────── tidyverse 1.2.1 ──
## ✔ ggplot2 2.2.1 ✔ purrr 0.2.4
## ✔ tibble 1.3.4 ✔ dplyr 0.7.4
## ✔ tidyr 0.7.2 ✔ stringr 1.2.0
## ✔ readr 1.1.1 ✔ forcats 0.2.0
## ── Conflicts ───────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
library(rpart)
library(rpart.plot)
library(rattle)
## Rattle: A free graphical interface for data science with R.
## Version 5.1.0 Copyright (c) 2006-2017 Togaware Pty Ltd.
## Type 'rattle()' to shake, rattle, and roll your data.
# Prepare test and train data set
diamonds_test <- diamonds %>% mutate(diamond_id = row_number()) %>%
group_by(cut, color, clarity) %>% sample_frac(0.2) %>% ungroup()
diamonds_train <- anti_join(diamonds %>% mutate(diamond_id = row_number()),
diamonds_test, by = "diamond_id")
diamonds_train
## # A tibble: 43,143 x 11
## carat cut color clarity depth table price x y z
## <dbl> <ord> <ord> <ord> <dbl> <dbl> <int> <dbl> <dbl> <dbl>
## 1 0.23 Ideal E SI2 61.5 55 326 3.95 3.98 2.43
## 2 0.21 Premium E SI1 59.8 61 326 3.89 3.84 2.31
## 3 0.23 Good E VS1 56.9 65 327 4.05 4.07 2.31
## 4 0.29 Premium I VS2 62.4 58 334 4.20 4.23 2.63
## 5 0.24 Very Good J VVS2 62.8 57 336 3.94 3.96 2.48
## 6 0.24 Very Good I VVS1 62.3 57 336 3.95 3.98 2.47
## 7 0.26 Very Good H SI1 61.9 55 337 4.07 4.11 2.53
## 8 0.22 Fair E VS2 65.1 61 337 3.87 3.78 2.49
## 9 0.23 Very Good H VS1 59.4 61 338 4.00 4.05 2.39
## 10 0.30 Good J SI1 64.0 55 339 4.25 4.28 2.73
## # ... with 43,133 more rows, and 1 more variables: diamond_id <int>
diamonds_test
## # A tibble: 10,797 x 11
## carat cut color clarity depth table price x y z
## <dbl> <ord> <ord> <ord> <dbl> <dbl> <int> <dbl> <dbl> <dbl>
## 1 3.40 Fair D I1 66.8 52 15964 9.42 9.34 6.27
## 2 0.90 Fair D SI2 64.7 59 3205 6.09 5.99 3.91
## 3 0.95 Fair D SI2 64.4 60 3384 6.06 6.02 3.89
## 4 1.00 Fair D SI2 65.2 56 3634 6.27 6.21 4.07
## 5 0.70 Fair D SI2 58.1 60 2358 5.79 5.82 3.37
## 6 1.04 Fair D SI2 64.9 56 4398 6.39 6.34 4.13
## 7 0.70 Fair D SI2 65.6 55 2167 5.59 5.50 3.64
## 8 1.03 Fair D SI2 66.4 56 3743 6.31 6.19 4.15
## 9 1.10 Fair D SI2 64.6 54 4725 6.56 6.49 4.22
## 10 2.01 Fair D SI2 59.4 66 15627 8.20 8.17 4.86
## # ... with 10,787 more rows, and 1 more variables: diamond_id <int>
Exploratory Analysis
nrow(diamonds)
## [1] 53940
ncol(diamonds)
## [1] 10
dim(diamonds_train)
## [1] 43143 11
dim(diamonds_test)
## [1] 10797 11
summary(diamonds_train)
## carat cut color clarity
## Min. :0.2000 Fair : 1285 D:5416 SI1 :10449
## 1st Qu.:0.4000 Good : 3923 E:7835 VS2 : 9806
## Median :0.7000 Very Good: 9662 F:7629 SI2 : 7354
## Mean :0.7978 Premium :11036 G:9037 VS1 : 6538
## 3rd Qu.:1.0400 Ideal :17237 H:6646 VVS2 : 4052
## Max. :5.0100 I:4336 VVS1 : 2923
## J:2244 (Other): 2021
## depth table price x
## Min. :43.00 Min. :44.00 Min. : 326 Min. : 0.000
## 1st Qu.:61.00 1st Qu.:56.00 1st Qu.: 951 1st Qu.: 4.710
## Median :61.80 Median :57.00 Median : 2401 Median : 5.700
## Mean :61.74 Mean :57.46 Mean : 3933 Mean : 5.731
## 3rd Qu.:62.50 3rd Qu.:59.00 3rd Qu.: 5327 3rd Qu.: 6.540
## Max. :79.00 Max. :95.00 Max. :18823 Max. :10.740
##
## y z diamond_id
## Min. : 0.000 Min. : 0.000 Min. : 1
## 1st Qu.: 4.720 1st Qu.: 2.910 1st Qu.:13476
## Median : 5.710 Median : 3.530 Median :26981
## Mean : 5.735 Mean : 3.539 Mean :26986
## 3rd Qu.: 6.540 3rd Qu.: 4.030 3rd Qu.:40518
## Max. :58.900 Max. :31.800 Max. :53940
##
summary(diamonds_test)
## carat cut color clarity depth
## Min. :0.2000 Fair : 325 D:1359 SI1 :2616 Min. :53.10
## 1st Qu.:0.4000 Good : 983 E:1962 VS2 :2452 1st Qu.:61.10
## Median :0.7000 Very Good:2420 F:1913 SI2 :1840 Median :61.80
## Mean :0.7986 Premium :2755 G:2255 VS1 :1633 Mean :61.77
## 3rd Qu.:1.0400 Ideal :4314 H:1658 VVS2 :1014 3rd Qu.:62.50
## Max. :4.5000 I:1086 VVS1 : 732 Max. :73.60
## J: 564 (Other): 510
## table price x y
## Min. :43.00 Min. : 335 Min. : 0.000 Min. : 0.000
## 1st Qu.:56.00 1st Qu.: 949 1st Qu.: 4.720 1st Qu.: 4.720
## Median :57.00 Median : 2407 Median : 5.700 Median : 5.710
## Mean :57.45 Mean : 3930 Mean : 5.731 Mean : 5.733
## 3rd Qu.:59.00 3rd Qu.: 5317 3rd Qu.: 6.540 3rd Qu.: 6.540
## Max. :79.00 Max. :18818 Max. :10.230 Max. :10.160
##
## z diamond_id
## Min. :0.00 Min. : 5
## 1st Qu.:2.91 1st Qu.:13511
## Median :3.52 Median :26914
## Mean :3.54 Mean :26909
## 3rd Qu.:4.04 3rd Qu.:40191
## Max. :6.72 Max. :53939
##
glimpse(diamonds_train)
## Observations: 43,143
## Variables: 11
## $ carat <dbl> 0.23, 0.21, 0.23, 0.29, 0.24, 0.24, 0.26, 0.22, 0.2...
## $ cut <ord> Ideal, Premium, Good, Premium, Very Good, Very Good...
## $ color <ord> E, E, E, I, J, I, H, E, H, J, J, F, J, E, E, I, J, ...
## $ clarity <ord> SI2, SI1, VS1, VS2, VVS2, VVS1, SI1, VS2, VS1, SI1,...
## $ depth <dbl> 61.5, 59.8, 56.9, 62.4, 62.8, 62.3, 61.9, 65.1, 59....
## $ table <dbl> 55, 61, 65, 58, 57, 57, 55, 61, 61, 55, 56, 61, 54,...
## $ price <int> 326, 326, 327, 334, 336, 336, 337, 337, 338, 339, 3...
## $ x <dbl> 3.95, 3.89, 4.05, 4.20, 3.94, 3.95, 4.07, 3.87, 4.0...
## $ y <dbl> 3.98, 3.84, 4.07, 4.23, 3.96, 3.98, 4.11, 3.78, 4.0...
## $ z <dbl> 2.43, 2.31, 2.31, 2.63, 2.48, 2.47, 2.53, 2.49, 2.3...
## $ diamond_id <int> 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,...
glimpse(diamonds_test)
## Observations: 10,797
## Variables: 11
## $ carat <dbl> 3.40, 0.90, 0.95, 1.00, 0.70, 1.04, 0.70, 1.03, 1.1...
## $ cut <ord> Fair, Fair, Fair, Fair, Fair, Fair, Fair, Fair, Fai...
## $ color <ord> D, D, D, D, D, D, D, D, D, D, D, D, D, D, D, D, D, ...
## $ clarity <ord> I1, SI2, SI2, SI2, SI2, SI2, SI2, SI2, SI2, SI2, SI...
## $ depth <dbl> 66.8, 64.7, 64.4, 65.2, 58.1, 64.9, 65.6, 66.4, 64....
## $ table <dbl> 52, 59, 60, 56, 60, 56, 55, 56, 54, 66, 58, 57, 57,...
## $ price <int> 15964, 3205, 3384, 3634, 2358, 4398, 2167, 3743, 47...
## $ x <dbl> 9.42, 6.09, 6.06, 6.27, 5.79, 6.39, 5.59, 6.31, 6.5...
## $ y <dbl> 9.34, 5.99, 6.02, 6.21, 5.82, 6.34, 5.50, 6.19, 6.4...
## $ z <dbl> 6.27, 3.91, 3.89, 4.07, 3.37, 4.13, 3.64, 4.15, 4.2...
## $ diamond_id <int> 26432, 2538, 3428, 4517, 51268, 8352, 49817, 5007, ...
str(diamonds_train)
## Classes 'tbl_df', 'tbl' and 'data.frame': 43143 obs. of 11 variables:
## $ carat : num 0.23 0.21 0.23 0.29 0.24 0.24 0.26 0.22 0.23 0.3 ...
## $ cut : Ord.factor w/ 5 levels "Fair"<"Good"<..: 5 4 2 4 3 3 3 1 3 2 ...
## $ color : Ord.factor w/ 7 levels "D"<"E"<"F"<"G"<..: 2 2 2 6 7 6 5 2 5 7 ...
## $ clarity : Ord.factor w/ 8 levels "I1"<"SI2"<"SI1"<..: 2 3 5 4 6 7 3 4 5 3 ...
## $ depth : num 61.5 59.8 56.9 62.4 62.8 62.3 61.9 65.1 59.4 64 ...
## $ table : num 55 61 65 58 57 57 55 61 61 55 ...
## $ price : int 326 326 327 334 336 336 337 337 338 339 ...
## $ x : num 3.95 3.89 4.05 4.2 3.94 3.95 4.07 3.87 4 4.25 ...
## $ y : num 3.98 3.84 4.07 4.23 3.96 3.98 4.11 3.78 4.05 4.28 ...
## $ z : num 2.43 2.31 2.31 2.63 2.48 2.47 2.53 2.49 2.39 2.73 ...
## $ diamond_id: int 1 2 3 4 6 7 8 9 10 11 ...
str(diamonds_test)
## Classes 'tbl_df', 'tbl' and 'data.frame': 10797 obs. of 11 variables:
## $ carat : num 3.4 0.9 0.95 1 0.7 1.04 0.7 1.03 1.1 2.01 ...
## $ cut : Ord.factor w/ 5 levels "Fair"<"Good"<..: 1 1 1 1 1 1 1 1 1 1 ...
## $ color : Ord.factor w/ 7 levels "D"<"E"<"F"<"G"<..: 1 1 1 1 1 1 1 1 1 1 ...
## $ clarity : Ord.factor w/ 8 levels "I1"<"SI2"<"SI1"<..: 1 2 2 2 2 2 2 2 2 2 ...
## $ depth : num 66.8 64.7 64.4 65.2 58.1 64.9 65.6 66.4 64.6 59.4 ...
## $ table : num 52 59 60 56 60 56 55 56 54 66 ...
## $ price : int 15964 3205 3384 3634 2358 4398 2167 3743 4725 15627 ...
## $ x : num 9.42 6.09 6.06 6.27 5.79 6.39 5.59 6.31 6.56 8.2 ...
## $ y : num 9.34 5.99 6.02 6.21 5.82 6.34 5.5 6.19 6.49 8.17 ...
## $ z : num 6.27 3.91 3.89 4.07 3.37 4.13 3.64 4.15 4.22 4.86 ...
## $ diamond_id: int 26432 2538 3428 4517 51268 8352 49817 5007 10151 26223 ...
Let’s see the statistical characteristics and distribution of prices
summary(diamonds$price)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 326 950 2401 3933 5324 18823
# Use histogram to see the distribution of price.
# As you see price distribution is right - skewed, mean of price seems to be inflated by a small amount of high prices
ggplot(diamonds, aes(x = price)) +
geom_histogram(fill = "blue", binwidth = 300) +
scale_x_continuous(breaks = seq(0, 20000, 1000)) +
theme(axis.text.x = element_text(angle = 90)) +
xlab("Price") + ylab("Count")
# Prepare correlation data for pca analysis
diamonds_train_pca_cor <- cor(diamonds_train%>%select(carat,depth, table,x,y,z))
diamonds_train_pca_cor
## carat depth table x y z
## carat 1.00000000 0.02595373 0.1841014 0.97586401 0.94678143 0.94991249
## depth 0.02595373 1.00000000 -0.3011976 -0.02848763 -0.03223782 0.09166475
## table 0.18410140 -0.30119756 1.0000000 0.19834246 0.18519210 0.15236302
## x 0.97586401 -0.02848763 0.1983425 1.00000000 0.96930921 0.96640370
## y 0.94678143 -0.03223782 0.1851921 0.96930921 1.00000000 0.94381041
## z 0.94991249 0.09166475 0.1523630 0.96640370 0.94381041 1.00000000
# Run PCA
diamonds_train_pca_result<-princomp(diamonds_train%>%select(carat,depth, table,x,y,z),cor=T)
# See the PCA results
summary(diamonds_train_pca_result,loadings=TRUE)
## Importance of components:
## Comp.1 Comp.2 Comp.3 Comp.4
## Standard deviation 1.9800490 1.1349816 0.8234788 0.227797884
## Proportion of Variance 0.6534323 0.2146972 0.1130196 0.008648646
## Cumulative Proportion 0.6534323 0.8681295 0.9811491 0.989797745
## Comp.5 Comp.6
## Standard deviation 0.216370481 0.119988944
## Proportion of Variance 0.007802697 0.002399558
## Cumulative Proportion 0.997600442 1.000000000
##
## Loadings:
## Comp.1 Comp.2 Comp.3 Comp.4 Comp.5 Comp.6
## carat -0.496 0.651 -0.422 0.385
## depth 0.734 0.671
## table -0.123 -0.669 0.733
## x -0.501 0.112 -0.855
## y -0.494 -0.749 -0.381 0.205
## z -0.493 0.103 0.818 0.276
# Run CART and try to improve the model with these extra properties.
rpart.control(
#minsplit = 20, #Min # of items that should be in a node to do a split
#minbucket = round(minsplit/3), #Minimum number of items in a final node
cp = 0.05, #Complexity parameter (min improvement to generate a split)
maxcompete = 4, #Not related to model. Some printout for analyses
maxsurrogate = 5, #Used to deal with missing values
usesurrogate = 2, #Used to deal with missing values
xval = 20, #Number of cross validations
surrogatestyle = 0, #Used to deal with missing values
maxdepth = 10 #Tree depth
)
## $minsplit
## [1] 20
##
## $minbucket
## [1] 7
##
## $cp
## [1] 0.05
##
## $maxcompete
## [1] 4
##
## $maxsurrogate
## [1] 5
##
## $usesurrogate
## [1] 2
##
## $surrogatestyle
## [1] 0
##
## $maxdepth
## [1] 10
##
## $xval
## [1] 20
# Run CART and plot the model
diamondsprice_model <- rpart(price ~ ., data=diamonds_train %>% select(-diamond_id))
fancyRpartPlot(diamondsprice_model)
# See in sample prediction results
diamonds_predict <- predict(diamondsprice_model,newdata=diamonds_test %>% select(-diamond_id))
str(diamonds_predict)
## Named num [1:10797] 14944 3073 3073 5395 3073 ...
## - attr(*, "names")= chr [1:10797] "1" "2" "3" "4" ...
# Compare the model predictions with the test data
difference <- as.integer(diamonds_test$price) - as.integer(diamonds_predict)
summary(difference)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -9893.00 -530.00 -176.00 -11.04 528.00 12640.00