Skip to content

Commit 2547a63

Browse files
authored
Merge pull request #22 from Uditgulati/iterators
Iterators nm_each_rank, nm_each_row, nm_each_column and nm_each_layer
2 parents ebb61b5 + b7540cd commit 2547a63

File tree

2 files changed

+52
-120
lines changed

2 files changed

+52
-120
lines changed

ext/ruby_nmatrix.c

Lines changed: 45 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,10 @@ VALUE nm_each_with_indices(VALUE self);
183183
//VALUE nm_each_stored_with_indices(VALUE self);
184184
//VALUE nm_each_ordered_stored_with_indices(VALUE self);
185185
//VALUE nm_map_stored(VALUE self);
186+
VALUE nm_each_rank(VALUE self, VALUE dimension_idx);
186187
VALUE nm_each_row(VALUE self);
187188
VALUE nm_each_column(VALUE self);
188-
//VALUE nm_each_rank(VALUE self);
189-
//VALUE nm_each_layer(VALUE self);
189+
VALUE nm_each_layer(VALUE self);
190190

191191
//VALUE nm_get_row(VALUE self, VALUE row_number);
192192
//VALUE nm_get_column(VALUE self, VALUE column_number);
@@ -297,6 +297,11 @@ void get_dense_from_dia(const void* data_t, const size_t rows,
297297
const size_t cols, const size_t* offset,
298298
void* elements_t, nm_dtype);
299299

300+
//forwards for internally used functions
301+
void get_slice(nmatrix* nmat, size_t* lower, size_t* upper, nmatrix* slice);
302+
size_t get_index(nmatrix* nmat, VALUE* indices);
303+
304+
300305
void Init_nmatrix() {
301306

302307
///////////////////////
@@ -361,8 +366,10 @@ void Init_nmatrix() {
361366
//rb_define_method(NMatrix, "each_stored_with_indices", nm_each_stored_with_indices, 0);
362367
//rb_define_method(NMatrix, "map_stored", nm_map_stored, 0);
363368
//rb_define_method(NMatrix, "each_ordered_stored_with_indices", nm_each_ordered_stored_with_indices, 0);
369+
rb_define_method(NMatrix, "each_rank", nm_each_rank, 1);
364370
rb_define_method(NMatrix, "each_row", nm_each_row, 0);
365371
rb_define_method(NMatrix, "each_column", nm_each_column, 0);
372+
rb_define_method(NMatrix, "each_layer", nm_each_layer, 0);
366373

367374
//rb_define_method(NMatrix, "row", nm_get_row, 1);
368375
//rb_define_method(NMatrix, "column", nm_get_column, 1);
@@ -997,138 +1004,56 @@ VALUE nm_map_stored(VALUE self) {
9971004
return Qnil;
9981005
}
9991006

1000-
VALUE nm_each_row(VALUE self) {
1007+
VALUE nm_each_rank(VALUE self, VALUE dimension_idx) {
10011008
nmatrix* input;
10021009
Data_Get_Struct(self, nmatrix, input);
10031010

1004-
VALUE* curr_row = ALLOC_N(VALUE, input->shape[1]);
1005-
for (size_t index = 0; index < input->shape[1]; index++){
1006-
curr_row[index] = INT2NUM(0);
1007-
}
1008-
1009-
switch(input->stype){
1010-
case nm_dense:
1011-
{
1012-
switch (input->dtype) {
1013-
case nm_bool:
1014-
{
1015-
bool* elements = (bool*)input->elements;
1016-
for (size_t row_index = 0; row_index < input->shape[0]; row_index++){
1017-
1018-
for (size_t index = 0; index < input->shape[1]; index++){
1019-
curr_row[index] = elements[(row_index * input->shape[1]) + index] ? Qtrue : Qfalse;
1020-
}
1021-
//rb_yield(DBL2NUM(elements[row_index]));
1022-
rb_yield(rb_ary_new4(input->shape[1], curr_row));
1023-
}
1024-
break;
1025-
}
1026-
case nm_int:
1027-
{
1028-
int* elements = (int*)input->elements;
1029-
for (size_t row_index = 0; row_index < input->shape[0]; row_index++){
1011+
size_t dim_idx = NUM2SIZET(dimension_idx);
10301012

1031-
for (size_t index = 0; index < input->shape[1]; index++){
1032-
curr_row[index] = INT2NUM(elements[(row_index * input->shape[1]) + index]);
1033-
}
1034-
//rb_yield(DBL2NUM(elements[row_index]));
1035-
rb_yield(rb_ary_new4(input->shape[1], curr_row));
1036-
}
1037-
break;
1038-
}
1039-
case nm_float64:
1040-
{
1041-
double* elements = (double*)input->elements;
1042-
for (size_t row_index = 0; row_index < input->shape[0]; row_index++){
1013+
nmatrix* result = ALLOC(nmatrix);
1014+
result->dtype = input->dtype;
1015+
result->stype = input->stype;
1016+
result->count = (input->count / input->shape[dim_idx]);
1017+
result->ndims = (input->ndims) - 1;
1018+
result->shape = ALLOC_N(size_t, result->ndims);
10431019

1044-
for (size_t index = 0; index < input->shape[1]; index++){
1045-
curr_row[index] = DBL2NUM(elements[(row_index * input->shape[1]) + index]);
1046-
}
1047-
//rb_yield(DBL2NUM(elements[row_index]));
1048-
rb_yield(rb_ary_new4(input->shape[1], curr_row));
1049-
}
1020+
for(size_t i = 0; i < result->ndims; ++i) {
1021+
if(i < dim_idx)
1022+
result->shape[i] = input->shape[i];
1023+
else
1024+
result->shape[i] = input->shape[i + 1];
1025+
}
10501026

1051-
break;
1052-
}
1053-
case nm_float32:
1054-
{
1055-
float* elements = (float*)input->elements;
1056-
for (size_t row_index = 0; row_index < input->shape[0]; row_index++){
1027+
size_t* lower_indices = ALLOC_N(size_t, input->ndims);
1028+
size_t* upper_indices = ALLOC_N(size_t, input->ndims);
10571029

1058-
for (size_t index = 0; index < input->shape[1]; index++){
1059-
curr_row[index] = DBL2NUM(elements[(row_index * input->shape[1]) + index]);
1060-
}
1061-
//rb_yield(DBL2NUM(elements[row_index]));
1062-
rb_yield(rb_ary_new4(input->shape[1], curr_row));
1063-
}
1064-
for (size_t index = 0; index < input->count; index++){
1065-
rb_yield(DBL2NUM(elements[index]));
1066-
}
1067-
break;
1068-
}
1069-
case nm_complex32:
1070-
{
1071-
float complex* elements = (float complex*)input->elements;
1072-
for (size_t row_index = 0; row_index < input->shape[0]; row_index++){
1030+
for(size_t i = 0; i < input->ndims; ++i) {
1031+
lower_indices[i] = 0;
1032+
upper_indices[i] = input->shape[i] - 1;
1033+
}
1034+
lower_indices[dim_idx] = upper_indices[dim_idx] = -1;
10731035

1074-
for (size_t index = 0; index < input->shape[1]; index++){
1075-
curr_row[index] = DBL2NUM(creal(elements[(row_index * input->shape[1]) + index])), DBL2NUM(cimag(elements[(row_index * input->shape[1]) + index]));
1076-
}
1077-
//rb_yield(DBL2NUM(elements[row_index]));
1078-
rb_yield(rb_ary_new4(input->shape[1], curr_row));
1079-
}
1080-
break;
1081-
}
1082-
case nm_complex64:
1083-
{
1084-
double complex* elements = (double complex*)input->elements;
1085-
for (size_t row_index = 0; row_index < input->shape[0]; row_index++){
1036+
for(size_t i = 0; i < input->shape[dim_idx]; ++i) {
1037+
lower_indices[dim_idx] = upper_indices[dim_idx] = i;
10861038

1087-
for (size_t index = 0; index < input->shape[1]; index++){
1088-
curr_row[index] = DBL2NUM(creal(elements[(row_index * input->shape[1]) + index])), DBL2NUM(cimag(elements[(row_index * input->shape[1]) + index]));
1089-
}
1090-
//rb_yield(DBL2NUM(elements[row_index]));
1091-
rb_yield(rb_ary_new4(input->shape[1], curr_row));
1092-
}
1093-
break;
1094-
}
1095-
}
1096-
break;
1097-
}
1098-
case nm_sparse: //this is to be modified later during sparse work
1099-
{
1100-
switch(input->dtype){
1101-
case nm_float64:
1102-
{
1103-
double* elements = (double*)input->sp->csr->elements;
1104-
for (size_t row_index = 0; row_index < input->shape[0]; row_index++){
1039+
get_slice(input, lower_indices, upper_indices, result);
11051040

1106-
for (size_t index = 0; index < input->shape[1]; index++){
1107-
curr_row[index] = DBL2NUM(elements[(row_index * input->shape[1]) + index]);
1108-
}
1109-
//rb_yield(DBL2NUM(elements[row_index]));
1110-
rb_yield(rb_ary_new4(input->shape[1], curr_row));
1111-
}
1112-
break;
1113-
}
1114-
}
1115-
break;
1116-
}
1041+
rb_yield(Data_Wrap_Struct(NMatrix, NULL, nm_free, result));
11171042
}
11181043

11191044
return self;
11201045
}
11211046

1122-
VALUE nm_each_column(VALUE self) {
1123-
return Qnil;
1047+
VALUE nm_each_row(VALUE self) {
1048+
return nm_each_rank(self, SIZET2NUM(0));
11241049
}
11251050

1126-
VALUE nm_each_rank(VALUE self) {
1127-
return Qnil;
1051+
VALUE nm_each_column(VALUE self) {
1052+
return nm_each_rank(self, SIZET2NUM(1));
11281053
}
11291054

11301055
VALUE nm_each_layer(VALUE self) {
1131-
return Qnil;
1056+
return nm_each_rank(self, SIZET2NUM(2));
11321057
}
11331058

11341059
/*
@@ -2269,7 +2194,7 @@ void get_slice(nmatrix* nmat, size_t* lower, size_t* upper, nmatrix* slice){
22692194
size_t curr_index_value = NUM2SIZET(state_array[state_index]);
22702195

22712196
if(curr_index_value == upper[state_index]){
2272-
curr_index_value = lower[i];
2197+
curr_index_value = lower[state_index];
22732198
state_array[state_index] = SIZET2NUM(curr_index_value);
22742199
}
22752200
else{
@@ -2300,7 +2225,7 @@ void get_slice(nmatrix* nmat, size_t* lower, size_t* upper, nmatrix* slice){
23002225
size_t curr_index_value = NUM2SIZET(state_array[state_index]);
23012226

23022227
if(curr_index_value == upper[state_index]){
2303-
curr_index_value = lower[i];
2228+
curr_index_value = lower[state_index];
23042229
state_array[state_index] = SIZET2NUM(curr_index_value);
23052230
}
23062231
else{
@@ -2331,7 +2256,7 @@ void get_slice(nmatrix* nmat, size_t* lower, size_t* upper, nmatrix* slice){
23312256
size_t curr_index_value = NUM2SIZET(state_array[state_index]);
23322257

23332258
if(curr_index_value == upper[state_index]){
2334-
curr_index_value = lower[i];
2259+
curr_index_value = lower[state_index];
23352260
state_array[state_index] = SIZET2NUM(curr_index_value);
23362261
}
23372262
else{
@@ -2362,7 +2287,7 @@ void get_slice(nmatrix* nmat, size_t* lower, size_t* upper, nmatrix* slice){
23622287
size_t curr_index_value = NUM2SIZET(state_array[state_index]);
23632288

23642289
if(curr_index_value == upper[state_index]){
2365-
curr_index_value = lower[i];
2290+
curr_index_value = lower[state_index];
23662291
state_array[state_index] = SIZET2NUM(curr_index_value);
23672292
}
23682293
else{
@@ -2393,7 +2318,7 @@ void get_slice(nmatrix* nmat, size_t* lower, size_t* upper, nmatrix* slice){
23932318
size_t curr_index_value = NUM2SIZET(state_array[state_index]);
23942319

23952320
if(curr_index_value == upper[state_index]){
2396-
curr_index_value = lower[i];
2321+
curr_index_value = lower[state_index];
23972322
state_array[state_index] = SIZET2NUM(curr_index_value);
23982323
}
23992324
else{
@@ -2424,7 +2349,7 @@ void get_slice(nmatrix* nmat, size_t* lower, size_t* upper, nmatrix* slice){
24242349
size_t curr_index_value = NUM2SIZET(state_array[state_index]);
24252350

24262351
if(curr_index_value == upper[state_index]){
2427-
curr_index_value = lower[i];
2352+
curr_index_value = lower[state_index];
24282353
state_array[state_index] = SIZET2NUM(curr_index_value);
24292354
}
24302355
else{

test/nmatrix_test.rb

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ def setup
77
@b = NMatrix.new [2,2],[true, true, false, true], :nm_bool
88
@m = NMatrix.new [2,2,2],[1, 2, 3, 4, 5, 6, 7, 8]
99
@n = NMatrix.new [2,1,2],[1, 2, 3, 4]
10+
@s = NMatrix.new [2, 2],[1, 2, 3, 4]
11+
@s_int = NMatrix.new [2, 2],[1, 2, 3, 4], :nm_int
1012
end
1113

1214
def test_dims
@@ -52,4 +54,9 @@ def test_accessor_set
5254
assert_equal @n[0,0,1], 12
5355
end
5456

57+
def test_slicing
58+
assert_equal @m[0, 0..1, 0..1], @s
59+
assert_equal @m[0, 0..1, 0..1], @s_int
60+
end
61+
5562
end

0 commit comments

Comments
 (0)