1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package org.apache.commons.math.stat.regression;
17
18 import java.util.Random;
19
20 import junit.framework.Test;
21 import junit.framework.TestCase;
22 import junit.framework.TestSuite;
23
24
25
26
27
28
29 public final class SimpleRegressionTest extends TestCase {
30
31
32
33
34
35
36 private double[][] data = { { 0.1, 0.2 }, {338.8, 337.4 }, {118.1, 118.2 },
37 {888.0, 884.6 }, {9.2, 10.1 }, {228.1, 226.5 }, {668.5, 666.3 }, {998.5, 996.3 },
38 {449.1, 448.6 }, {778.9, 777.0 }, {559.2, 558.2 }, {0.3, 0.4 }, {0.1, 0.6 }, {778.1, 775.5 },
39 {668.8, 666.9 }, {339.3, 338.0 }, {448.9, 447.5 }, {10.8, 11.6 }, {557.7, 556.0 },
40 {228.3, 228.1 }, {998.0, 995.8 }, {888.8, 887.6 }, {119.6, 120.2 }, {0.3, 0.3 },
41 {0.6, 0.3 }, {557.6, 556.8 }, {339.3, 339.1 }, {888.0, 887.2 }, {998.5, 999.0 },
42 {778.9, 779.0 }, {10.2, 11.1 }, {117.6, 118.3 }, {228.9, 229.2 }, {668.4, 669.1 },
43 {449.2, 448.9 }, {0.2, 0.5 }
44 };
45
46
47
48
49
50 private double[][] corrData = { { 101.0, 99.2 }, {100.1, 99.0 }, {100.0, 100.0 },
51 {90.6, 111.6 }, {86.5, 122.2 }, {89.7, 117.6 }, {90.6, 121.1 }, {82.8, 136.0 },
52 {70.1, 154.2 }, {65.4, 153.6 }, {61.3, 158.5 }, {62.5, 140.6 }, {63.6, 136.2 },
53 {52.6, 168.0 }, {59.7, 154.3 }, {59.5, 149.0 }, {61.3, 165.5 }
54 };
55
56
57
58
59
60 private double[][] infData = { { 15.6, 5.2 }, {26.8, 6.1 }, {37.8, 8.7 }, {36.4, 8.5 },
61 {35.5, 8.8 }, {18.6, 4.9 }, {15.3, 4.5 }, {7.9, 2.5 }, {0.0, 1.1 }
62 };
63
64
65
66
67 private double[][] infData2 = { { 1, 1 }, {2, 0 }, {3, 5 }, {4, 2 },
68 {5, -1 }, {6, 12 }
69 };
70
71 public SimpleRegressionTest(String name) {
72 super(name);
73 }
74
75 public void setUp() {
76 }
77
78 public static Test suite() {
79 TestSuite suite = new TestSuite(SimpleRegressionTest.class);
80 suite.setName("BivariateRegression Tests");
81 return suite;
82 }
83
84 public void testNorris() {
85 SimpleRegression regression = new SimpleRegression();
86 for (int i = 0; i < data.length; i++) {
87 regression.addData(data[i][1], data[i][0]);
88 }
89
90
91 assertEquals("slope", 1.00211681802045, regression.getSlope(), 10E-12);
92 assertEquals("slope std err", 0.429796848199937E-03,
93 regression.getSlopeStdErr(),10E-12);
94 assertEquals("number of observations", 36, regression.getN());
95 assertEquals( "intercept", -0.262323073774029,
96 regression.getIntercept(),10E-12);
97 assertEquals("std err intercept", 0.232818234301152,
98 regression.getInterceptStdErr(),10E-12);
99 assertEquals("r-square", 0.999993745883712,
100 regression.getRSquare(), 10E-12);
101 assertEquals("SSR", 4255954.13232369,
102 regression.getRegressionSumSquares(), 10E-9);
103 assertEquals("MSE", 0.782864662630069,
104 regression.getMeanSquareError(), 10E-10);
105 assertEquals("SSE", 26.6173985294224,
106 regression.getSumSquaredErrors(),10E-9);
107
108
109 assertEquals( "predict(0)", -0.262323073774029,
110 regression.predict(0), 10E-12);
111 assertEquals("predict(1)", 1.00211681802045 - 0.262323073774029,
112 regression.predict(1), 10E-12);
113 }
114
115 public void testCorr() {
116 SimpleRegression regression = new SimpleRegression();
117 regression.addData(corrData);
118 assertEquals("number of observations", 17, regression.getN());
119 assertEquals("r-square", .896123, regression.getRSquare(), 10E-6);
120 assertEquals("r", -0.94663767742, regression.getR(), 1E-10);
121 }
122
123 public void testNaNs() {
124 SimpleRegression regression = new SimpleRegression();
125 assertTrue("intercept not NaN", Double.isNaN(regression.getIntercept()));
126 assertTrue("slope not NaN", Double.isNaN(regression.getSlope()));
127 assertTrue("slope std err not NaN", Double.isNaN(regression.getSlopeStdErr()));
128 assertTrue("intercept std err not NaN", Double.isNaN(regression.getInterceptStdErr()));
129 assertTrue("MSE not NaN", Double.isNaN(regression.getMeanSquareError()));
130 assertTrue("e not NaN", Double.isNaN(regression.getR()));
131 assertTrue("r-square not NaN", Double.isNaN(regression.getRSquare()));
132 assertTrue( "RSS not NaN", Double.isNaN(regression.getRegressionSumSquares()));
133 assertTrue("SSE not NaN",Double.isNaN(regression.getSumSquaredErrors()));
134 assertTrue("SSTO not NaN", Double.isNaN(regression.getTotalSumSquares()));
135 assertTrue("predict not NaN", Double.isNaN(regression.predict(0)));
136
137 regression.addData(1, 2);
138 regression.addData(1, 3);
139
140
141 assertTrue("intercept not NaN", Double.isNaN(regression.getIntercept()));
142 assertTrue("slope not NaN", Double.isNaN(regression.getSlope()));
143 assertTrue("slope std err not NaN", Double.isNaN(regression.getSlopeStdErr()));
144 assertTrue("intercept std err not NaN", Double.isNaN(regression.getInterceptStdErr()));
145 assertTrue("MSE not NaN", Double.isNaN(regression.getMeanSquareError()));
146 assertTrue("e not NaN", Double.isNaN(regression.getR()));
147 assertTrue("r-square not NaN", Double.isNaN(regression.getRSquare()));
148 assertTrue("RSS not NaN", Double.isNaN(regression.getRegressionSumSquares()));
149 assertTrue("SSE not NaN", Double.isNaN(regression.getSumSquaredErrors()));
150 assertTrue("predict not NaN", Double.isNaN(regression.predict(0)));
151
152
153 assertTrue("SSTO NaN", !Double.isNaN(regression.getTotalSumSquares()));
154
155 regression = new SimpleRegression();
156
157 regression.addData(1, 2);
158 regression.addData(3, 3);
159
160
161 assertTrue("interceptNaN", !Double.isNaN(regression.getIntercept()));
162 assertTrue("slope NaN", !Double.isNaN(regression.getSlope()));
163 assertTrue ("slope std err not NaN", Double.isNaN(regression.getSlopeStdErr()));
164 assertTrue("intercept std err not NaN", Double.isNaN(regression.getInterceptStdErr()));
165 assertTrue("MSE not NaN", Double.isNaN(regression.getMeanSquareError()));
166 assertTrue("r NaN", !Double.isNaN(regression.getR()));
167 assertTrue("r-square NaN", !Double.isNaN(regression.getRSquare()));
168 assertTrue("RSS NaN", !Double.isNaN(regression.getRegressionSumSquares()));
169 assertTrue("SSE NaN", !Double.isNaN(regression.getSumSquaredErrors()));
170 assertTrue("SSTO NaN", !Double.isNaN(regression.getTotalSumSquares()));
171 assertTrue("predict NaN", !Double.isNaN(regression.predict(0)));
172
173 regression.addData(1, 4);
174
175
176 assertTrue("MSE NaN", !Double.isNaN(regression.getMeanSquareError()));
177 assertTrue("slope std err NaN", !Double.isNaN(regression.getSlopeStdErr()));
178 assertTrue("intercept std err NaN", !Double.isNaN(regression.getInterceptStdErr()));
179 }
180
181 public void testClear() {
182 SimpleRegression regression = new SimpleRegression();
183 regression.addData(corrData);
184 assertEquals("number of observations", 17, regression.getN());
185 regression.clear();
186 assertEquals("number of observations", 0, regression.getN());
187 regression.addData(corrData);
188 assertEquals("r-square", .896123, regression.getRSquare(), 10E-6);
189 regression.addData(data);
190 assertEquals("number of observations", 53, regression.getN());
191 }
192
193 public void testInference() throws Exception {
194
195
196 SimpleRegression regression = new SimpleRegression();
197 regression.addData(infData);
198 assertEquals("slope std err", 0.011448491,
199 regression.getSlopeStdErr(), 1E-10);
200 assertEquals("std err intercept", 0.286036932,
201 regression.getInterceptStdErr(),1E-8);
202 assertEquals("significance", 4.596e-07,
203 regression.getSignificance(),1E-8);
204 assertEquals("slope conf interval half-width", 0.0270713794287,
205 regression.getSlopeConfidenceInterval(),1E-8);
206
207 regression = new SimpleRegression();
208 regression.addData(infData2);
209 assertEquals("slope std err", 1.07260253,
210 regression.getSlopeStdErr(), 1E-8);
211 assertEquals("std err intercept",4.17718672,
212 regression.getInterceptStdErr(),1E-8);
213 assertEquals("significance", 0.261829133982,
214 regression.getSignificance(),1E-11);
215 assertEquals("slope conf interval half-width", 2.97802204827,
216 regression.getSlopeConfidenceInterval(),1E-8);
217
218
219
220 assertTrue("tighter means wider",
221 regression.getSlopeConfidenceInterval() < regression.getSlopeConfidenceInterval(0.01));
222
223 try {
224 double x = regression.getSlopeConfidenceInterval(1);
225 fail("expecting IllegalArgumentException for alpha = 1");
226 } catch (IllegalArgumentException ex) {
227 ;
228 }
229
230 }
231
232 public void testPerfect() throws Exception {
233 SimpleRegression regression = new SimpleRegression();
234 int n = 100;
235 for (int i = 0; i < n; i++) {
236 regression.addData(((double) i) / (n - 1), i);
237 }
238 assertEquals(0.0, regression.getSignificance(), 1.0e-5);
239 assertTrue(regression.getSlope() > 0.0);
240 }
241
242 public void testPerfectNegative() throws Exception {
243 SimpleRegression regression = new SimpleRegression();
244 int n = 100;
245 for (int i = 0; i < n; i++) {
246 regression.addData(- ((double) i) / (n - 1), i);
247 }
248
249 assertEquals(0.0, regression.getSignificance(), 1.0e-5);
250 assertTrue(regression.getSlope() < 0.0);
251 }
252
253 public void testRandom() throws Exception {
254 SimpleRegression regression = new SimpleRegression();
255 Random random = new Random(1);
256 int n = 100;
257 for (int i = 0; i < n; i++) {
258 regression.addData(((double) i) / (n - 1), random.nextDouble());
259 }
260
261 assertTrue( 0.0 < regression.getSignificance()
262 && regression.getSignificance() < 1.0);
263 }
264 }