Skip to content

Commit 0bd38b1

Browse files
committed
First comment!
0 parents commit 0bd38b1

File tree

1 file changed

+123
-0
lines changed

1 file changed

+123
-0
lines changed

model.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
class LinearRegression():
2+
"""
3+
A simple implementation of Linear Regression using Gradient Descent.
4+
"""
5+
6+
def __init__(self, step=0.001, n_iters=10000):
7+
"""
8+
Initializes the LinearRegression class with the provided learning rate (`step`)
9+
and the number of iterations (`n_iters`). Also initializes the slope and intercept
10+
of the model, as well as internal variables to track the number of features (`__m_`)
11+
and samples (`__n_`).
12+
13+
Parameters:
14+
-----------
15+
step : float
16+
The learning rate for the gradient descent optimization. It controls how big
17+
of a step is taken towards the minimum during each iteration.
18+
n_iters : int
19+
The number of iterations for the gradient descent optimization.
20+
"""
21+
self.__k = 0
22+
self.step = step
23+
self.n_iters = n_iters
24+
self._slope_ = 0
25+
self._intercept_ = 0
26+
self.__m_ = 0
27+
self.__n_ = 0
28+
29+
def fit(self, X, y):
30+
"""
31+
Trains the linear regression model using the input data `X` and target values `y`.
32+
It adjusts the slope (`_slope_`) and intercept (`_intercept_`) by performing gradient
33+
descent over `n_iters` iterations.
34+
35+
Parameters:
36+
-----------
37+
X : list of list of float or list of float
38+
The input data, where each inner list represents a sample with multiple features,
39+
or a simple list if there's only one feature.
40+
y : list of float
41+
The target values corresponding to each sample in `X`.
42+
43+
Raises:
44+
-------
45+
ValueError:
46+
If the number of samples in `X` and `y` do not match.
47+
"""
48+
self.__n_ = len(X)
49+
50+
if isinstance(X[0], list):
51+
self.__m_ = len(X[0])
52+
else:
53+
self.__m_ = 1
54+
X = [[x] for x in X]
55+
56+
if self.__n_ != len(y):
57+
raise ValueError(f"X and y must have the same number of samples: {(self.__n_, len(y))}")
58+
59+
self._slope_ = [0] * self.__m_
60+
self._intercept_ = 0
61+
62+
for _ in range(self.n_iters):
63+
y_pred = self.predict(X)
64+
65+
for j in range(self.__m_):
66+
self._slope_[j] -= self.step * (-(2/self.__n_) * sum((y[i] - y_pred[i]) * X[i][j] for i in range(self.__n_)))
67+
self._intercept_ -= self.step * (-(2/self.__n_) * sum(y[i] - y_pred[i] for i in range(self.__n_)))
68+
69+
def predict(self, X):
70+
"""
71+
Predicts the target values for the given input data `X` using the trained linear
72+
regression model.
73+
74+
Parameters:
75+
-----------
76+
X : list of list of float or list of float
77+
The input data, where each inner list represents a sample with multiple features,
78+
or a simple list if there's only one feature.
79+
80+
Returns:
81+
--------
82+
y_pred : list of float
83+
The predicted target values corresponding to each sample in `X`.
84+
85+
Raises:
86+
-------
87+
ValueError:
88+
If the input data `X` has a different number of features or samples compared
89+
to the data used for training.
90+
"""
91+
if isinstance(X[0], list):
92+
m = len(X[0])
93+
else:
94+
m = 1
95+
X = [[x] for x in X]
96+
n = len(X)
97+
if m != self.__m_ or n != self.__n_:
98+
raise ValueError(f"X must have the same number of features as the training data: {(m, n)}, Except: {(self.__m_, self.__n_)}")
99+
100+
y_pred = []
101+
for i in range(len(X)):
102+
y_pred.append(sum(self._slope_[j] * X[i][j] for j in range(m)) + self._intercept_)
103+
104+
return y_pred
105+
106+
def MSE(self, y, y_pred):
107+
"""
108+
Calculates the Mean Squared Error (MSE) between the true target values `y` and the
109+
predicted values `y_pred`.
110+
111+
Parameters:
112+
-----------
113+
y : list of float
114+
The true target values.
115+
y_pred : list of float
116+
The predicted target values.
117+
118+
Returns:
119+
--------
120+
mse : float
121+
The mean squared error between `y` and `y_pred`.
122+
"""
123+
return sum((y[i] - y_pred[i]) ** 2 for i in range(len(y))) / len(y)

0 commit comments

Comments
 (0)