Skip to content

Commit 2bfda60

Browse files
committed
Implement getri
1 parent e48b1dd commit 2bfda60

File tree

1 file changed

+80
-0
lines changed

1 file changed

+80
-0
lines changed

ext/lapack.c

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,86 @@ VALUE nm_getrs(int argc, VALUE* argv) {
741741
return INT2NUM(-1);
742742
}
743743

744+
/*
745+
*
746+
*
747+
*/
748+
VALUE nm_getri(int argc, VALUE* argv) {
749+
nmatrix* matrix_lu;
750+
Data_Get_Struct(argv[0], nmatrix, matrix_lu);
751+
752+
int m = matrix_lu->shape[0]; //no. of rows
753+
int n = matrix_lu->shape[1]; //no. of cols
754+
int lda = n, info = -1;
755+
756+
nmatrix* matrix_ipiv;
757+
Data_Get_Struct(argv[1], nmatrix, matrix_ipiv);
758+
759+
nmatrix* result = nmatrix_new(matrix_lu->dtype, matrix_lu->stype, 2, matrix_lu->count, matrix_lu->shape, NULL);
760+
761+
switch(matrix_lu->dtype) {
762+
case nm_bool:
763+
{
764+
//not supported error
765+
break;
766+
}
767+
case nm_int:
768+
{
769+
//not supported error
770+
break;
771+
}
772+
case nm_float32:
773+
{
774+
float* elements = ALLOC_N(float, matrix_lu->count);
775+
memcpy(elements, matrix_lu->elements, sizeof(float)*matrix_lu->count);
776+
int* elements_ipiv = (int*)matrix_ipiv->elements;
777+
info = LAPACKE_sgetri(LAPACK_ROW_MAJOR, n, elements, lda, elements_ipiv);
778+
779+
result->elements = elements;
780+
781+
return Data_Wrap_Struct(NMatrix, NULL, nm_free, result);
782+
break;
783+
}
784+
case nm_float64:
785+
{
786+
double* elements = ALLOC_N(double, matrix_lu->count);
787+
memcpy(elements, matrix_lu->elements, sizeof(double)*matrix_lu->count);
788+
int* elements_ipiv = (int*)matrix_ipiv->elements;
789+
info = LAPACKE_dgetri(LAPACK_ROW_MAJOR, n, elements, lda, elements_ipiv);
790+
791+
result->elements = elements;
792+
793+
return Data_Wrap_Struct(NMatrix, NULL, nm_free, result);
794+
break;
795+
}
796+
case nm_complex32:
797+
{
798+
float complex* elements = ALLOC_N(float complex, matrix_lu->count);
799+
memcpy(elements, matrix_lu->elements, sizeof(float complex)*matrix_lu->count);
800+
int* elements_ipiv = (int*)matrix_ipiv->elements;
801+
info = LAPACKE_cgetri(LAPACK_ROW_MAJOR, n, elements, lda, elements_ipiv);
802+
803+
result->elements = elements;
804+
805+
return Data_Wrap_Struct(NMatrix, NULL, nm_free, result);
806+
break;
807+
}
808+
case nm_complex64:
809+
{
810+
double complex* elements = ALLOC_N(double complex, matrix_lu->count);
811+
memcpy(elements, matrix_lu->elements, sizeof(double complex)*matrix_lu->count);
812+
int* elements_ipiv = (int*)matrix_ipiv->elements;
813+
info = LAPACKE_zgetri(LAPACK_ROW_MAJOR, n, elements, lda, elements_ipiv);
814+
815+
result->elements = elements;
816+
817+
return Data_Wrap_Struct(NMatrix, NULL, nm_free, result);
818+
break;
819+
}
820+
}
821+
return INT2NUM(-1);
822+
}
823+
744824
// TODO: m should represent no. of rows and n no. of cols throughout
745825

746826
/*

0 commit comments

Comments
 (0)