This post is part of the Regression and Model Fitting Tutorial
Linear regression is one of the all time most used model fitting techniques used in the world today. As part of a recent project I had a use case where I was generating a lot of reports with a lot of charts that used linear regression. Having someone run these by hand would have taken months where a small amount of development effort reduced that effort so that a number of charts could be generated with the most current data on a daily basis.
Before we get too far into the details of the query, we should review a little bit of the math behind linear regression. The basic idea of linear regression is that we want to fit a line (usually in the form of y = a + bx ) against a set of points where we minimize the total "overall" error between the line and the data points (in the figure below the yellow line is the regression line and the blue points are the data points):
Mathematically, this is a problem that is well understood and breaks down to simultaneously solving a set of linear equations using the method of least squares. The equations take the following form, solving for a and b in a way that minimizes the overall error of the model:
Although you can algebraically solve it in this form, it is extremely tedious and quickly gets out of hand when you have more than a few data points. The best way that I've learned to solve these equations goes back to linear algebra and matrices. We take our observations and create the following three matrices (x1, x2, x3, etc are observations of the X variable and y1, y2, y3, etc are observations of the Y variable):
and we solve the following equation to get a and b left in X.
In the end, we end up with this (this can be simplified a step further to have average notation.:
Anyway, it is relatively easy to build a query that returns these parameters. Below is the SQL query that I use for linear regression. This query will work in MySql and Microsoft SQL Server. Depending on whether other platforms have the functions required (ex. Postgres, Oracle, IBM DB2), it may work there as well. Beware of how your DB platform handles summations as you may run into numerical issues on some data sets:
-- Developed by Mike Burr
-- This query calculates a linear regression
-- in the form y = a + bx and calculates
-- the correlation coefficient for the source data
select a as 'a',
b as 'b',
-- Correlation coefficient
(ss_xy * ss_xy)/ (ss_xx * ss_yy) as 'r_r'
from (
-- In this inner query we calculate the parameters
-- and the correlation coefficient for the linear model
-- that we calculated
select
((avg_yi * sum_xi_xi) - (avg_xi * sum_xi_yi )) /
(sum_xi_xi-(n* avg_xi * avg_xi))
as 'a',
(sum_xi_yi - (n * avg_xi * avg_yi)) /
(sum_xi_xi - (n * avg_xi * avg_xi))
as 'b',
sum_xi_xi - (n * avg_xi * avg_xi )
as 'ss_xx',
sum_yi_yi - (n * avg_yi * avg_yi )
as 'ss_yy',
sum_xi_yi - (n * avg_xi * avg_yi )
as 'ss_xy'
from (
-- In this inner query, we build the
-- variables used in the linear regression
-- calculation
select avg(y) as 'avg_yi',
avg(x) as 'avg_xi',
count(x) as 'n',
sum(x*x) as 'sum_xi_xi',
sum(y*y) as 'sum_yi_yi',
sum(x*y) as 'sum_xi_yi',
sum(x) as 'sum_xi'
from (
-- Insert source data query here
-- Alias the x-variable column as 'x'
-- Alias the y-variable column as 'y'
) as source_data
) as regression
) as final_parameters
Back to Mike's Big Data, Data Mining, and Analytics Tutorial