The simplest OLSMultipleLinearRegression example ever

Assuming that the reader knows what Multivariate Linear Regression by Ordinary Least Squares is, and that he would like to use it in Java by means of OLSMultipleLinearRegression, class that is part of the Apache Commons Math library, let's see the simplest example of its usage I could think of.

I have two observations, one variable, and I want to get the description of an interpolating curve:
OLSMultipleLinearRegression ols = new OLSMultipleLinearRegression();

double[] data = { 2, 1, 4, 2 }; // 1
int obs = 2;
int vars = 1; // 2
try {
    ols.newSampleData(data, obs, vars); // 3
}
catch(IllegalArgumentException e) {
    System.out.print("Can't sample data: ");
    e.printStackTrace();
    return;
}

double[] coe = null;
try {
    coe = ols.estimateRegressionParameters(); // 4
}
catch(IllegalArgumentException | InvalidMatrixException e) { // 5
    System.out.print("Can't estimate parameters: ");
    e.printStackTrace();
    return;
}

dumpEstimation(coe);
  1. The input data is flattened in an array, where all the observations are stored one after the other, respecting the convention of having first the y and then the x component.
  2. We should have more observations than variables, so this is the minimal case. Having just one variable, we'll get a straight line as a result.
  3. The functions performs a few check on the passed data before working with them, if the number of variables is not bigger than the number of observations, or if there are less values than expected in the array, an IllegalArgumentException is thrown.
  4. The resulting curve parameters are returned in a double array. In case of failure an exception is thrown.
  5. Cool Java 7 feature, we can group all the exceptions that requires the same management in a single block. In previous Java version, a specific catch is required for each exception type.
The last line is a call to a testing function that gives some feedback to the user. It calls another short function that actually calculates the estimated y value given a specific x:
private void dumpEstimation(double[] coe) {
    if(coe == null)
        return;

    for(double d : coe)
        System.out.print(d + " ");
    System.out.println();

    System.out.println("Estimations:");
    System.out.println("x = 1, y = " + calculateEstimation(1, coe));
    System.out.println("x = 2, y = " + calculateEstimation(2, coe));
    System.out.println("x = 3, y = " + calculateEstimation(3, coe));
    System.out.println("x = 4, y = " + calculateEstimation(4, coe));
}

private double calculateEstimation(double x, double[] coe) {
    double result = 0;
    for(int i = 0; i < coe.length; ++i)
        result += coe[i] * Math.pow(x, i); // 1
    return result;
}
  1. The most interesting line in this testing code. It shows how the coefficients are stored in the array returned from OLSMultipleLinearRegression.estimateRegressionParameters(). As we see, the coefficient 'i' is relative to the 'i'-th power of x.
The expected output is:
-0.0 2.0 
Estimations:
x = 1, y = 2.0
x = 2, y = 4.0
x = 3, y = 6.0
x = 4, y = 8.0

8 comments:

  1. I'm trying with all my might to get the coefficients for a function that should describe a set of data, using commons.Math. I have a method that receives three params:

    method(t, a, b, c)

    And I need to grasp which a, b and c would get me the best rsquare for a set of data in the form:

    {{t, v}, {t, v},..} // double array

    ReplyDelete
    Replies
    1. Hi Italo, thank you for leaving your comment here. Sorry I missed it at the time. I guess you find a solution on your own in the meantime ;-)

      Delete
  2. I know I could run the code myself, but it would be nice to see the actual output from the example

    ReplyDelete
    Replies
    1. Good suggestion, I added it at the bottom of the post. Thanks.

      Delete
  3. Hi,
    Thank you for the code.

    I have one query why u used Math.pw(x,i), will this work if i have more than 1 independent variable. If you make it coe[i] * x will be much better right?

    ReplyDelete
    Replies
    1. Hello Ravjot, thank you for your feedback. You are right, I used Math.pow() so that calculateEstimation() would work properly whichever number of coefficients is returned by estimateRegressionParameters().
      I see it as a feature, however it depends by the context in which the code is actually used.

      Delete
  4. This example is actually incorrect. The coef matrix is [0, 2]... the last number is the x offset (it is a linear regression). If you look at r^2 it is 1.0, meaning nothing was learned...

    The math.pow() is what makes the number correct--it should not be used.

    ReplyDelete
    Replies
    1. Sorry for the huge delay in my answer, I haven't check the blog for a long time. Please have a look at the previous comment for details on the reason why I used pow() in the example.

      Delete