First packages should be installed and then libraries should be adressed.
library(tidyverse)
## -- Attaching packages ---------------------------------- tidyverse 1.2.1 --
## v ggplot2 2.2.1 v purrr 0.2.4
## v tibble 1.3.4 v dplyr 0.7.4
## v tidyr 0.7.2 v stringr 1.2.0
## v readr 1.1.1 v forcats 0.2.0
## -- Conflicts ------------------------------------- tidyverse_conflicts() --
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
library(rpart)
library(rpart.plot)
library(rattle)
## Rattle: A free graphical interface for data science with R.
## Version 5.1.0 Copyright (c) 2006-2017 Togaware Pty Ltd.
## Type 'rattle()' to shake, rattle, and roll your data.
set.seed(503)
Let’s seperate the dataset into 2 pieces called diamonds_test and diamonds_train.
diamonds_test <- diamonds %>% mutate(diamond_id = row_number()) %>%
group_by(cut, color, clarity) %>% sample_frac(0.2) %>% ungroup()
diamonds_train <- anti_join(diamonds %>% mutate(diamond_id = row_number()),
diamonds_test, by = "diamond_id")%>%
select(cut,color,clarity,price)
diamonds_train
## # A tibble: 43,143 x 4
## cut color clarity price
## <ord> <ord> <ord> <int>
## 1 Ideal E SI2 326
## 2 Premium E SI1 326
## 3 Good E VS1 327
## 4 Premium I VS2 334
## 5 Very Good J VVS2 336
## 6 Very Good I VVS1 336
## 7 Very Good H SI1 337
## 8 Fair E VS2 337
## 9 Very Good H VS1 338
## 10 Good J SI1 339
## # ... with 43,133 more rows
diamonds_test
## # A tibble: 10,797 x 11
## carat cut color clarity depth table price x y z
## <dbl> <ord> <ord> <ord> <dbl> <dbl> <int> <dbl> <dbl> <dbl>
## 1 3.40 Fair D I1 66.8 52 15964 9.42 9.34 6.27
## 2 0.90 Fair D SI2 64.7 59 3205 6.09 5.99 3.91
## 3 0.95 Fair D SI2 64.4 60 3384 6.06 6.02 3.89
## 4 1.00 Fair D SI2 65.2 56 3634 6.27 6.21 4.07
## 5 0.70 Fair D SI2 58.1 60 2358 5.79 5.82 3.37
## 6 1.04 Fair D SI2 64.9 56 4398 6.39 6.34 4.13
## 7 0.70 Fair D SI2 65.6 55 2167 5.59 5.50 3.64
## 8 1.03 Fair D SI2 66.4 56 3743 6.31 6.19 4.15
## 9 1.10 Fair D SI2 64.6 54 4725 6.56 6.49 4.22
## 10 2.01 Fair D SI2 59.4 66 15627 8.20 8.17 4.86
## # ... with 10,787 more rows, and 1 more variables: diamond_id <int>
To apply CART, created a model and tree diagram is below.
diamonds_model <- rpart(price ~ ., data=diamonds_train)
fancyRpartPlot(diamonds_model)
As it’s seen above, the model has seperated colors into 2 class, and clarity too.
Also predicted price value can be estimated as below.
diamonds_in_sample <- predict(diamonds_model)
print(head(diamonds_in_sample))
## 1 2 3 4 5 6
## 3540.079 3540.079 3540.079 5250.455 2543.766 2543.766
To compare actual data and predicted one, created a column called ‘error’. The mean of all errors is very small number and printed below.
in_sample_prediction <-
cbind(
diamonds_in_sample %>% tbl_df,
diamonds_train %>% tbl_df
) %>%
mutate(error=price-value)
print(mean(in_sample_prediction$error))
## [1] 8.221958e-14