1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package org.apache.commons.math.stat.inference;
17
18 import org.apache.commons.math.MathException;
19 import org.apache.commons.math.distribution.DistributionFactory;
20 import org.apache.commons.math.distribution.ChiSquaredDistribution;
21
22
23
24
25
26
27 public class ChiSquareTestImpl implements ChiSquareTest {
28
29
30 private DistributionFactory distributionFactory = null;
31
32
33
34
35 public ChiSquareTestImpl() {
36 super();
37 }
38
39
40
41
42
43
44
45
46 public double chiSquare(double[] expected, long[] observed)
47 throws IllegalArgumentException {
48 double sumSq = 0.0d;
49 double dev = 0.0d;
50 if ((expected.length < 2) || (expected.length != observed.length)) {
51 throw new IllegalArgumentException(
52 "observed, expected array lengths incorrect");
53 }
54 if (!isPositive(expected) || !isNonNegative(observed)) {
55 throw new IllegalArgumentException(
56 "observed counts must be non-negative and expected counts must be postive");
57 }
58 for (int i = 0; i < observed.length; i++) {
59 dev = ((double) observed[i] - expected[i]);
60 sumSq += dev * dev / expected[i];
61 }
62 return sumSq;
63 }
64
65
66
67
68
69
70
71
72 public double chiSquareTest(double[] expected, long[] observed)
73 throws IllegalArgumentException, MathException {
74 ChiSquaredDistribution chiSquaredDistribution =
75 getDistributionFactory().createChiSquareDistribution(
76 (double) expected.length - 1);
77 return 1 - chiSquaredDistribution.cumulativeProbability(
78 chiSquare(expected, observed));
79 }
80
81
82
83
84
85
86
87
88
89
90 public boolean chiSquareTest(double[] expected, long[] observed,
91 double alpha) throws IllegalArgumentException, MathException {
92 if ((alpha <= 0) || (alpha > 0.5)) {
93 throw new IllegalArgumentException(
94 "bad significance level: " + alpha);
95 }
96 return (chiSquareTest(expected, observed) < alpha);
97 }
98
99
100
101
102
103
104 public double chiSquare(long[][] counts) throws IllegalArgumentException {
105
106 checkArray(counts);
107 int nRows = counts.length;
108 int nCols = counts[0].length;
109
110
111 double[] rowSum = new double[nRows];
112 double[] colSum = new double[nCols];
113 double total = 0.0d;
114 for (int row = 0; row < nRows; row++) {
115 for (int col = 0; col < nCols; col++) {
116 rowSum[row] += (double) counts[row][col];
117 colSum[col] += (double) counts[row][col];
118 total += (double) counts[row][col];
119 }
120 }
121
122
123 double sumSq = 0.0d;
124 double expected = 0.0d;
125 for (int row = 0; row < nRows; row++) {
126 for (int col = 0; col < nCols; col++) {
127 expected = (rowSum[row] * colSum[col]) / total;
128 sumSq += (((double) counts[row][col] - expected) *
129 ((double) counts[row][col] - expected)) / expected;
130 }
131 }
132 return sumSq;
133 }
134
135
136
137
138
139
140
141 public double chiSquareTest(long[][] counts)
142 throws IllegalArgumentException, MathException {
143 checkArray(counts);
144 double df = ((double) counts.length -1) * ((double) counts[0].length - 1);
145 ChiSquaredDistribution chiSquaredDistribution =
146 getDistributionFactory().createChiSquareDistribution(df);
147 return 1 - chiSquaredDistribution.cumulativeProbability(chiSquare(counts));
148 }
149
150
151
152
153
154
155
156
157
158 public boolean chiSquareTest(long[][] counts, double alpha)
159 throws IllegalArgumentException, MathException {
160 if ((alpha <= 0) || (alpha > 0.5)) {
161 throw new IllegalArgumentException("bad significance level: " + alpha);
162 }
163 return (chiSquareTest(counts) < alpha);
164 }
165
166
167
168
169
170
171
172
173
174 private void checkArray(long[][] in) throws IllegalArgumentException {
175
176 if (in.length < 2) {
177 throw new IllegalArgumentException("Input table must have at least two rows");
178 }
179
180 if (in[0].length < 2) {
181 throw new IllegalArgumentException("Input table must have at least two columns");
182 }
183
184 if (!isRectangular(in)) {
185 throw new IllegalArgumentException("Input table must be rectangular");
186 }
187
188 if (!isNonNegative(in)) {
189 throw new IllegalArgumentException("All entries in input 2-way table must be non-negative");
190 }
191
192 }
193
194
195
196
197
198
199
200 protected DistributionFactory getDistributionFactory() {
201 if (distributionFactory == null) {
202 distributionFactory = DistributionFactory.newInstance();
203 }
204 return distributionFactory;
205 }
206
207
208
209
210
211
212
213
214
215
216
217 private boolean isRectangular(long[][] in) {
218 for (int i = 1; i < in.length; i++) {
219 if (in[i].length != in[0].length) {
220 return false;
221 }
222 }
223 return true;
224 }
225
226
227
228
229
230
231
232
233
234 private boolean isPositive(double[] in) {
235 for (int i = 0; i < in.length; i ++) {
236 if (in[i] <= 0) {
237 return false;
238 }
239 }
240 return true;
241 }
242
243
244
245
246
247
248
249
250
251 private boolean isNonNegative(long[] in) {
252 for (int i = 0; i < in.length; i ++) {
253 if (in[i] < 0) {
254 return false;
255 }
256 }
257 return true;
258 }
259
260
261
262
263
264
265
266
267
268 private boolean isNonNegative(long[][] in) {
269 for (int i = 0; i < in.length; i ++) {
270 for (int j = 0; j < in[i].length; j++) {
271 if (in[i][j] < 0) {
272 return false;
273 }
274 }
275 }
276 return true;
277 }
278
279 }