@@ -183,10 +183,10 @@ VALUE nm_each_with_indices(VALUE self);
183183//VALUE nm_each_stored_with_indices(VALUE self);
184184//VALUE nm_each_ordered_stored_with_indices(VALUE self);
185185//VALUE nm_map_stored(VALUE self);
186+ VALUE nm_each_rank (VALUE self , VALUE dimension_idx );
186187VALUE nm_each_row (VALUE self );
187188VALUE nm_each_column (VALUE self );
188- //VALUE nm_each_rank(VALUE self);
189- //VALUE nm_each_layer(VALUE self);
189+ VALUE nm_each_layer (VALUE self );
190190
191191//VALUE nm_get_row(VALUE self, VALUE row_number);
192192//VALUE nm_get_column(VALUE self, VALUE column_number);
@@ -297,6 +297,11 @@ void get_dense_from_dia(const void* data_t, const size_t rows,
297297 const size_t cols , const size_t * offset ,
298298 void * elements_t , nm_dtype );
299299
300+ //forwards for internally used functions
301+ void get_slice (nmatrix * nmat , size_t * lower , size_t * upper , nmatrix * slice );
302+ size_t get_index (nmatrix * nmat , VALUE * indices );
303+
304+
300305void Init_nmatrix () {
301306
302307 ///////////////////////
@@ -361,8 +366,10 @@ void Init_nmatrix() {
361366 //rb_define_method(NMatrix, "each_stored_with_indices", nm_each_stored_with_indices, 0);
362367 //rb_define_method(NMatrix, "map_stored", nm_map_stored, 0);
363368 //rb_define_method(NMatrix, "each_ordered_stored_with_indices", nm_each_ordered_stored_with_indices, 0);
369+ rb_define_method (NMatrix , "each_rank" , nm_each_rank , 1 );
364370 rb_define_method (NMatrix , "each_row" , nm_each_row , 0 );
365371 rb_define_method (NMatrix , "each_column" , nm_each_column , 0 );
372+ rb_define_method (NMatrix , "each_layer" , nm_each_layer , 0 );
366373
367374 //rb_define_method(NMatrix, "row", nm_get_row, 1);
368375 //rb_define_method(NMatrix, "column", nm_get_column, 1);
@@ -997,138 +1004,56 @@ VALUE nm_map_stored(VALUE self) {
9971004 return Qnil ;
9981005}
9991006
1000- VALUE nm_each_row (VALUE self ) {
1007+ VALUE nm_each_rank (VALUE self , VALUE dimension_idx ) {
10011008 nmatrix * input ;
10021009 Data_Get_Struct (self , nmatrix , input );
10031010
1004- VALUE * curr_row = ALLOC_N (VALUE , input -> shape [1 ]);
1005- for (size_t index = 0 ; index < input -> shape [1 ]; index ++ ){
1006- curr_row [index ] = INT2NUM (0 );
1007- }
1008-
1009- switch (input -> stype ){
1010- case nm_dense :
1011- {
1012- switch (input -> dtype ) {
1013- case nm_bool :
1014- {
1015- bool * elements = (bool * )input -> elements ;
1016- for (size_t row_index = 0 ; row_index < input -> shape [0 ]; row_index ++ ){
1017-
1018- for (size_t index = 0 ; index < input -> shape [1 ]; index ++ ){
1019- curr_row [index ] = elements [(row_index * input -> shape [1 ]) + index ] ? Qtrue : Qfalse ;
1020- }
1021- //rb_yield(DBL2NUM(elements[row_index]));
1022- rb_yield (rb_ary_new4 (input -> shape [1 ], curr_row ));
1023- }
1024- break ;
1025- }
1026- case nm_int :
1027- {
1028- int * elements = (int * )input -> elements ;
1029- for (size_t row_index = 0 ; row_index < input -> shape [0 ]; row_index ++ ){
1011+ size_t dim_idx = NUM2SIZET (dimension_idx );
10301012
1031- for (size_t index = 0 ; index < input -> shape [1 ]; index ++ ){
1032- curr_row [index ] = INT2NUM (elements [(row_index * input -> shape [1 ]) + index ]);
1033- }
1034- //rb_yield(DBL2NUM(elements[row_index]));
1035- rb_yield (rb_ary_new4 (input -> shape [1 ], curr_row ));
1036- }
1037- break ;
1038- }
1039- case nm_float64 :
1040- {
1041- double * elements = (double * )input -> elements ;
1042- for (size_t row_index = 0 ; row_index < input -> shape [0 ]; row_index ++ ){
1013+ nmatrix * result = ALLOC (nmatrix );
1014+ result -> dtype = input -> dtype ;
1015+ result -> stype = input -> stype ;
1016+ result -> count = (input -> count / input -> shape [dim_idx ]);
1017+ result -> ndims = (input -> ndims ) - 1 ;
1018+ result -> shape = ALLOC_N (size_t , result -> ndims );
10431019
1044- for (size_t index = 0 ; index < input -> shape [ 1 ]; index ++ ) {
1045- curr_row [ index ] = DBL2NUM ( elements [( row_index * input -> shape [ 1 ]) + index ]);
1046- }
1047- //rb_yield(DBL2NUM(elements[row_index]));
1048- rb_yield ( rb_ary_new4 ( input -> shape [1 ], curr_row )) ;
1049- }
1020+ for (size_t i = 0 ; i < result -> ndims ; ++ i ) {
1021+ if ( i < dim_idx )
1022+ result -> shape [ i ] = input -> shape [ i ];
1023+ else
1024+ result -> shape [ i ] = input -> shape [i + 1 ] ;
1025+ }
10501026
1051- break ;
1052- }
1053- case nm_float32 :
1054- {
1055- float * elements = (float * )input -> elements ;
1056- for (size_t row_index = 0 ; row_index < input -> shape [0 ]; row_index ++ ){
1027+ size_t * lower_indices = ALLOC_N (size_t , input -> ndims );
1028+ size_t * upper_indices = ALLOC_N (size_t , input -> ndims );
10571029
1058- for (size_t index = 0 ; index < input -> shape [1 ]; index ++ ){
1059- curr_row [index ] = DBL2NUM (elements [(row_index * input -> shape [1 ]) + index ]);
1060- }
1061- //rb_yield(DBL2NUM(elements[row_index]));
1062- rb_yield (rb_ary_new4 (input -> shape [1 ], curr_row ));
1063- }
1064- for (size_t index = 0 ; index < input -> count ; index ++ ){
1065- rb_yield (DBL2NUM (elements [index ]));
1066- }
1067- break ;
1068- }
1069- case nm_complex32 :
1070- {
1071- float complex * elements = (float complex * )input -> elements ;
1072- for (size_t row_index = 0 ; row_index < input -> shape [0 ]; row_index ++ ){
1030+ for (size_t i = 0 ; i < input -> ndims ; ++ i ) {
1031+ lower_indices [i ] = 0 ;
1032+ upper_indices [i ] = input -> shape [i ] - 1 ;
1033+ }
1034+ lower_indices [dim_idx ] = upper_indices [dim_idx ] = -1 ;
10731035
1074- for (size_t index = 0 ; index < input -> shape [1 ]; index ++ ){
1075- curr_row [index ] = DBL2NUM (creal (elements [(row_index * input -> shape [1 ]) + index ])), DBL2NUM (cimag (elements [(row_index * input -> shape [1 ]) + index ]));
1076- }
1077- //rb_yield(DBL2NUM(elements[row_index]));
1078- rb_yield (rb_ary_new4 (input -> shape [1 ], curr_row ));
1079- }
1080- break ;
1081- }
1082- case nm_complex64 :
1083- {
1084- double complex * elements = (double complex * )input -> elements ;
1085- for (size_t row_index = 0 ; row_index < input -> shape [0 ]; row_index ++ ){
1036+ for (size_t i = 0 ; i < input -> shape [dim_idx ]; ++ i ) {
1037+ lower_indices [dim_idx ] = upper_indices [dim_idx ] = i ;
10861038
1087- for (size_t index = 0 ; index < input -> shape [1 ]; index ++ ){
1088- curr_row [index ] = DBL2NUM (creal (elements [(row_index * input -> shape [1 ]) + index ])), DBL2NUM (cimag (elements [(row_index * input -> shape [1 ]) + index ]));
1089- }
1090- //rb_yield(DBL2NUM(elements[row_index]));
1091- rb_yield (rb_ary_new4 (input -> shape [1 ], curr_row ));
1092- }
1093- break ;
1094- }
1095- }
1096- break ;
1097- }
1098- case nm_sparse : //this is to be modified later during sparse work
1099- {
1100- switch (input -> dtype ){
1101- case nm_float64 :
1102- {
1103- double * elements = (double * )input -> sp -> csr -> elements ;
1104- for (size_t row_index = 0 ; row_index < input -> shape [0 ]; row_index ++ ){
1039+ get_slice (input , lower_indices , upper_indices , result );
11051040
1106- for (size_t index = 0 ; index < input -> shape [1 ]; index ++ ){
1107- curr_row [index ] = DBL2NUM (elements [(row_index * input -> shape [1 ]) + index ]);
1108- }
1109- //rb_yield(DBL2NUM(elements[row_index]));
1110- rb_yield (rb_ary_new4 (input -> shape [1 ], curr_row ));
1111- }
1112- break ;
1113- }
1114- }
1115- break ;
1116- }
1041+ rb_yield (Data_Wrap_Struct (NMatrix , NULL , nm_free , result ));
11171042 }
11181043
11191044 return self ;
11201045}
11211046
1122- VALUE nm_each_column (VALUE self ) {
1123- return Qnil ;
1047+ VALUE nm_each_row (VALUE self ) {
1048+ return nm_each_rank ( self , SIZET2NUM ( 0 )) ;
11241049}
11251050
1126- VALUE nm_each_rank (VALUE self ) {
1127- return Qnil ;
1051+ VALUE nm_each_column (VALUE self ) {
1052+ return nm_each_rank ( self , SIZET2NUM ( 1 )) ;
11281053}
11291054
11301055VALUE nm_each_layer (VALUE self ) {
1131- return Qnil ;
1056+ return nm_each_rank ( self , SIZET2NUM ( 2 )) ;
11321057}
11331058
11341059/*
@@ -2269,7 +2194,7 @@ void get_slice(nmatrix* nmat, size_t* lower, size_t* upper, nmatrix* slice){
22692194 size_t curr_index_value = NUM2SIZET (state_array [state_index ]);
22702195
22712196 if (curr_index_value == upper [state_index ]){
2272- curr_index_value = lower [i ];
2197+ curr_index_value = lower [state_index ];
22732198 state_array [state_index ] = SIZET2NUM (curr_index_value );
22742199 }
22752200 else {
@@ -2300,7 +2225,7 @@ void get_slice(nmatrix* nmat, size_t* lower, size_t* upper, nmatrix* slice){
23002225 size_t curr_index_value = NUM2SIZET (state_array [state_index ]);
23012226
23022227 if (curr_index_value == upper [state_index ]){
2303- curr_index_value = lower [i ];
2228+ curr_index_value = lower [state_index ];
23042229 state_array [state_index ] = SIZET2NUM (curr_index_value );
23052230 }
23062231 else {
@@ -2331,7 +2256,7 @@ void get_slice(nmatrix* nmat, size_t* lower, size_t* upper, nmatrix* slice){
23312256 size_t curr_index_value = NUM2SIZET (state_array [state_index ]);
23322257
23332258 if (curr_index_value == upper [state_index ]){
2334- curr_index_value = lower [i ];
2259+ curr_index_value = lower [state_index ];
23352260 state_array [state_index ] = SIZET2NUM (curr_index_value );
23362261 }
23372262 else {
@@ -2362,7 +2287,7 @@ void get_slice(nmatrix* nmat, size_t* lower, size_t* upper, nmatrix* slice){
23622287 size_t curr_index_value = NUM2SIZET (state_array [state_index ]);
23632288
23642289 if (curr_index_value == upper [state_index ]){
2365- curr_index_value = lower [i ];
2290+ curr_index_value = lower [state_index ];
23662291 state_array [state_index ] = SIZET2NUM (curr_index_value );
23672292 }
23682293 else {
@@ -2393,7 +2318,7 @@ void get_slice(nmatrix* nmat, size_t* lower, size_t* upper, nmatrix* slice){
23932318 size_t curr_index_value = NUM2SIZET (state_array [state_index ]);
23942319
23952320 if (curr_index_value == upper [state_index ]){
2396- curr_index_value = lower [i ];
2321+ curr_index_value = lower [state_index ];
23972322 state_array [state_index ] = SIZET2NUM (curr_index_value );
23982323 }
23992324 else {
@@ -2424,7 +2349,7 @@ void get_slice(nmatrix* nmat, size_t* lower, size_t* upper, nmatrix* slice){
24242349 size_t curr_index_value = NUM2SIZET (state_array [state_index ]);
24252350
24262351 if (curr_index_value == upper [state_index ]){
2427- curr_index_value = lower [i ];
2352+ curr_index_value = lower [state_index ];
24282353 state_array [state_index ] = SIZET2NUM (curr_index_value );
24292354 }
24302355 else {
0 commit comments