@@ -741,6 +741,86 @@ VALUE nm_getrs(int argc, VALUE* argv) {
741741 return INT2NUM (-1 );
742742}
743743
744+ /*
745+ *
746+ *
747+ */
748+ VALUE nm_getri (int argc , VALUE * argv ) {
749+ nmatrix * matrix_lu ;
750+ Data_Get_Struct (argv [0 ], nmatrix , matrix_lu );
751+
752+ int m = matrix_lu -> shape [0 ]; //no. of rows
753+ int n = matrix_lu -> shape [1 ]; //no. of cols
754+ int lda = n , info = -1 ;
755+
756+ nmatrix * matrix_ipiv ;
757+ Data_Get_Struct (argv [1 ], nmatrix , matrix_ipiv );
758+
759+ nmatrix * result = nmatrix_new (matrix_lu -> dtype , matrix_lu -> stype , 2 , matrix_lu -> count , matrix_lu -> shape , NULL );
760+
761+ switch (matrix_lu -> dtype ) {
762+ case nm_bool :
763+ {
764+ //not supported error
765+ break ;
766+ }
767+ case nm_int :
768+ {
769+ //not supported error
770+ break ;
771+ }
772+ case nm_float32 :
773+ {
774+ float * elements = ALLOC_N (float , matrix_lu -> count );
775+ memcpy (elements , matrix_lu -> elements , sizeof (float )* matrix_lu -> count );
776+ int * elements_ipiv = (int * )matrix_ipiv -> elements ;
777+ info = LAPACKE_sgetri (LAPACK_ROW_MAJOR , n , elements , lda , elements_ipiv );
778+
779+ result -> elements = elements ;
780+
781+ return Data_Wrap_Struct (NMatrix , NULL , nm_free , result );
782+ break ;
783+ }
784+ case nm_float64 :
785+ {
786+ double * elements = ALLOC_N (double , matrix_lu -> count );
787+ memcpy (elements , matrix_lu -> elements , sizeof (double )* matrix_lu -> count );
788+ int * elements_ipiv = (int * )matrix_ipiv -> elements ;
789+ info = LAPACKE_dgetri (LAPACK_ROW_MAJOR , n , elements , lda , elements_ipiv );
790+
791+ result -> elements = elements ;
792+
793+ return Data_Wrap_Struct (NMatrix , NULL , nm_free , result );
794+ break ;
795+ }
796+ case nm_complex32 :
797+ {
798+ float complex * elements = ALLOC_N (float complex , matrix_lu -> count );
799+ memcpy (elements , matrix_lu -> elements , sizeof (float complex )* matrix_lu -> count );
800+ int * elements_ipiv = (int * )matrix_ipiv -> elements ;
801+ info = LAPACKE_cgetri (LAPACK_ROW_MAJOR , n , elements , lda , elements_ipiv );
802+
803+ result -> elements = elements ;
804+
805+ return Data_Wrap_Struct (NMatrix , NULL , nm_free , result );
806+ break ;
807+ }
808+ case nm_complex64 :
809+ {
810+ double complex * elements = ALLOC_N (double complex , matrix_lu -> count );
811+ memcpy (elements , matrix_lu -> elements , sizeof (double complex )* matrix_lu -> count );
812+ int * elements_ipiv = (int * )matrix_ipiv -> elements ;
813+ info = LAPACKE_zgetri (LAPACK_ROW_MAJOR , n , elements , lda , elements_ipiv );
814+
815+ result -> elements = elements ;
816+
817+ return Data_Wrap_Struct (NMatrix , NULL , nm_free , result );
818+ break ;
819+ }
820+ }
821+ return INT2NUM (-1 );
822+ }
823+
744824// TODO: m should represent no. of rows and n no. of cols throughout
745825
746826/*
0 commit comments