#Set the random seed
set.seed(503)
# Install necessary packages

library(tidyverse)
## ── Attaching packages ────────────────────────────────── tidyverse 1.2.1 ──
## ✔ ggplot2 2.2.1     ✔ purrr   0.2.4
## ✔ tibble  1.3.4     ✔ dplyr   0.7.4
## ✔ tidyr   0.7.2     ✔ stringr 1.2.0
## ✔ readr   1.1.1     ✔ forcats 0.2.0
## ── Conflicts ───────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
library(rpart)
library(rpart.plot)
library(rattle)
## Rattle: A free graphical interface for data science with R.
## Version 5.1.0 Copyright (c) 2006-2017 Togaware Pty Ltd.
## Type 'rattle()' to shake, rattle, and roll your data.
# Prepare test and train data set

diamonds_test <- diamonds %>% mutate(diamond_id = row_number()) %>% 
    group_by(cut, color, clarity) %>% sample_frac(0.2) %>% ungroup()

diamonds_train <- anti_join(diamonds %>% mutate(diamond_id = row_number()), 
    diamonds_test, by = "diamond_id")

diamonds_train
## # A tibble: 43,143 x 11
##    carat       cut color clarity depth table price     x     y     z
##    <dbl>     <ord> <ord>   <ord> <dbl> <dbl> <int> <dbl> <dbl> <dbl>
##  1  0.23     Ideal     E     SI2  61.5    55   326  3.95  3.98  2.43
##  2  0.21   Premium     E     SI1  59.8    61   326  3.89  3.84  2.31
##  3  0.23      Good     E     VS1  56.9    65   327  4.05  4.07  2.31
##  4  0.29   Premium     I     VS2  62.4    58   334  4.20  4.23  2.63
##  5  0.24 Very Good     J    VVS2  62.8    57   336  3.94  3.96  2.48
##  6  0.24 Very Good     I    VVS1  62.3    57   336  3.95  3.98  2.47
##  7  0.26 Very Good     H     SI1  61.9    55   337  4.07  4.11  2.53
##  8  0.22      Fair     E     VS2  65.1    61   337  3.87  3.78  2.49
##  9  0.23 Very Good     H     VS1  59.4    61   338  4.00  4.05  2.39
## 10  0.30      Good     J     SI1  64.0    55   339  4.25  4.28  2.73
## # ... with 43,133 more rows, and 1 more variables: diamond_id <int>
diamonds_test
## # A tibble: 10,797 x 11
##    carat   cut color clarity depth table price     x     y     z
##    <dbl> <ord> <ord>   <ord> <dbl> <dbl> <int> <dbl> <dbl> <dbl>
##  1  3.40  Fair     D      I1  66.8    52 15964  9.42  9.34  6.27
##  2  0.90  Fair     D     SI2  64.7    59  3205  6.09  5.99  3.91
##  3  0.95  Fair     D     SI2  64.4    60  3384  6.06  6.02  3.89
##  4  1.00  Fair     D     SI2  65.2    56  3634  6.27  6.21  4.07
##  5  0.70  Fair     D     SI2  58.1    60  2358  5.79  5.82  3.37
##  6  1.04  Fair     D     SI2  64.9    56  4398  6.39  6.34  4.13
##  7  0.70  Fair     D     SI2  65.6    55  2167  5.59  5.50  3.64
##  8  1.03  Fair     D     SI2  66.4    56  3743  6.31  6.19  4.15
##  9  1.10  Fair     D     SI2  64.6    54  4725  6.56  6.49  4.22
## 10  2.01  Fair     D     SI2  59.4    66 15627  8.20  8.17  4.86
## # ... with 10,787 more rows, and 1 more variables: diamond_id <int>

Exploratory Analysis

nrow(diamonds)
## [1] 53940
ncol(diamonds)
## [1] 10
dim(diamonds_train)
## [1] 43143    11
dim(diamonds_test)
## [1] 10797    11
summary(diamonds_train)
##      carat               cut        color       clarity     
##  Min.   :0.2000   Fair     : 1285   D:5416   SI1    :10449  
##  1st Qu.:0.4000   Good     : 3923   E:7835   VS2    : 9806  
##  Median :0.7000   Very Good: 9662   F:7629   SI2    : 7354  
##  Mean   :0.7978   Premium  :11036   G:9037   VS1    : 6538  
##  3rd Qu.:1.0400   Ideal    :17237   H:6646   VVS2   : 4052  
##  Max.   :5.0100                     I:4336   VVS1   : 2923  
##                                     J:2244   (Other): 2021  
##      depth           table           price             x         
##  Min.   :43.00   Min.   :44.00   Min.   :  326   Min.   : 0.000  
##  1st Qu.:61.00   1st Qu.:56.00   1st Qu.:  951   1st Qu.: 4.710  
##  Median :61.80   Median :57.00   Median : 2401   Median : 5.700  
##  Mean   :61.74   Mean   :57.46   Mean   : 3933   Mean   : 5.731  
##  3rd Qu.:62.50   3rd Qu.:59.00   3rd Qu.: 5327   3rd Qu.: 6.540  
##  Max.   :79.00   Max.   :95.00   Max.   :18823   Max.   :10.740  
##                                                                  
##        y                z            diamond_id   
##  Min.   : 0.000   Min.   : 0.000   Min.   :    1  
##  1st Qu.: 4.720   1st Qu.: 2.910   1st Qu.:13476  
##  Median : 5.710   Median : 3.530   Median :26981  
##  Mean   : 5.735   Mean   : 3.539   Mean   :26986  
##  3rd Qu.: 6.540   3rd Qu.: 4.030   3rd Qu.:40518  
##  Max.   :58.900   Max.   :31.800   Max.   :53940  
## 
summary(diamonds_test)
##      carat               cut       color       clarity         depth      
##  Min.   :0.2000   Fair     : 325   D:1359   SI1    :2616   Min.   :53.10  
##  1st Qu.:0.4000   Good     : 983   E:1962   VS2    :2452   1st Qu.:61.10  
##  Median :0.7000   Very Good:2420   F:1913   SI2    :1840   Median :61.80  
##  Mean   :0.7986   Premium  :2755   G:2255   VS1    :1633   Mean   :61.77  
##  3rd Qu.:1.0400   Ideal    :4314   H:1658   VVS2   :1014   3rd Qu.:62.50  
##  Max.   :4.5000                    I:1086   VVS1   : 732   Max.   :73.60  
##                                    J: 564   (Other): 510                  
##      table           price             x                y         
##  Min.   :43.00   Min.   :  335   Min.   : 0.000   Min.   : 0.000  
##  1st Qu.:56.00   1st Qu.:  949   1st Qu.: 4.720   1st Qu.: 4.720  
##  Median :57.00   Median : 2407   Median : 5.700   Median : 5.710  
##  Mean   :57.45   Mean   : 3930   Mean   : 5.731   Mean   : 5.733  
##  3rd Qu.:59.00   3rd Qu.: 5317   3rd Qu.: 6.540   3rd Qu.: 6.540  
##  Max.   :79.00   Max.   :18818   Max.   :10.230   Max.   :10.160  
##                                                                   
##        z          diamond_id   
##  Min.   :0.00   Min.   :    5  
##  1st Qu.:2.91   1st Qu.:13511  
##  Median :3.52   Median :26914  
##  Mean   :3.54   Mean   :26909  
##  3rd Qu.:4.04   3rd Qu.:40191  
##  Max.   :6.72   Max.   :53939  
## 
glimpse(diamonds_train)
## Observations: 43,143
## Variables: 11
## $ carat      <dbl> 0.23, 0.21, 0.23, 0.29, 0.24, 0.24, 0.26, 0.22, 0.2...
## $ cut        <ord> Ideal, Premium, Good, Premium, Very Good, Very Good...
## $ color      <ord> E, E, E, I, J, I, H, E, H, J, J, F, J, E, E, I, J, ...
## $ clarity    <ord> SI2, SI1, VS1, VS2, VVS2, VVS1, SI1, VS2, VS1, SI1,...
## $ depth      <dbl> 61.5, 59.8, 56.9, 62.4, 62.8, 62.3, 61.9, 65.1, 59....
## $ table      <dbl> 55, 61, 65, 58, 57, 57, 55, 61, 61, 55, 56, 61, 54,...
## $ price      <int> 326, 326, 327, 334, 336, 336, 337, 337, 338, 339, 3...
## $ x          <dbl> 3.95, 3.89, 4.05, 4.20, 3.94, 3.95, 4.07, 3.87, 4.0...
## $ y          <dbl> 3.98, 3.84, 4.07, 4.23, 3.96, 3.98, 4.11, 3.78, 4.0...
## $ z          <dbl> 2.43, 2.31, 2.31, 2.63, 2.48, 2.47, 2.53, 2.49, 2.3...
## $ diamond_id <int> 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,...
glimpse(diamonds_test)
## Observations: 10,797
## Variables: 11
## $ carat      <dbl> 3.40, 0.90, 0.95, 1.00, 0.70, 1.04, 0.70, 1.03, 1.1...
## $ cut        <ord> Fair, Fair, Fair, Fair, Fair, Fair, Fair, Fair, Fai...
## $ color      <ord> D, D, D, D, D, D, D, D, D, D, D, D, D, D, D, D, D, ...
## $ clarity    <ord> I1, SI2, SI2, SI2, SI2, SI2, SI2, SI2, SI2, SI2, SI...
## $ depth      <dbl> 66.8, 64.7, 64.4, 65.2, 58.1, 64.9, 65.6, 66.4, 64....
## $ table      <dbl> 52, 59, 60, 56, 60, 56, 55, 56, 54, 66, 58, 57, 57,...
## $ price      <int> 15964, 3205, 3384, 3634, 2358, 4398, 2167, 3743, 47...
## $ x          <dbl> 9.42, 6.09, 6.06, 6.27, 5.79, 6.39, 5.59, 6.31, 6.5...
## $ y          <dbl> 9.34, 5.99, 6.02, 6.21, 5.82, 6.34, 5.50, 6.19, 6.4...
## $ z          <dbl> 6.27, 3.91, 3.89, 4.07, 3.37, 4.13, 3.64, 4.15, 4.2...
## $ diamond_id <int> 26432, 2538, 3428, 4517, 51268, 8352, 49817, 5007, ...
str(diamonds_train)
## Classes 'tbl_df', 'tbl' and 'data.frame':    43143 obs. of  11 variables:
##  $ carat     : num  0.23 0.21 0.23 0.29 0.24 0.24 0.26 0.22 0.23 0.3 ...
##  $ cut       : Ord.factor w/ 5 levels "Fair"<"Good"<..: 5 4 2 4 3 3 3 1 3 2 ...
##  $ color     : Ord.factor w/ 7 levels "D"<"E"<"F"<"G"<..: 2 2 2 6 7 6 5 2 5 7 ...
##  $ clarity   : Ord.factor w/ 8 levels "I1"<"SI2"<"SI1"<..: 2 3 5 4 6 7 3 4 5 3 ...
##  $ depth     : num  61.5 59.8 56.9 62.4 62.8 62.3 61.9 65.1 59.4 64 ...
##  $ table     : num  55 61 65 58 57 57 55 61 61 55 ...
##  $ price     : int  326 326 327 334 336 336 337 337 338 339 ...
##  $ x         : num  3.95 3.89 4.05 4.2 3.94 3.95 4.07 3.87 4 4.25 ...
##  $ y         : num  3.98 3.84 4.07 4.23 3.96 3.98 4.11 3.78 4.05 4.28 ...
##  $ z         : num  2.43 2.31 2.31 2.63 2.48 2.47 2.53 2.49 2.39 2.73 ...
##  $ diamond_id: int  1 2 3 4 6 7 8 9 10 11 ...
str(diamonds_test)
## Classes 'tbl_df', 'tbl' and 'data.frame':    10797 obs. of  11 variables:
##  $ carat     : num  3.4 0.9 0.95 1 0.7 1.04 0.7 1.03 1.1 2.01 ...
##  $ cut       : Ord.factor w/ 5 levels "Fair"<"Good"<..: 1 1 1 1 1 1 1 1 1 1 ...
##  $ color     : Ord.factor w/ 7 levels "D"<"E"<"F"<"G"<..: 1 1 1 1 1 1 1 1 1 1 ...
##  $ clarity   : Ord.factor w/ 8 levels "I1"<"SI2"<"SI1"<..: 1 2 2 2 2 2 2 2 2 2 ...
##  $ depth     : num  66.8 64.7 64.4 65.2 58.1 64.9 65.6 66.4 64.6 59.4 ...
##  $ table     : num  52 59 60 56 60 56 55 56 54 66 ...
##  $ price     : int  15964 3205 3384 3634 2358 4398 2167 3743 4725 15627 ...
##  $ x         : num  9.42 6.09 6.06 6.27 5.79 6.39 5.59 6.31 6.56 8.2 ...
##  $ y         : num  9.34 5.99 6.02 6.21 5.82 6.34 5.5 6.19 6.49 8.17 ...
##  $ z         : num  6.27 3.91 3.89 4.07 3.37 4.13 3.64 4.15 4.22 4.86 ...
##  $ diamond_id: int  26432 2538 3428 4517 51268 8352 49817 5007 10151 26223 ...

Let’s see the statistical characteristics and distribution of prices

summary(diamonds$price)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##     326     950    2401    3933    5324   18823
# Use histogram to see the distribution of price.
# As you see price distribution is right - skewed, mean of price seems to be inflated by a small amount of high prices 

ggplot(diamonds, aes(x = price)) + 
  geom_histogram(fill = "blue", binwidth = 300) + 
  scale_x_continuous(breaks = seq(0, 20000, 1000)) + 
  theme(axis.text.x = element_text(angle = 90)) + 
  xlab("Price") + ylab("Count")

# Prepare correlation data for pca analysis

diamonds_train_pca_cor <- cor(diamonds_train%>%select(carat,depth, table,x,y,z))

diamonds_train_pca_cor
##            carat       depth      table           x           y          z
## carat 1.00000000  0.02595373  0.1841014  0.97586401  0.94678143 0.94991249
## depth 0.02595373  1.00000000 -0.3011976 -0.02848763 -0.03223782 0.09166475
## table 0.18410140 -0.30119756  1.0000000  0.19834246  0.18519210 0.15236302
## x     0.97586401 -0.02848763  0.1983425  1.00000000  0.96930921 0.96640370
## y     0.94678143 -0.03223782  0.1851921  0.96930921  1.00000000 0.94381041
## z     0.94991249  0.09166475  0.1523630  0.96640370  0.94381041 1.00000000
# Run PCA 

diamonds_train_pca_result<-princomp(diamonds_train%>%select(carat,depth, table,x,y,z),cor=T)

# See the PCA results

summary(diamonds_train_pca_result,loadings=TRUE)
## Importance of components:
##                           Comp.1    Comp.2    Comp.3      Comp.4
## Standard deviation     1.9800490 1.1349816 0.8234788 0.227797884
## Proportion of Variance 0.6534323 0.2146972 0.1130196 0.008648646
## Cumulative Proportion  0.6534323 0.8681295 0.9811491 0.989797745
##                             Comp.5      Comp.6
## Standard deviation     0.216370481 0.119988944
## Proportion of Variance 0.007802697 0.002399558
## Cumulative Proportion  0.997600442 1.000000000
## 
## Loadings:
##       Comp.1 Comp.2 Comp.3 Comp.4 Comp.5 Comp.6
## carat -0.496                0.651 -0.422  0.385
## depth         0.734  0.671                     
## table -0.123 -0.669  0.733                     
## x     -0.501                0.112        -0.855
## y     -0.494               -0.749 -0.381  0.205
## z     -0.493  0.103                0.818  0.276
# Run CART and try to improve the model with these extra properties.

 rpart.control(
        #minsplit = 20, #Min # of items that should be in a node to do a split
        #minbucket = round(minsplit/3), #Minimum number of items in a final node
        cp = 0.05, #Complexity parameter (min improvement to generate a split)
        maxcompete = 4, #Not related to model. Some printout for analyses
        maxsurrogate = 5, #Used to deal with missing values
        usesurrogate = 2, #Used to deal with missing values
        xval = 20, #Number of cross validations
        surrogatestyle = 0, #Used to deal with missing values
        maxdepth = 10 #Tree depth
    )
## $minsplit
## [1] 20
## 
## $minbucket
## [1] 7
## 
## $cp
## [1] 0.05
## 
## $maxcompete
## [1] 4
## 
## $maxsurrogate
## [1] 5
## 
## $usesurrogate
## [1] 2
## 
## $surrogatestyle
## [1] 0
## 
## $maxdepth
## [1] 10
## 
## $xval
## [1] 20
# Run CART and plot the model 
diamondsprice_model <- rpart(price ~ ., data=diamonds_train %>% select(-diamond_id))
fancyRpartPlot(diamondsprice_model)

# See in sample prediction results

diamonds_predict <- predict(diamondsprice_model,newdata=diamonds_test %>% select(-diamond_id))

str(diamonds_predict)
##  Named num [1:10797] 14944 3073 3073 5395 3073 ...
##  - attr(*, "names")= chr [1:10797] "1" "2" "3" "4" ...
# Compare the model predictions with the test data

difference <- as.integer(diamonds_test$price) - as.integer(diamonds_predict)

summary(difference)
##     Min.  1st Qu.   Median     Mean  3rd Qu.     Max. 
## -9893.00  -530.00  -176.00   -11.04   528.00 12640.00