- Notifications
You must be signed in to change notification settings - Fork 10
Implement NumRuby::Linalg #30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| | @@ -3,4 +3,6 @@ | |
| Gemfile.lock | ||
| /.yardoc | ||
| .rake_tasks~ | ||
| *.so | ||
| *.so | ||
| .vscode | ||
| | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,20 @@ | ||
| module NumRuby::Linalg | ||
| def self.inv(obj) | ||
| if obj.is_a?(NMatrix) | ||
| return obj.invert | ||
| def self.inv(matrix) | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you could put this in functions. For example: def self.inv(matrix) is_a_matrix?(matrix) valid_shape?(matrix, 2) ... end private def is_a_matrix?(matrix) raise("Invalid matrix. Not of type NMatrix.") unless matrix.is_a?(NMatrix) end def valid_shape?(matrix, shape) raise("Invalid shape of matrix. Should be 2.") unless matrix.dim != shape endWDYT? 🤔 Member Author There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, seems alright. Functions for these would be a good idea as these are quite frequent. I'll add these once the PR is complete. Thanks. | ||
| if not matrix.is_a?(NMatrix) | ||
| raise("Invalid matrix. Not of type NMatrix.") | ||
| end | ||
| if matrix.dim != 2 | ||
| raise("Invalid shape of matrix. Should be 2.") | ||
| end | ||
| if matrix.shape[0] != matrix.shape[1] | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what do you want validate here? Member There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Square matrix, I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so I think you could put this in a function as: def square_matrix?(matrix) raise("Invalid shape. Expected square matrix.") unless matrix.shape[0] != matrix.shape[1] endWDYT? 🤔 | ||
| raise("Invalid shape. Expected square matrix.") | ||
| end | ||
| m, n = matrix.shape | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what do you main by Member Author There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I could name these fields as | ||
| | ||
| lu, ipiv = NumRuby::Lapack.getrf(matrix) | ||
| inv_a = NumRuby::Lapack.getri(lu, ipiv) | ||
| | ||
| return inv_a | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As this is the last line in this function, it isn't necessary to put Member Author There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe including | ||
| end | ||
| | ||
| def self.dot(lha, rha) | ||
| | @@ -13,12 +25,22 @@ def self.norm | |
| | ||
| end | ||
| | ||
| def self.solve | ||
| | ||
| def self.solve(a, b, sym_pos: False, lower: False, assume_a: "gen", transposed: False) | ||
| # TODO: implement this and remove NMatrix.solve | ||
| end | ||
| | ||
| def self.det | ||
| def self.det(matrix) | ||
| if not matrix.is_a?(NMatrix) | ||
| raise("Invalid matrix. Not of type NMatrix.") | ||
| end | ||
| if matrix.dim != 2 | ||
| raise("Invalid shape of matrix. Should be 2.") | ||
| end | ||
| if matrix.shape[0] != matrix.shape[1] | ||
| raise("Invalid shape. Expected square matrix.") | ||
| end | ||
| | ||
| return matrix.det | ||
| end | ||
| | ||
| def self.least_square | ||
| | @@ -48,23 +70,47 @@ def self.eigvalsh | |
| # Matrix Decomposition | ||
| | ||
| | ||
| def self.lu(matrix) | ||
| def self.lu(matrix, permute_l: False) | ||
| if not matrix.is_a?(NMatrix) | ||
| raise("Invalid matrix. Not of type NMatrix.") | ||
| end | ||
| if matrix.dim != 2 | ||
| raise("Invalid shape of matrix. Should be 2.") | ||
| end | ||
| | ||
| lu, ipiv = NumRuby::Linalg.getrf(matrix) | ||
| | ||
| # TODO: calulate p, l, u | ||
| end | ||
| | ||
| def self.lu_factor(matrix) | ||
| if not matrix.is_a?(NMatrix) | ||
| raise("Invalid matrix. Not of type NMatrix.") | ||
| end | ||
| if matrix.dim != 2 | ||
| raise("Invalid shape of matrix. Should be 2.") | ||
| end | ||
| if matrix.shape[0] != matrix.shape[1] | ||
| raise("Invalid shape. Expected square matrix.") | ||
| end | ||
| | ||
| lu, ipiv = NumRuby::Linalg.getrf(matrix) | ||
| | ||
| return [lu, ipiv] | ||
| end | ||
| | ||
| def self.lu_solve(matrix, rhs_val) | ||
| def self.lu_solve(lu, ipiv, b, trans: 0) | ||
| if lu.shape[0] != b.shape[0] | ||
| raise("Incompatibel dimensions.") | ||
| end | ||
| | ||
| x = NumRuby::Lapack.getrs(lu, ipiv, b, trans) | ||
| return x | ||
| end | ||
| | ||
| # Computes the QR decomposition of matrix. | ||
| # Computes the SVD decomposition of matrix. | ||
| # Args: | ||
| # - input matrix, type: NMatrix | ||
| # - mode, type: String | ||
| # - pivoting, type: Boolean | ||
| def self.svd(matrix) | ||
| | ||
| end | ||
| | @@ -89,11 +135,24 @@ def self.cholesky_solve(matrix) | |
| | ||
| end | ||
| | ||
| # Computes the QR decomposition of matrix. | ||
| # Computes QR decomposition of a matrix. | ||
| # | ||
| # Calculates the decomposition A = Q*R where Q is unitary/orthogonal and R is upper triangular. | ||
| # | ||
| # Args: | ||
| # - input matrix, type: NMatrix | ||
| # - matrix, type: NMatrix | ||
| # Matrix to be decomposed | ||
| # - mode, type: String | ||
| # Determines what information is to be returned: either both Q and R | ||
| # ('full', default), only R ('r') or both Q and R but computed in | ||
| # economy-size ('economic', see Notes). The final option 'raw' | ||
| # (added in Scipy 0.11) makes the function return two matrices | ||
| # (Q, TAU) in the internal format used by LAPACK. | ||
| # - pivoting, type: Boolean | ||
| # Whether or not factorization should include pivoting for rank-revealing | ||
| # qr decomposition. If pivoting, compute the decomposition | ||
| # A*P = Q*R as above, but where P is chosen such that the diagonal | ||
| # of R is non-increasing. | ||
| def self.qr(matrix, mode: "full", pivoting: false) | ||
| if not ['full', 'r', 'economic', 'raw'].include?(mode.downcase) | ||
| raise("Invalid mode. Should be one of ['full', 'r', 'economic', 'raw']") | ||
| | @@ -106,9 +165,42 @@ def self.qr(matrix, mode: "full", pivoting: false) | |
| end | ||
| m, n = matrix.shape | ||
| | ||
| if pivoting == false | ||
| if pivoting == true | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you could put this as: if pivoting ... endWDYT? 🤔 | ||
| qr, tau, jpvt = NumRuby::Lapack.geqp3(matrix) | ||
| jpvt -= 1 | ||
| else | ||
| qr, tau = NumRuby::Lapack.geqrf(matrix) | ||
| end | ||
| | ||
| # calculate R here for both pivot true & false | ||
| | ||
| if ['economic', 'raw'].include?(mode.downcase) or m < n | ||
| r = NumRuby.triu(matrix) | ||
| else | ||
| r = NumRuby.triu(matrix[0...n, 0...n]) | ||
| end | ||
| | ||
| if pivoting == true | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This validation is similar to line Member There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed. cc @Uditgulati Member Author There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A function for an if statement? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🤔 yes, maybe move only the conditional is weird, but I think that you could the conditional blocks in functions. WDYT? 🤔 | ||
| rj = r, jpvt | ||
| else | ||
| rj = r | ||
| end | ||
| | ||
| if mode == 'r' | ||
| return rj | ||
| elsif mode == 'raw' | ||
| return [qr, tau] | ||
| end | ||
| | ||
| if m < n | ||
| q = NumRuby::Lapack.orgqr(qr[0...m, 0...m], tau) | ||
| elsif mode == 'economic' | ||
| q = NumRuby::Lapack.orgqr(qr, tau) | ||
| else | ||
| # TODO: Implement slice view and set slice | ||
| q = NumRuby::Lapack.orgqr(qr, tau) | ||
| end | ||
| | ||
| return q, rj | ||
| end | ||
| end | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you shouldn't add this here, you could place this in a
.global-gitignoreas you can see here https://gist.github.com/subfuzion/db7f57fff2fb6998a16c WDYT? 🤔There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
*.soshould be in.gitignore,.vscodecan be shifted to.global-gitignore.