Skip to content

Commit ecb07f7

Browse files
committed
Add tests for posv and gesv
1 parent 9e1d648 commit ecb07f7

File tree

2 files changed

+45
-10
lines changed

2 files changed

+45
-10
lines changed

ext/lapack.c

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -875,14 +875,14 @@ VALUE nm_posv(int argc, VALUE* argv) {
875875
Data_Get_Struct(argv[1], nmatrix, matrix_b);
876876

877877
int m_b = matrix_b->shape[0]; //no. of rows
878-
int n_b = matrix_b->shape[1]; //no. of cols
878+
int n_b = 1; //no. of cols
879879
int lda_b = n_b;
880880

881881
bool lower = (bool)RTEST(argv[2]);
882882
char uplo = lower ? 'L' : 'U';
883883

884884
nmatrix* result_c = nmatrix_new(matrix_a->dtype, matrix_a->stype, 2, matrix_a->count, matrix_a->shape, NULL);
885-
nmatrix* result_x = nmatrix_new(matrix_b->dtype, matrix_b->stype, 2, matrix_b->count, matrix_b->shape, NULL);
885+
nmatrix* result_x = nmatrix_new(matrix_b->dtype, matrix_b->stype, 1, matrix_b->count, matrix_b->shape, NULL);
886886

887887
switch(matrix_a->dtype) {
888888
case nm_bool:
@@ -991,9 +991,6 @@ VALUE nm_gesv(int argc, VALUE* argv) {
991991
int n_b = matrix_b->shape[1]; //no. of cols
992992
int lda_b = n_b;
993993

994-
bool lower = (bool)RTEST(argv[2]);
995-
char uplo = lower ? 'L' : 'U';
996-
997994
nmatrix* result_lu = nmatrix_new(matrix_a->dtype, matrix_a->stype, 2, matrix_a->count, matrix_a->shape, NULL);
998995
nmatrix* result_x = nmatrix_new(matrix_b->dtype, matrix_b->stype, 2, matrix_b->count, matrix_b->shape, NULL);
999996
nmatrix* result_ipiv = nmatrix_new(nm_int, matrix_a->stype, 1, n_a, NULL, NULL);
@@ -1095,6 +1092,19 @@ VALUE nm_gesv(int argc, VALUE* argv) {
10951092
* the infinity norm, or the element of largest absolute value of a
10961093
* real matrix A.
10971094
*
1095+
* LANGE = ( max(abs(A(i,j))), NORM = 'M' or 'm'
1096+
* (
1097+
* ( norm1(A), NORM = '1', 'O' or 'o'
1098+
* (
1099+
* ( normI(A), NORM = 'I' or 'i'
1100+
* (
1101+
* ( normF(A), NORM = 'F', 'f', 'E' or 'e'
1102+
*
1103+
* where norm1 denotes the one norm of a matrix (maximum column sum),
1104+
* normI denotes the infinity norm of a matrix (maximum row sum) and
1105+
* normF denotes the Frobenius norm of a matrix (square root of sum of
1106+
* squares). Note that max(abs(A(i,j))) is not a consistent matrix norm.
1107+
*
10981108
*/
10991109
VALUE nm_lange(int argc, VALUE* argv) {
11001110
nmatrix* matrix;
@@ -1104,7 +1114,8 @@ VALUE nm_lange(int argc, VALUE* argv) {
11041114
int n = matrix->shape[1]; //no. of cols
11051115
int lda = n;
11061116

1107-
char norm = NUM2CHAR(argv[1]);
1117+
char* norm_str = StringValueCStr(argv[1]);
1118+
char norm = norm_str[0];
11081119

11091120
switch(matrix->dtype) {
11101121
case nm_bool:
@@ -1131,13 +1142,13 @@ VALUE nm_lange(int argc, VALUE* argv) {
11311142
}
11321143
case nm_complex32:
11331144
{
1134-
float complex val = LAPACKE_clange(LAPACK_ROW_MAJOR, norm, m, n, matrix->elements, lda);
1145+
float val = LAPACKE_clange(LAPACK_ROW_MAJOR, norm, m, n, matrix->elements, lda);
11351146
return val;
11361147
break;
11371148
}
11381149
case nm_complex64:
11391150
{
1140-
double complex val = LAPACKE_zlange(LAPACK_ROW_MAJOR, norm, m, n, matrix->elements, lda);
1151+
double val = LAPACKE_zlange(LAPACK_ROW_MAJOR, norm, m, n, matrix->elements, lda);
11411152
return val;
11421153
break;
11431154
}

test/lapack_test.rb

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,15 +144,39 @@ def test_gelss
144144
end
145145

146146
def test_posv
147-
147+
matrix = NMatrix.new [3, 3], [2, -1, 0, -1, 2, -1, 0, -1, 2]
148+
b = NMatrix.new [3], [1, 2, 3]
149+
150+
c, x = NumRuby::Lapack.posv(matrix, b, true)
151+
x_soln = NMatrix.new [3], [2.5, 4.0, 3.5]
152+
assert_equal x, x_soln
153+
154+
matrix = NMatrix.new [3, 3], [2, -1, 0, -1, 2, -1, 0, -1, 2]
155+
b = NMatrix.new [3], [1, 2, 3]
156+
157+
c, x = NumRuby::Lapack.posv(matrix, b, false)
158+
x_soln = NMatrix.new [3], [2.5, 4.0, 3.5]
159+
assert_equal x, x_soln
148160
end
149161

150162
def test_gesv
163+
matrix = NMatrix.new [3, 3], [2, -1, 0, -1, 2, -1, 0, -1, 2]
164+
b = NMatrix.new [3], [1, 2, 3]
165+
166+
lu, x, ipiv = NumRuby::Lapack.posv(matrix, b)
167+
x_soln = NMatrix.new [3], [2.5, 4.0, 3.5]
168+
assert_equal x, x_soln
151169

170+
matrix = NMatrix.new [3, 3], [2, -1, 0, -1, 2, -1, 0, -1, 2]
171+
b = NMatrix.new [3], [1, 2, 3]
172+
173+
lu, x, ipiv = NumRuby::Lapack.posv(matrix, b)
174+
x_soln = NMatrix.new [3], [2.5, 4.0, 3.5]
175+
assert_equal x, x_soln
152176
end
153177

154178
def test_lange
155-
179+
# TODO: fix nm_lange
156180
end
157181

158182
def test_pinv

0 commit comments

Comments
 (0)