Skip to content

Commit 68746f8

Browse files
authored
feat: Add Chebyshev Iteration algorithm (#6963)
* feat: Add Chebyshev Iteration algorithm * Update ChebyshevIteration.java * Update ChebyshevIterationTest.java * Update ChebyshevIteration.java * Update ChebyshevIterationTest.java * Update ChebyshevIteration.java * Update ChebyshevIteration.java * Update ChebyshevIterationTest.java * Update ChebyshevIteration.java * Update ChebyshevIterationTest.java * Update ChebyshevIterationTest.java * Update ChebyshevIteration.java * Update ChebyshevIteration.java * Update ChebyshevIterationTest.java * Update ChebyshevIteration.java * Update ChebyshevIterationTest.java * Update ChebyshevIteration.java * Update ChebyshevIterationTest.java * update * Update ChebyshevIteration.java * Update ChebyshevIterationTest.java * Update ChebyshevIteration.java * Update ChebyshevIterationTest.java * Update ChebyshevIteration.java * Update ChebyshevIterationTest.java
1 parent 3c70a54 commit 68746f8

File tree

2 files changed

+286
-0
lines changed

2 files changed

+286
-0
lines changed
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
package com.thealgorithms.maths;
2+
3+
/**
4+
* In numerical analysis, Chebyshev iteration is an iterative method for solving
5+
* systems of linear equations Ax = b. It is designed for systems where the
6+
* matrix A is symmetric positive-definite (SPD).
7+
*
8+
* <p>
9+
* This method is a "polynomial acceleration" method, meaning it finds the
10+
* optimal polynomial to apply to the residual to accelerate convergence.
11+
*
12+
* <p>
13+
* It requires knowledge of the bounds of the eigenvalues of the matrix A:
14+
* m(A) (smallest eigenvalue) and M(A) (largest eigenvalue).
15+
*
16+
* <p>
17+
* Wikipedia: https://en.wikipedia.org/wiki/Chebyshev_iteration
18+
*
19+
* @author Mitrajit Ghorui(KeyKyrios)
20+
*/
21+
public final class ChebyshevIteration {
22+
23+
private ChebyshevIteration() {
24+
}
25+
26+
/**
27+
* Solves the linear system Ax = b using the Chebyshev iteration method.
28+
*
29+
* <p>
30+
* NOTE: The matrix A *must* be symmetric positive-definite (SPD) for this
31+
* algorithm to converge.
32+
*
33+
* @param a The matrix A (must be square, SPD).
34+
* @param b The vector b.
35+
* @param x0 The initial guess vector.
36+
* @param minEigenvalue The smallest eigenvalue of A (m(A)).
37+
* @param maxEigenvalue The largest eigenvalue of A (M(A)).
38+
* @param maxIterations The maximum number of iterations to perform.
39+
* @param tolerance The desired tolerance for the residual norm.
40+
* @return The solution vector x.
41+
* @throws IllegalArgumentException if matrix/vector dimensions are
42+
* incompatible,
43+
* if maxIterations <= 0, or if eigenvalues are invalid (e.g., minEigenvalue
44+
* <= 0, maxEigenvalue <= minEigenvalue).
45+
*/
46+
public static double[] solve(double[][] a, double[] b, double[] x0, double minEigenvalue, double maxEigenvalue, int maxIterations, double tolerance) {
47+
validateInputs(a, b, x0, minEigenvalue, maxEigenvalue, maxIterations, tolerance);
48+
49+
int n = b.length;
50+
double[] x = x0.clone();
51+
double[] r = vectorSubtract(b, matrixVectorMultiply(a, x));
52+
double[] p = new double[n];
53+
54+
double d = (maxEigenvalue + minEigenvalue) / 2.0;
55+
double c = (maxEigenvalue - minEigenvalue) / 2.0;
56+
57+
double alpha = 0.0;
58+
double alphaPrev = 0.0;
59+
60+
for (int k = 0; k < maxIterations; k++) {
61+
double residualNorm = vectorNorm(r);
62+
if (residualNorm < tolerance) {
63+
return x; // Solution converged
64+
}
65+
66+
if (k == 0) {
67+
alpha = 1.0 / d;
68+
System.arraycopy(r, 0, p, 0, n); // p = r
69+
} else {
70+
double beta = c * alphaPrev / 2.0 * (c * alphaPrev / 2.0);
71+
alpha = 1.0 / (d - beta / alphaPrev);
72+
double[] pUpdate = scalarMultiply(beta / alphaPrev, p);
73+
p = vectorAdd(r, pUpdate); // p = r + (beta / alphaPrev) * p
74+
}
75+
76+
double[] xUpdate = scalarMultiply(alpha, p);
77+
x = vectorAdd(x, xUpdate); // x = x + alpha * p
78+
79+
// Recompute residual for accuracy
80+
r = vectorSubtract(b, matrixVectorMultiply(a, x));
81+
alphaPrev = alpha;
82+
}
83+
84+
return x; // Return best guess after maxIterations
85+
}
86+
87+
/**
88+
* Validates the inputs for the Chebyshev solver.
89+
*/
90+
private static void validateInputs(double[][] a, double[] b, double[] x0, double minEigenvalue, double maxEigenvalue, int maxIterations, double tolerance) {
91+
int n = a.length;
92+
if (n == 0) {
93+
throw new IllegalArgumentException("Matrix A cannot be empty.");
94+
}
95+
if (n != a[0].length) {
96+
throw new IllegalArgumentException("Matrix A must be square.");
97+
}
98+
if (n != b.length) {
99+
throw new IllegalArgumentException("Matrix A and vector b dimensions do not match.");
100+
}
101+
if (n != x0.length) {
102+
throw new IllegalArgumentException("Matrix A and vector x0 dimensions do not match.");
103+
}
104+
if (minEigenvalue <= 0) {
105+
throw new IllegalArgumentException("Smallest eigenvalue must be positive (matrix must be positive-definite).");
106+
}
107+
if (maxEigenvalue <= minEigenvalue) {
108+
throw new IllegalArgumentException("Max eigenvalue must be strictly greater than min eigenvalue.");
109+
}
110+
if (maxIterations <= 0) {
111+
throw new IllegalArgumentException("Max iterations must be positive.");
112+
}
113+
if (tolerance <= 0) {
114+
throw new IllegalArgumentException("Tolerance must be positive.");
115+
}
116+
}
117+
118+
// --- Vector/Matrix Helper Methods ---
119+
/**
120+
* Computes the product of a matrix A and a vector v (Av).
121+
*/
122+
private static double[] matrixVectorMultiply(double[][] a, double[] v) {
123+
int n = a.length;
124+
double[] result = new double[n];
125+
for (int i = 0; i < n; i++) {
126+
double sum = 0;
127+
for (int j = 0; j < n; j++) {
128+
sum += a[i][j] * v[j];
129+
}
130+
result[i] = sum;
131+
}
132+
return result;
133+
}
134+
135+
/**
136+
* Computes the subtraction of two vectors (v1 - v2).
137+
*/
138+
private static double[] vectorSubtract(double[] v1, double[] v2) {
139+
int n = v1.length;
140+
double[] result = new double[n];
141+
for (int i = 0; i < n; i++) {
142+
result[i] = v1[i] - v2[i];
143+
}
144+
return result;
145+
}
146+
147+
/**
148+
* Computes the addition of two vectors (v1 + v2).
149+
*/
150+
private static double[] vectorAdd(double[] v1, double[] v2) {
151+
int n = v1.length;
152+
double[] result = new double[n];
153+
for (int i = 0; i < n; i++) {
154+
result[i] = v1[i] + v2[i];
155+
}
156+
return result;
157+
}
158+
159+
/**
160+
* Computes the product of a scalar and a vector (s * v).
161+
*/
162+
private static double[] scalarMultiply(double scalar, double[] v) {
163+
int n = v.length;
164+
double[] result = new double[n];
165+
for (int i = 0; i < n; i++) {
166+
result[i] = scalar * v[i];
167+
}
168+
return result;
169+
}
170+
171+
/**
172+
* Computes the L2 norm (Euclidean norm) of a vector.
173+
*/
174+
private static double vectorNorm(double[] v) {
175+
double sumOfSquares = 0;
176+
for (double val : v) {
177+
sumOfSquares += val * val;
178+
}
179+
return Math.sqrt(sumOfSquares);
180+
}
181+
}
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
package com.thealgorithms.maths;
2+
3+
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
4+
import static org.junit.jupiter.api.Assertions.assertThrows;
5+
6+
import org.junit.jupiter.api.Test;
7+
8+
public class ChebyshevIterationTest {
9+
10+
@Test
11+
public void testSolveSimple2x2Diagonal() {
12+
double[][] a = {{2, 0}, {0, 1}};
13+
double[] b = {2, 2};
14+
double[] x0 = {0, 0};
15+
double minEig = 1.0;
16+
double maxEig = 2.0;
17+
int maxIter = 50;
18+
double tol = 1e-9;
19+
double[] expected = {1.0, 2.0};
20+
21+
double[] result = ChebyshevIteration.solve(a, b, x0, minEig, maxEig, maxIter, tol);
22+
assertArrayEquals(expected, result, 1e-9);
23+
}
24+
25+
@Test
26+
public void testSolve2x2Symmetric() {
27+
double[][] a = {{4, 1}, {1, 3}};
28+
double[] b = {1, 2};
29+
double[] x0 = {0, 0};
30+
double minEig = (7.0 - Math.sqrt(5.0)) / 2.0;
31+
double maxEig = (7.0 + Math.sqrt(5.0)) / 2.0;
32+
int maxIter = 100;
33+
double tol = 1e-10;
34+
double[] expected = {1.0 / 11.0, 7.0 / 11.0};
35+
36+
double[] result = ChebyshevIteration.solve(a, b, x0, minEig, maxEig, maxIter, tol);
37+
assertArrayEquals(expected, result, 1e-9);
38+
}
39+
40+
@Test
41+
public void testAlreadyAtSolution() {
42+
double[][] a = {{2, 0}, {0, 1}};
43+
double[] b = {2, 2};
44+
double[] x0 = {1, 2};
45+
double minEig = 1.0;
46+
double maxEig = 2.0;
47+
int maxIter = 10;
48+
double tol = 1e-5;
49+
double[] expected = {1.0, 2.0};
50+
51+
double[] result = ChebyshevIteration.solve(a, b, x0, minEig, maxEig, maxIter, tol);
52+
assertArrayEquals(expected, result, 0.0);
53+
}
54+
55+
@Test
56+
public void testMismatchedDimensionsAB() {
57+
double[][] a = {{1, 0}, {0, 1}};
58+
double[] b = {1};
59+
double[] x0 = {0, 0};
60+
assertThrows(IllegalArgumentException.class, () -> ChebyshevIteration.solve(a, b, x0, 1, 2, 10, 1e-5));
61+
}
62+
63+
@Test
64+
public void testMismatchedDimensionsAX() {
65+
double[][] a = {{1, 0}, {0, 1}};
66+
double[] b = {1, 1};
67+
double[] x0 = {0};
68+
assertThrows(IllegalArgumentException.class, () -> ChebyshevIteration.solve(a, b, x0, 1, 2, 10, 1e-5));
69+
}
70+
71+
@Test
72+
public void testNonSquareMatrix() {
73+
double[][] a = {{1, 0, 0}, {0, 1, 0}};
74+
double[] b = {1, 1};
75+
double[] x0 = {0, 0};
76+
assertThrows(IllegalArgumentException.class, () -> ChebyshevIteration.solve(a, b, x0, 1, 2, 10, 1e-5));
77+
}
78+
79+
@Test
80+
public void testInvalidEigenvalues() {
81+
double[][] a = {{1, 0}, {0, 1}};
82+
double[] b = {1, 1};
83+
double[] x0 = {0, 0};
84+
assertThrows(IllegalArgumentException.class, () -> ChebyshevIteration.solve(a, b, x0, 2, 1, 10, 1e-5));
85+
assertThrows(IllegalArgumentException.class, () -> ChebyshevIteration.solve(a, b, x0, 1, 1, 10, 1e-5));
86+
}
87+
88+
@Test
89+
public void testNonPositiveDefinite() {
90+
double[][] a = {{1, 0}, {0, 1}};
91+
double[] b = {1, 1};
92+
double[] x0 = {0, 0};
93+
assertThrows(IllegalArgumentException.class, () -> ChebyshevIteration.solve(a, b, x0, 0, 1, 10, 1e-5));
94+
assertThrows(IllegalArgumentException.class, () -> ChebyshevIteration.solve(a, b, x0, -1, 1, 10, 1e-5));
95+
}
96+
97+
@Test
98+
public void testInvalidIterationCount() {
99+
double[][] a = {{1, 0}, {0, 1}};
100+
double[] b = {1, 1};
101+
double[] x0 = {0, 0};
102+
assertThrows(IllegalArgumentException.class, () -> ChebyshevIteration.solve(a, b, x0, 1, 2, 0, 1e-5));
103+
assertThrows(IllegalArgumentException.class, () -> ChebyshevIteration.solve(a, b, x0, 1, 2, -1, 1e-5));
104+
}
105+
}

0 commit comments

Comments
 (0)