This document is prepared for assignment.
set.seed(503)
library(tidyverse)
## -- Attaching packages ---------------------------------- tidyverse 1.2.1 --
## <U+221A> ggplot2 2.2.1 <U+221A> purrr 0.2.4
## <U+221A> tibble 1.3.4 <U+221A> dplyr 0.7.4
## <U+221A> tidyr 0.7.2 <U+221A> stringr 1.2.0
## <U+221A> readr 1.1.1 <U+221A> forcats 0.2.0
## -- Conflicts ------------------------------------- tidyverse_conflicts() --
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
diamonds_test <- diamonds %>% mutate(diamond_id = row_number()) %>%
group_by(cut, color, clarity) %>% sample_frac(0.2) %>% ungroup()
diamonds_train <- anti_join(diamonds %>% mutate(diamond_id = row_number()),
diamonds_test, by = "diamond_id")
diamonds_train
## # A tibble: 43,143 x 11
## carat cut color clarity depth table price x y z
## <dbl> <ord> <ord> <ord> <dbl> <dbl> <int> <dbl> <dbl> <dbl>
## 1 0.23 Ideal E SI2 61.5 55 326 3.95 3.98 2.43
## 2 0.21 Premium E SI1 59.8 61 326 3.89 3.84 2.31
## 3 0.23 Good E VS1 56.9 65 327 4.05 4.07 2.31
## 4 0.29 Premium I VS2 62.4 58 334 4.20 4.23 2.63
## 5 0.24 Very Good J VVS2 62.8 57 336 3.94 3.96 2.48
## 6 0.24 Very Good I VVS1 62.3 57 336 3.95 3.98 2.47
## 7 0.26 Very Good H SI1 61.9 55 337 4.07 4.11 2.53
## 8 0.22 Fair E VS2 65.1 61 337 3.87 3.78 2.49
## 9 0.23 Very Good H VS1 59.4 61 338 4.00 4.05 2.39
## 10 0.30 Good J SI1 64.0 55 339 4.25 4.28 2.73
## # ... with 43,133 more rows, and 1 more variables: diamond_id <int>
diamonds_test
## # A tibble: 10,797 x 11
## carat cut color clarity depth table price x y z
## <dbl> <ord> <ord> <ord> <dbl> <dbl> <int> <dbl> <dbl> <dbl>
## 1 3.40 Fair D I1 66.8 52 15964 9.42 9.34 6.27
## 2 0.90 Fair D SI2 64.7 59 3205 6.09 5.99 3.91
## 3 0.95 Fair D SI2 64.4 60 3384 6.06 6.02 3.89
## 4 1.00 Fair D SI2 65.2 56 3634 6.27 6.21 4.07
## 5 0.70 Fair D SI2 58.1 60 2358 5.79 5.82 3.37
## 6 1.04 Fair D SI2 64.9 56 4398 6.39 6.34 4.13
## 7 0.70 Fair D SI2 65.6 55 2167 5.59 5.50 3.64
## 8 1.03 Fair D SI2 66.4 56 3743 6.31 6.19 4.15
## 9 1.10 Fair D SI2 64.6 54 4725 6.56 6.49 4.22
## 10 2.01 Fair D SI2 59.4 66 15627 8.20 8.17 4.86
## # ... with 10,787 more rows, and 1 more variables: diamond_id <int>
library(rpart)
library(rpart.plot)
summary(diamonds)
## carat cut color clarity
## Min. :0.2000 Fair : 1610 D: 6775 SI1 :13065
## 1st Qu.:0.4000 Good : 4906 E: 9797 VS2 :12258
## Median :0.7000 Very Good:12082 F: 9542 SI2 : 9194
## Mean :0.7979 Premium :13791 G:11292 VS1 : 8171
## 3rd Qu.:1.0400 Ideal :21551 H: 8304 VVS2 : 5066
## Max. :5.0100 I: 5422 VVS1 : 3655
## J: 2808 (Other): 2531
## depth table price x
## Min. :43.00 Min. :43.00 Min. : 326 Min. : 0.000
## 1st Qu.:61.00 1st Qu.:56.00 1st Qu.: 950 1st Qu.: 4.710
## Median :61.80 Median :57.00 Median : 2401 Median : 5.700
## Mean :61.75 Mean :57.46 Mean : 3933 Mean : 5.731
## 3rd Qu.:62.50 3rd Qu.:59.00 3rd Qu.: 5324 3rd Qu.: 6.540
## Max. :79.00 Max. :95.00 Max. :18823 Max. :10.740
##
## y z
## Min. : 0.000 Min. : 0.000
## 1st Qu.: 4.720 1st Qu.: 2.910
## Median : 5.710 Median : 3.530
## Mean : 5.735 Mean : 3.539
## 3rd Qu.: 6.540 3rd Qu.: 4.040
## Max. :58.900 Max. :31.800
##
glimpse(diamonds)
## Observations: 53,940
## Variables: 10
## $ carat <dbl> 0.23, 0.21, 0.23, 0.29, 0.31, 0.24, 0.24, 0.26, 0.22, ...
## $ cut <ord> Ideal, Premium, Good, Premium, Good, Very Good, Very G...
## $ color <ord> E, E, E, I, J, J, I, H, E, H, J, J, F, J, E, E, I, J, ...
## $ clarity <ord> SI2, SI1, VS1, VS2, SI2, VVS2, VVS1, SI1, VS2, VS1, SI...
## $ depth <dbl> 61.5, 59.8, 56.9, 62.4, 63.3, 62.8, 62.3, 61.9, 65.1, ...
## $ table <dbl> 55, 61, 65, 58, 58, 57, 57, 55, 61, 61, 55, 56, 61, 54...
## $ price <int> 326, 326, 327, 334, 335, 336, 336, 337, 337, 338, 339,...
## $ x <dbl> 3.95, 3.89, 4.05, 4.20, 4.34, 3.94, 3.95, 4.07, 3.87, ...
## $ y <dbl> 3.98, 3.84, 4.07, 4.23, 4.35, 3.96, 3.98, 4.11, 3.78, ...
## $ z <dbl> 2.43, 2.31, 2.31, 2.63, 2.75, 2.48, 2.47, 2.53, 2.49, ...
ggplot(diamonds,aes(x=carat,y=price))+
geom_point(color='blue',fill='pink')+
xlim(0,quantile(diamonds$carat,0.99))+
ylim(0,quantile(diamonds$price,0.99))+
ggtitle('Diamond price vs. carat')
qplot(x = price, data = diamonds) + facet_wrap(~cut, scales = "free")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
foranova<- as.formula(price ~ carat + cut + color + clarity + depth + table + x + y + z)
anova<- rpart(foranova, data=diamonds, method= "anova" )
anova
## n= 53940
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 53940 858473100000 3932.800
## 2) carat< 0.995 34880 43459420000 1632.641
## 4) y< 5.535 24951 6860691000 1058.546 *
## 5) y>=5.535 9929 7710112000 3075.309 *
## 3) carat>=0.995 19060 292761600000 8142.115
## 6) y< 7.195 12884 60679350000 6137.844
## 12) clarity=I1,SI2,SI1,VS2 9804 20256360000 5397.093 *
## 13) clarity=VS1,VVS2,VVS1,IF 3080 17919640000 8495.739 *
## 7) y>=7.195 6176 72354930000 12323.300
## 14) y< 7.815 3945 33996520000 10899.960
## 28) clarity=I1,SI2 954 3380193000 8375.178 *
## 29) clarity=SI1,VS2,VS1,VVS2,VVS1,IF 2991 22595360000 11705.260
## 58) color=H,I,J 1554 5588830000 10014.970 *
## 59) color=D,E,F,G 1437 7765314000 13533.160 *
## 15) y>=7.815 2231 16233830000 14840.160 *
rpart.plot(anova, type=3, digits=2)
predictt<- predict(anova, diamonds)
#i can't write predict it turns to matrix (flying numbers)
Thank you!