-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathR_simple_linear_regression.R
More file actions
56 lines (44 loc) · 1.94 KB
/
Copy pathR_simple_linear_regression.R
File metadata and controls
56 lines (44 loc) · 1.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
dataset = read.csv('Salary_Data.csv')
# dataset$Age = ifelse(is.na(dataset$Age),
# ave(dataset$Age, FUN = function(x) mean(x, na.rm = TRUE)),
# dataset$Age)
# dataset$Salary = ifelse(is.na(dataset$Salary),
# ave(dataset$Salary, FUN = function(x) mean(x, na.rm = TRUE)),
# dataset$Salary)
# dataset$Country = factor(dataset$Country,
# levels = c('France', 'Spain', 'Germany'),
# labels = c(1, 2, 3))
# dataset$Purchased = factor(dataset$Purchased,
# levels = c('No', 'Yes'),
# labels = c(0, 1))
install.packages('caTools')
library(caTools)
set.seed(123)
split = sample.split(dataset$Salary, SplitRatio = 2/3)
training_set = subset(dataset, split == TRUE)
test_set = subset(dataset, split == FALSE)
# training_set[, 2:3] = scale(training_set[, 2:3])
# test_set[, 2:3] = scale(test_set[, 2:3])
regressor = lm(formula = Salary ~ YearsExperience,
data = training_set)
y_pred = predict(regressor, newdata = test_set)
# install.packages('ggplot2')
library(ggplot2)
ggplot() +
geom_point(aes(x = training_set$YearsExperience, y = training_set$Salary),
colour = 'red') +
geom_line(aes(x = training_set$YearsExperience, y = predict(regressor, newdata = training_set)),
colour = 'blue') +
geom_smooth(method = 'lm') +
ggtitle('Salary vs Experience (Training set)') +
xlab('Years of experience') +
ylab('Salary')
ggplot() +
geom_point(aes(x = test_set$YearsExperience, y = test_set$Salary),
colour = 'red') +
geom_line(aes(x = training_set$YearsExperience, y = predict(regressor, newdata = training_set)),
colour = 'blue') +
geom_smooth(method = 'lm') +
ggtitle('Salary vs Experience (Test set)') +
xlab('Years of experience') +
ylab('Salary')