6.6 Constructing the Regression Line Our problem of finding the regression line for a given set of data points has become one of finding the values for a and a 1 that minimize the value of the error function E ( a , a 1 ). E ( a , a 1 ) is a function of the two variables a and a 1 , and its value is always nonnegative. Its graph will be the surface of a smooth, three-dimensional shape sitting above (or possibly touching) the a a 1 plane. The region near its minimum would be like the bottom of a round bowl. At the minimum point, tangent lines drawn parallel to the a axis and parallel to the a 1 axis will both have zero slopes. [3]
So we differentiate E ( a , a 1 ) twice, first with respect to a and then with respect to a 1. To start, multiply out the squared expression:
Differentiate with respect to a :
and then differentiate with respect to a 1 :
Since the slopes are 0, set both derivatives equal to 0 and divide all sides by -2. For the first equation,
and then for the second equation,
If n is the total number of data points, then in the first equation, a is simply na . If we rearrange the terms of the two equations, we have the normal equations, a system of two linear equations with the two unknowns a and a 1 :
Solve the first equation for a :
and substitute it into the second equation:
Multiply through by n and rearrange the terms:
and so
If we let represent , the average of the x i values, and represent , the average of the y i values, we can simplify the expression for a :
So to calculate the least-squares regression line for a given set of data points, we need to compute the following quantities : x i , y i , , x i y i , , and . Listing 6-2a shows the RegressionLine class in package numbercruncher. mathutils . It computes these quantities for a set of data points?amethod addData() updates the sums for each new data point. Before methods getA0() and getA1() return the values of the coefficients a and a 1 , respectively, each invokes method validateCoefficients() , which uses the formulas to compute the current values of the coefficients. The class implements the Evaluatable interface. Listing 6-2a A class that implements a least-squares regression line.package numbercruncher.mathutils; /** * A least-squares regression line function. */ public class RegressionLine implements Evaluatable { /** sum of x */ private double sumX; /** sum of y */ private double sumY; /** sum of x*x */ private double sumXX; /** sum of x*y */ private double sumXY; /** line coefficient a0 */ private float a0; /** line coefficient a1 */ private float a1; /** number of data points */ private int n; /** true if coefficients valid */ private boolean coefsValid; /** * Constructor. */ public RegressionLine() {} /** * Constructor. * @param data the array of data points */ public RegressionLine(DataPoint data[]) { for (int i = 0; i < data.length; ++i) { addDataPoint(data[i]); } } /** * Return the current number of data points. * @return the count */ public int getDataPointCount() { return n; } /** * Return the coefficient a0. * @return the value of a0 */ public float getA0() { validateCoefficients(); return a0; } /** * Return the coefficient a1. * @return the value of a1 */ public float getA1() { validateCoefficients(); return a1; } /** * Return the sum of the x values. * @return the sum */ public double getSumX() { return sumX; } /** * Return the sum of the y values. * @return the sum */ public double getSumY() { return sumY; } /** * Return the sum of the x*x values. * @return the sum */ public double getSumXX() { return sumXX; } /** * Return the sum of the x*y values. * @return the sum */ public double getSumXY() { return sumXY; } /** * Add a new data point: Update the sums. * @param dataPoint the new data point */ public void addDataPoint(DataPoint dataPoint) { sumX += dataPoint.x; sumY += dataPoint.y; sumXX += dataPoint.x*dataPoint.x; sumXY += dataPoint.x*dataPoint.y; ++n; coefsValid = false; } /** * Return the value of the regression line function at x. * (Implementation of Evaluatable.) * @param x the value of x * @return the value of the function at x */ public float at(float x) { if (n < 2) return Float.NaN; validateCoefficients(); return a0 + a1*x; } /** * Reset. */ public void reset() { n = 0; sumX = sumY = sumXX = sumXY = 0; coefsValid = false; } /** * Validate the coefficients. */ private void validateCoefficients() { if (coefsValid) return; if (n >= 2) { float xBar = (float) sumX/n; float yBar = (float) sumY/n; a1 = (float) ((n*sumXY - sumX*sumY) /(n*sumXX - sumX*sumX)); a0 = (float) (yBar - a1*xBar); } else { a0 = a1 = Float.NaN; } coefsValid = true; } } There could be some computational problems with the sums. If the data points have both positive and negative x and y values, then sumX and sumY may have cancellation errors. If the data points are spread far apart from one another, there may be magnitude errors with all the sums, especially sumXX and sumXY . Of course, there is also the danger of overflow. Therefore, we may need to rewrite method addData() to employ some of the summation algorithms described in Chapter 4. Program 6 §C2 instantiates a RegressionLine object and uses a set of seven data points to construct and print the equation for a least-squares regression line. Listing 6-2b shows the noninteractive version of the program. Listing 6-2b The noninteractive version of Program 6 §C2 constructs a regression line for a set of data points.package numbercruncher.program6_2; import numbercruncher.mathutils.DataPoint; import numbercruncher.mathutils.RegressionLine; /** * PROGRAM 6-2: Linear Regression * * Demonstrate linear regression by constructing * the regression line for a set of data points. */ public class LinearRegression { private static final int MAX_POINTS = 10; /** * Main program. * @param args the array of runtime arguments */ public static void main(String args[]) { RegressionLine line = new RegressionLine(); line.addDataPoint(new DataPoint(6.2f, 6.0f)); line.addDataPoint(new DataPoint(1.3f, 0.75f)); line.addDataPoint(new DataPoint(5.5f, 3.05f)); line.addDataPoint(new DataPoint(2.8f, 2.96f)); line.addDataPoint(new DataPoint(4.7f, 4.72f)); line.addDataPoint(new DataPoint(7.9f, 5.81f)); line.addDataPoint(new DataPoint(3.0f, 2.49f)); printSums(line); printLine(line); } /** * Print the computed sums. * @param line the regression line */ private static void printSums(RegressionLine line) { System.out.println("n = " + line.getDataPointCount()); System.out.println("Sum x = " + line.getSumX()); System.out.println("Sum y = " + line.getSumY()); System.out.println("Sum xx = " + line.getSumXX()); System.out.println("Sum xy = " + line.getSumXY()); } /** * Print the regression line function. * @param line the regression line */ private static void printLine(RegressionLine line) { System.out.println("\nRegression line: y = " + line.getA1() + "x + " + line.getA0()); } } Output: n = 7 Sum x = 31.399999618530273 Sum y = 25.77999973297119 Sum xx = 171.71999621391296 Sum xy = 138.7909932732582 Regression line: y = 0.74993044x + 0.31888318 The interactive version of Program 6 §C2 allows you to set up to 100 arbitrary data points by clicking the mouse on the graph, and the program will then create and plot the regression line through the data points. Screen 6-2a shows a screen shot. Screen 6-2a. A least-squares regression line created by the interactive version of Program 6 §C2.
You can add new data points to see their effect on the regression line. Screen 6-2b is a screen shot after new data points have been added, and it shows the new regression line. Screen 6-2b. The result of adding new data points.
The fact that the least-squares algorithm minimizes the vertical distances between the data points and the regression line, instead of their perpendicular distances from the line, causes pathological behavior when most of the data points are stacked vertically. Screen 6-2c shows an example. Fortunately, most data points, such as from experiments, are spread more horizontally than vertically. Screen 6-2c. A pathological regression line when most of the data points are stacked vertically.
A regression line that is nearly vertical (as the data points in Screen 6-2c suggest that it should be) would cause very large D values. |
Top |