Skip to content

Commit 6ae0ce0

Browse files
committed
Implement initial part of qr
1 parent 0e2fc96 commit 6ae0ce0

File tree

4 files changed

+87
-48
lines changed

4 files changed

+87
-48
lines changed

ext/lapack.c

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,20 @@ VALUE nm_geqrf(int argc, VALUE* argv) {
1111
int n = matrix->shape[1]; //no. of cols
1212
int lda = m, info = -1;
1313

14-
nmatrix* result_qr = nmatrix_new(matrix->dtype, matrix->stype, 2, 0, NULL, NULL);
15-
nmatrix* result_tau = nmatrix_new(matrix->dtype, matrix->stype, 2, 0, NULL, NULL);
14+
nmatrix* result_qr = nmatrix_new(matrix->dtype, matrix->stype, 2, matrix->count, matrix->shape, NULL);
15+
nmatrix* result_tau = nmatrix_new(matrix->dtype, matrix->stype, 1, min(m, n), NULL, NULL);
16+
result_tau->shape[0] = min(m, n);
1617

1718
switch(matrix->dtype) {
1819
case nm_bool:
1920
{
2021
//not supported error
22+
break;
2123
}
2224
case nm_int:
2325
{
2426
//not supported error
27+
break;
2528
}
2629
case nm_float32:
2730
{
@@ -37,6 +40,7 @@ VALUE nm_geqrf(int argc, VALUE* argv) {
3740
rb_ary_push(ary, Data_Wrap_Struct(NMatrix, NULL, nm_free, result_qr));
3841
rb_ary_push(ary, Data_Wrap_Struct(NMatrix, NULL, nm_free, result_tau));
3942
return ary;
43+
break;
4044
}
4145
case nm_float64:
4246
{
@@ -47,11 +51,17 @@ VALUE nm_geqrf(int argc, VALUE* argv) {
4751

4852
result_qr->elements = elements;
4953
result_tau->elements = tau_elements;
50-
51-
VALUE ary = rb_ary_new();
52-
rb_ary_push(ary, Data_Wrap_Struct(NMatrix, NULL, nm_free, result_qr));
53-
rb_ary_push(ary, Data_Wrap_Struct(NMatrix, NULL, nm_free, result_tau));
54-
return ary;
54+
// for(size_t i = 0; i < matrix->count; ++i) {
55+
// printf("%f\n", elements[i]);
56+
// }
57+
// for(size_t i = 0; i < min(m, n); ++i) {
58+
// printf("%f\n", tau_elements[i]);
59+
// }
60+
61+
VALUE qr = Data_Wrap_Struct(NMatrix, NULL, nm_free, result_qr);
62+
VALUE tau = Data_Wrap_Struct(NMatrix, NULL, nm_free, result_tau);
63+
return rb_ary_new3(2, qr, tau);
64+
break;
5565
}
5666
case nm_complex32:
5767
{
@@ -67,6 +77,7 @@ VALUE nm_geqrf(int argc, VALUE* argv) {
6777
rb_ary_push(ary, Data_Wrap_Struct(NMatrix, NULL, nm_free, result_qr));
6878
rb_ary_push(ary, Data_Wrap_Struct(NMatrix, NULL, nm_free, result_tau));
6979
return ary;
80+
break;
7081
}
7182
case nm_complex64:
7283
{
@@ -82,13 +93,13 @@ VALUE nm_geqrf(int argc, VALUE* argv) {
8293
rb_ary_push(ary, Data_Wrap_Struct(NMatrix, NULL, nm_free, result_qr));
8394
rb_ary_push(ary, Data_Wrap_Struct(NMatrix, NULL, nm_free, result_tau));
8495
return ary;
96+
break;
8597
}
8698
}
99+
return INT2NUM(-1);
87100
}
88101

89102

90-
91-
92103
// TODO: m should represent no. of rows and n no. of cols throughout
93104

94105
/*
@@ -418,15 +429,6 @@ VALUE nm_cholesky_solve(VALUE self){
418429
return Qnil;
419430
}
420431

421-
/*
422-
* Computes the QR decomposition of matrix.
423-
* Args:
424-
* - input matrix, type: NMatrix
425-
* - mode, type: String
426-
* - pivoting, type: Boolean
427-
*
428-
* returns the vector of type NMatrix with values of unknowns
429-
*/
430-
VALUE nm_qr(VALUE self, VALUE mode, VALUE pivoting){
431-
432+
VALUE nm_qr(VALUE self){
433+
return Qnil;
432434
}

ext/ruby_nmatrix.c

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,14 @@ nmatrix* nmatrix_new(
133133
matrix->count = count;
134134

135135
matrix->shape = ALLOC_N(size_t, matrix->ndims);
136-
for(size_t i = 0; i < ndims; ++i) {
137-
matrix->shape[i] = shape[i];
136+
if(shape != NULL) {
137+
for(size_t i = 0; i < ndims; ++i) {
138+
matrix->shape[i] = shape[i];
139+
}
140+
}
141+
142+
if(elements == NULL) {
143+
return matrix;
138144
}
139145

140146
switch(dtype) {
@@ -211,16 +217,22 @@ nmatrix* matrix_copy(nmatrix* original_matrix) {
211217
matrix->count = original_matrix->count;
212218

213219
matrix->shape = ALLOC_N(size_t, matrix->ndims);
214-
for(size_t i = 0; i < ndims; ++i) {
215-
matrix->shape[i] = original_matrix->shape[i];
220+
if(original_matrix->shape != NULL) {
221+
for(size_t i = 0; i < original_matrix->ndims; ++i) {
222+
matrix->shape[i] = original_matrix->shape[i];
223+
}
216224
}
217225

218-
switch(dtype) {
226+
if(original_matrix->elements == NULL) {
227+
return matrix;
228+
}
229+
230+
switch(original_matrix->dtype) {
219231
case nm_bool:
220232
{
221233
bool* temp_elements = (bool*)original_matrix->elements;
222234
bool* matrix_elements = ALLOC_N(bool, matrix->count);
223-
for(size_t i = 0; i < count; ++i) {
235+
for(size_t i = 0; i < original_matrix->count; ++i) {
224236
matrix_elements[i] = temp_elements[i];
225237
}
226238
matrix->elements = matrix_elements;
@@ -230,7 +242,7 @@ nmatrix* matrix_copy(nmatrix* original_matrix) {
230242
{
231243
int* temp_elements = (int*)original_matrix->elements;
232244
int* matrix_elements = ALLOC_N(int, matrix->count);
233-
for(size_t i = 0; i < count; ++i) {
245+
for(size_t i = 0; i < original_matrix->count; ++i) {
234246
matrix_elements[i] = temp_elements[i];
235247
}
236248
matrix->elements = matrix_elements;
@@ -240,7 +252,7 @@ nmatrix* matrix_copy(nmatrix* original_matrix) {
240252
{
241253
float* temp_elements = (float*)original_matrix->elements;
242254
float* matrix_elements = ALLOC_N(float, matrix->count);
243-
for(size_t i = 0; i < count; ++i) {
255+
for(size_t i = 0; i < original_matrix->count; ++i) {
244256
matrix_elements[i] = temp_elements[i];
245257
}
246258
matrix->elements = matrix_elements;
@@ -250,7 +262,7 @@ nmatrix* matrix_copy(nmatrix* original_matrix) {
250262
{
251263
double* temp_elements = (double*)original_matrix->elements;
252264
double* matrix_elements = ALLOC_N(double, matrix->count);
253-
for(size_t i = 0; i < count; ++i) {
265+
for(size_t i = 0; i < original_matrix->count; ++i) {
254266
matrix_elements[i] = temp_elements[i];
255267
}
256268
matrix->elements = matrix_elements;
@@ -260,7 +272,7 @@ nmatrix* matrix_copy(nmatrix* original_matrix) {
260272
{
261273
float complex* temp_elements = (float complex*)original_matrix->elements;
262274
float complex* matrix_elements = ALLOC_N(float complex, matrix->count);
263-
for(size_t i = 0; i < count; ++i) {
275+
for(size_t i = 0; i < original_matrix->count; ++i) {
264276
matrix_elements[i] = temp_elements[i];
265277
}
266278
matrix->elements = matrix_elements;
@@ -270,7 +282,7 @@ nmatrix* matrix_copy(nmatrix* original_matrix) {
270282
{
271283
double complex* temp_elements = (double complex*)original_matrix->elements;
272284
double complex* matrix_elements = ALLOC_N(double complex, matrix->count);
273-
for(size_t i = 0; i < count; ++i) {
285+
for(size_t i = 0; i < original_matrix->count; ++i) {
274286
matrix_elements[i] = temp_elements[i];
275287
}
276288
matrix->elements = matrix_elements;
@@ -425,7 +437,7 @@ VALUE nm_geqp3(int argc, VALUE* argv);
425437
VALUE nm_orth(VALUE self);
426438
VALUE nm_cholesky(VALUE self);
427439
VALUE nm_cholesky_solve(VALUE self);
428-
VALUE nm_qr(VALUE self, VALUE mode, VALUE pivoting);
440+
VALUE nm_qr(VALUE self);
429441

430442
VALUE nm_accessor_get(int argc, VALUE* argv, VALUE self);
431443
VALUE nm_accessor_set(int argc, VALUE* argv, VALUE self);
@@ -484,15 +496,15 @@ void Init_nmatrix() {
484496
rb_define_singleton_method(NumRuby, "ones", ones_nmatrix, -1);
485497
// rb_define_singleton_method(NumRuby, "matrix", nmatrix_init, -1);
486498

487-
Lapack = rb_define_module("NumRuby::Linalg::Lapack");
499+
Lapack = rb_define_module_under(NumRuby, "Lapack");
488500
rb_define_singleton_method(Lapack, "geqrf", nm_geqrf, -1);
489-
rb_define_singleton_method(Lapack, "orgqr", nm_orgqr, -1);
490-
rb_define_singleton_method(Lapack, "geqp3", nm_geqp3, -1);
501+
// rb_define_singleton_method(Lapack, "orgqr", nm_orgqr, -1);
502+
// rb_define_singleton_method(Lapack, "geqp3", nm_geqp3, -1);
491503
// rb_define_singleton_method(Lapack, "geqrf", nm_geqrf, -1);
492504
// rb_define_singleton_method(Lapack, "geqrf", nm_geqrf, -1);
493505
// rb_define_singleton_method(Lapack, "geqrf", nm_geqrf, -1);
494506

495-
Blas = rb_define_module("NumRuby::Linalg::Blas");
507+
Blas = rb_define_module("Blas");
496508

497509
/*
498510
* Exception raised when there's a problem with data.

lib/nmatrix/lapack.rb

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,43 +45,69 @@ def self.eigvalsh
4545

4646
end
4747

48-
def self.lu
48+
# Matrix Decomposition
49+
50+
51+
def self.lu(matrix)
4952

5053
end
5154

52-
def self.lu_factor
55+
def self.lu_factor(matrix)
5356

5457
end
5558

56-
def self.lu_solve
59+
def self.lu_solve(matrix, rhs_val)
5760

5861
end
5962

60-
def self.svd
63+
# Computes the QR decomposition of matrix.
64+
# Args:
65+
# - input matrix, type: NMatrix
66+
# - mode, type: String
67+
# - pivoting, type: Boolean
68+
def self.svd(matrix)
6169

6270
end
6371

64-
def self.svdvals
72+
def self.svdvals(matrix)
6573

6674
end
6775

68-
def self.diagsvd
76+
def self.diagsvd(matrix)
6977

7078
end
7179

72-
def self.orth
80+
def self.orth(matrix)
7381

7482
end
7583

76-
def self.cholesky
84+
def self.cholesky(matrix)
7785

7886
end
7987

80-
def self.cholesky_solve
88+
def self.cholesky_solve(matrix)
8189

8290
end
8391

84-
def self.qr
92+
# Computes the QR decomposition of matrix.
93+
# Args:
94+
# - input matrix, type: NMatrix
95+
# - mode, type: String
96+
# - pivoting, type: Boolean
97+
def self.qr(matrix, mode: "full", pivoting: false)
98+
if not ['full', 'r', 'economic', 'raw'].include?(mode.downcase)
99+
raise("Invalid mode. Should be one of ['full', 'r', 'economic', 'raw']")
100+
end
101+
if not matrix.is_a?(NMatrix)
102+
raise("Invalid matrix. Not of type NMatrix")
103+
end
104+
if matrix.dim != 2
105+
raise("Invalid shape of matrix. Should be 2.")
106+
end
107+
m, n = matrix.shape
85108

109+
if pivoting == false
110+
qr, tau = NumRuby::Lapack.geqrf(matrix)
111+
end
86112
end
87113
end

test/numruby_test.rb

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def test_append
4141
x = NumRuby.array [2,2],[0, 0, 0, 0]
4242
y = NumRuby.array [2,2],[0, 0, 0, 0]
4343
result = NumRuby.append(x, y)
44-
assert_equal result, NumRuby.array([4,2],
45-
[0, 0, 0, 0, 0, 0, 0, 0])
44+
assert_equal result, NumRuby.array([4,2], [0, 0, 0, 0, 0, 0, 0, 0])
4645
end
4746
end

0 commit comments

Comments
 (0)