Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 45 additions & 120 deletions ext/ruby_nmatrix.c
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,10 @@ VALUE nm_each_with_indices(VALUE self);
//VALUE nm_each_stored_with_indices(VALUE self);
//VALUE nm_each_ordered_stored_with_indices(VALUE self);
//VALUE nm_map_stored(VALUE self);
VALUE nm_each_rank(VALUE self, VALUE dimension_idx);
VALUE nm_each_row(VALUE self);
VALUE nm_each_column(VALUE self);
//VALUE nm_each_rank(VALUE self);
//VALUE nm_each_layer(VALUE self);
VALUE nm_each_layer(VALUE self);

//VALUE nm_get_row(VALUE self, VALUE row_number);
//VALUE nm_get_column(VALUE self, VALUE column_number);
Expand Down Expand Up @@ -297,6 +297,11 @@ void get_dense_from_dia(const double* data, const size_t rows,
const size_t cols, const size_t* offset,
double* elements);

//forwards for internally used functions
void get_slice(nmatrix* nmat, size_t* lower, size_t* upper, nmatrix* slice);
size_t get_index(nmatrix* nmat, VALUE* indices);


void Init_nmatrix() {

///////////////////////
Expand Down Expand Up @@ -361,8 +366,10 @@ void Init_nmatrix() {
//rb_define_method(NMatrix, "each_stored_with_indices", nm_each_stored_with_indices, 0);
//rb_define_method(NMatrix, "map_stored", nm_map_stored, 0);
//rb_define_method(NMatrix, "each_ordered_stored_with_indices", nm_each_ordered_stored_with_indices, 0);
rb_define_method(NMatrix, "each_rank", nm_each_rank, 1);
rb_define_method(NMatrix, "each_row", nm_each_row, 0);
rb_define_method(NMatrix, "each_column", nm_each_column, 0);
rb_define_method(NMatrix, "each_layer", nm_each_layer, 0);

//rb_define_method(NMatrix, "row", nm_get_row, 1);
//rb_define_method(NMatrix, "column", nm_get_column, 1);
Expand Down Expand Up @@ -997,138 +1004,56 @@ VALUE nm_map_stored(VALUE self) {
return Qnil;
}

VALUE nm_each_row(VALUE self) {
VALUE nm_each_rank(VALUE self, VALUE dimension_idx) {
nmatrix* input;
Data_Get_Struct(self, nmatrix, input);

VALUE* curr_row = ALLOC_N(VALUE, input->shape[1]);
for (size_t index = 0; index < input->shape[1]; index++){
curr_row[index] = INT2NUM(0);
}

switch(input->stype){
case nm_dense:
{
switch (input->dtype) {
case nm_bool:
{
bool* elements = (bool*)input->elements;
for (size_t row_index = 0; row_index < input->shape[0]; row_index++){

for (size_t index = 0; index < input->shape[1]; index++){
curr_row[index] = elements[(row_index * input->shape[1]) + index] ? Qtrue : Qfalse;
}
//rb_yield(DBL2NUM(elements[row_index]));
rb_yield(rb_ary_new4(input->shape[1], curr_row));
}
break;
}
case nm_int:
{
int* elements = (int*)input->elements;
for (size_t row_index = 0; row_index < input->shape[0]; row_index++){
size_t dim_idx = NUM2SIZET(dimension_idx);

for (size_t index = 0; index < input->shape[1]; index++){
curr_row[index] = INT2NUM(elements[(row_index * input->shape[1]) + index]);
}
//rb_yield(DBL2NUM(elements[row_index]));
rb_yield(rb_ary_new4(input->shape[1], curr_row));
}
break;
}
case nm_float64:
{
double* elements = (double*)input->elements;
for (size_t row_index = 0; row_index < input->shape[0]; row_index++){
nmatrix* result = ALLOC(nmatrix);
result->dtype = input->dtype;
result->stype = input->stype;
result->count = (input->count / input->shape[dim_idx]);
result->ndims = (input->ndims) - 1;
result->shape = ALLOC_N(size_t, result->ndims);

for (size_t index = 0; index < input->shape[1]; index++){
curr_row[index] = DBL2NUM(elements[(row_index * input->shape[1]) + index]);
}
//rb_yield(DBL2NUM(elements[row_index]));
rb_yield(rb_ary_new4(input->shape[1], curr_row));
}
for(size_t i = 0; i < result->ndims; ++i) {
if(i < dim_idx)
result->shape[i] = input->shape[i];
else
result->shape[i] = input->shape[i + 1];
}

break;
}
case nm_float32:
{
float* elements = (float*)input->elements;
for (size_t row_index = 0; row_index < input->shape[0]; row_index++){
size_t* lower_indices = ALLOC_N(size_t, input->ndims);
size_t* upper_indices = ALLOC_N(size_t, input->ndims);

for (size_t index = 0; index < input->shape[1]; index++){
curr_row[index] = DBL2NUM(elements[(row_index * input->shape[1]) + index]);
}
//rb_yield(DBL2NUM(elements[row_index]));
rb_yield(rb_ary_new4(input->shape[1], curr_row));
}
for (size_t index = 0; index < input->count; index++){
rb_yield(DBL2NUM(elements[index]));
}
break;
}
case nm_complex32:
{
float complex* elements = (float complex*)input->elements;
for (size_t row_index = 0; row_index < input->shape[0]; row_index++){
for(size_t i = 0; i < input->ndims; ++i) {
lower_indices[i] = 0;
upper_indices[i] = input->shape[i] - 1;
}
lower_indices[dim_idx] = upper_indices[dim_idx] = -1;

for (size_t index = 0; index < input->shape[1]; index++){
curr_row[index] = DBL2NUM(creal(elements[(row_index * input->shape[1]) + index])), DBL2NUM(cimag(elements[(row_index * input->shape[1]) + index]));
}
//rb_yield(DBL2NUM(elements[row_index]));
rb_yield(rb_ary_new4(input->shape[1], curr_row));
}
break;
}
case nm_complex64:
{
double complex* elements = (double complex*)input->elements;
for (size_t row_index = 0; row_index < input->shape[0]; row_index++){
for(size_t i = 0; i < input->shape[dim_idx]; ++i) {
lower_indices[dim_idx] = upper_indices[dim_idx] = i;

for (size_t index = 0; index < input->shape[1]; index++){
curr_row[index] = DBL2NUM(creal(elements[(row_index * input->shape[1]) + index])), DBL2NUM(cimag(elements[(row_index * input->shape[1]) + index]));
}
//rb_yield(DBL2NUM(elements[row_index]));
rb_yield(rb_ary_new4(input->shape[1], curr_row));
}
break;
}
}
break;
}
case nm_sparse: //this is to be modified later during sparse work
{
switch(input->dtype){
case nm_float64:
{
double* elements = (double*)input->sp->csr->elements;
for (size_t row_index = 0; row_index < input->shape[0]; row_index++){
get_slice(input, lower_indices, upper_indices, result);

for (size_t index = 0; index < input->shape[1]; index++){
curr_row[index] = DBL2NUM(elements[(row_index * input->shape[1]) + index]);
}
//rb_yield(DBL2NUM(elements[row_index]));
rb_yield(rb_ary_new4(input->shape[1], curr_row));
}
break;
}
}
break;
}
rb_yield(Data_Wrap_Struct(NMatrix, NULL, nm_free, result));
}

return self;
}

VALUE nm_each_column(VALUE self) {
return Qnil;
VALUE nm_each_row(VALUE self) {
return nm_each_rank(self, SIZET2NUM(0));
}

VALUE nm_each_rank(VALUE self) {
return Qnil;
VALUE nm_each_column(VALUE self) {
return nm_each_rank(self, SIZET2NUM(1));
}

VALUE nm_each_layer(VALUE self) {
return Qnil;
return nm_each_rank(self, SIZET2NUM(2));
}

/*
Expand Down Expand Up @@ -2028,7 +1953,7 @@ void get_slice(nmatrix* nmat, size_t* lower, size_t* upper, nmatrix* slice){
size_t curr_index_value = NUM2SIZET(state_array[state_index]);

if(curr_index_value == upper[state_index]){
curr_index_value = lower[i];
curr_index_value = lower[state_index];
state_array[state_index] = SIZET2NUM(curr_index_value);
}
else{
Expand Down Expand Up @@ -2059,7 +1984,7 @@ void get_slice(nmatrix* nmat, size_t* lower, size_t* upper, nmatrix* slice){
size_t curr_index_value = NUM2SIZET(state_array[state_index]);

if(curr_index_value == upper[state_index]){
curr_index_value = lower[i];
curr_index_value = lower[state_index];
state_array[state_index] = SIZET2NUM(curr_index_value);
}
else{
Expand Down Expand Up @@ -2090,7 +2015,7 @@ void get_slice(nmatrix* nmat, size_t* lower, size_t* upper, nmatrix* slice){
size_t curr_index_value = NUM2SIZET(state_array[state_index]);

if(curr_index_value == upper[state_index]){
curr_index_value = lower[i];
curr_index_value = lower[state_index];
state_array[state_index] = SIZET2NUM(curr_index_value);
}
else{
Expand Down Expand Up @@ -2121,7 +2046,7 @@ void get_slice(nmatrix* nmat, size_t* lower, size_t* upper, nmatrix* slice){
size_t curr_index_value = NUM2SIZET(state_array[state_index]);

if(curr_index_value == upper[state_index]){
curr_index_value = lower[i];
curr_index_value = lower[state_index];
state_array[state_index] = SIZET2NUM(curr_index_value);
}
else{
Expand Down Expand Up @@ -2152,7 +2077,7 @@ void get_slice(nmatrix* nmat, size_t* lower, size_t* upper, nmatrix* slice){
size_t curr_index_value = NUM2SIZET(state_array[state_index]);

if(curr_index_value == upper[state_index]){
curr_index_value = lower[i];
curr_index_value = lower[state_index];
state_array[state_index] = SIZET2NUM(curr_index_value);
}
else{
Expand Down Expand Up @@ -2183,7 +2108,7 @@ void get_slice(nmatrix* nmat, size_t* lower, size_t* upper, nmatrix* slice){
size_t curr_index_value = NUM2SIZET(state_array[state_index]);

if(curr_index_value == upper[state_index]){
curr_index_value = lower[i];
curr_index_value = lower[state_index];
state_array[state_index] = SIZET2NUM(curr_index_value);
}
else{
Expand Down
7 changes: 7 additions & 0 deletions test/nmatrix_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ def setup
@b = NMatrix.new [2,2],[true, true, false, true], :nm_bool
@m = NMatrix.new [2,2,2],[1, 2, 3, 4, 5, 6, 7, 8]
@n = NMatrix.new [2,1,2],[1, 2, 3, 4]
@s = NMatrix.new [2, 2],[1, 2, 3, 4]
@s_int = NMatrix.new [2, 2],[1, 2, 3, 4], :nm_int
end

def test_dims
Expand Down Expand Up @@ -52,4 +54,9 @@ def test_accessor_set
assert_equal @n[0,0,1], 12
end

def test_slicing
assert_equal @m[0, 0..1, 0..1], @s
assert_equal @m[0, 0..1, 0..1], @s_int
end

end