Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@
Gemfile.lock
/.yardoc
.rake_tasks~
*.so
*.so
.vscode

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you shouldn't add this here, you could place this in a .global-gitignore as you can see here https://gist.github.com/subfuzion/db7f57fff2fb6998a16c WDYT? 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think *.so should be in .gitignore, .vscode can be shifted to .global-gitignore.


120 changes: 106 additions & 14 deletions lib/nmatrix/lapack.rb
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
module NumRuby::Linalg
def self.inv(obj)
if obj.is_a?(NMatrix)
return obj.invert
def self.inv(matrix)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you could put this in functions. For example:

def self.inv(matrix) is_a_matrix?(matrix) valid_shape?(matrix, 2) ... end private def is_a_matrix?(matrix) raise("Invalid matrix. Not of type NMatrix.") unless matrix.is_a?(NMatrix) end def valid_shape?(matrix, shape) raise("Invalid shape of matrix. Should be 2.") unless matrix.dim != shape end

WDYT? 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, seems alright. Functions for these would be a good idea as these are quite frequent. I'll add these once the PR is complete. Thanks.

if not matrix.is_a?(NMatrix)
raise("Invalid matrix. Not of type NMatrix.")
end
if matrix.dim != 2
raise("Invalid shape of matrix. Should be 2.")
end
if matrix.shape[0] != matrix.shape[1]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you want validate here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Square matrix, I think

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so I think you could put this in a function as:

def square_matrix?(matrix) raise("Invalid shape. Expected square matrix.") unless matrix.shape[0] != matrix.shape[1] end

WDYT? 🤔

raise("Invalid shape. Expected square matrix.")
end
m, n = matrix.shape

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you main by m and n?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

m is rows count and n is columns count.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I could name these fields as row and account_ columns to make this more readable. WDYT? 🤔


lu, ipiv = NumRuby::Lapack.getrf(matrix)
inv_a = NumRuby::Lapack.getri(lu, ipiv)

return inv_a

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this is the last line in this function, it isn't necessary to put return. You could put inv_a and this variable is automatically returned.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe including return makes it more readable.

end

def self.dot(lha, rha)
Expand All @@ -13,12 +25,22 @@ def self.norm

end

def self.solve

def self.solve(a, b, sym_pos: False, lower: False, assume_a: "gen", transposed: False)
# TODO: implement this and remove NMatrix.solve
end

def self.det
def self.det(matrix)
if not matrix.is_a?(NMatrix)
raise("Invalid matrix. Not of type NMatrix.")
end
if matrix.dim != 2
raise("Invalid shape of matrix. Should be 2.")
end
if matrix.shape[0] != matrix.shape[1]
raise("Invalid shape. Expected square matrix.")
end

return matrix.det
end

def self.least_square
Expand Down Expand Up @@ -48,23 +70,47 @@ def self.eigvalsh
# Matrix Decomposition


def self.lu(matrix)
def self.lu(matrix, permute_l: False)
if not matrix.is_a?(NMatrix)
raise("Invalid matrix. Not of type NMatrix.")
end
if matrix.dim != 2
raise("Invalid shape of matrix. Should be 2.")
end

lu, ipiv = NumRuby::Linalg.getrf(matrix)

# TODO: calulate p, l, u
end

def self.lu_factor(matrix)
if not matrix.is_a?(NMatrix)
raise("Invalid matrix. Not of type NMatrix.")
end
if matrix.dim != 2
raise("Invalid shape of matrix. Should be 2.")
end
if matrix.shape[0] != matrix.shape[1]
raise("Invalid shape. Expected square matrix.")
end

lu, ipiv = NumRuby::Linalg.getrf(matrix)

return [lu, ipiv]
end

def self.lu_solve(matrix, rhs_val)
def self.lu_solve(lu, ipiv, b, trans: 0)
if lu.shape[0] != b.shape[0]
raise("Incompatibel dimensions.")
end

x = NumRuby::Lapack.getrs(lu, ipiv, b, trans)
return x
end

# Computes the QR decomposition of matrix.
# Computes the SVD decomposition of matrix.
# Args:
# - input matrix, type: NMatrix
# - mode, type: String
# - pivoting, type: Boolean
def self.svd(matrix)

end
Expand All @@ -89,11 +135,24 @@ def self.cholesky_solve(matrix)

end

# Computes the QR decomposition of matrix.
# Computes QR decomposition of a matrix.
#
# Calculates the decomposition A = Q*R where Q is unitary/orthogonal and R is upper triangular.
#
# Args:
# - input matrix, type: NMatrix
# - matrix, type: NMatrix
# Matrix to be decomposed
# - mode, type: String
# Determines what information is to be returned: either both Q and R
# ('full', default), only R ('r') or both Q and R but computed in
# economy-size ('economic', see Notes). The final option 'raw'
# (added in Scipy 0.11) makes the function return two matrices
# (Q, TAU) in the internal format used by LAPACK.
# - pivoting, type: Boolean
# Whether or not factorization should include pivoting for rank-revealing
# qr decomposition. If pivoting, compute the decomposition
# A*P = Q*R as above, but where P is chosen such that the diagonal
# of R is non-increasing.
def self.qr(matrix, mode: "full", pivoting: false)
if not ['full', 'r', 'economic', 'raw'].include?(mode.downcase)
raise("Invalid mode. Should be one of ['full', 'r', 'economic', 'raw']")
Expand All @@ -106,9 +165,42 @@ def self.qr(matrix, mode: "full", pivoting: false)
end
m, n = matrix.shape

if pivoting == false
if pivoting == true

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you could put this as:

if pivoting ... end

WDYT? 🤔

qr, tau, jpvt = NumRuby::Lapack.geqp3(matrix)
jpvt -= 1
else
qr, tau = NumRuby::Lapack.geqrf(matrix)
end

# calculate R here for both pivot true & false

if ['economic', 'raw'].include?(mode.downcase) or m < n
r = NumRuby.triu(matrix)
else
r = NumRuby.triu(matrix[0...n, 0...n])
end

if pivoting == true

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This validation is similar to line 168. I think you could put this in a function. WDYT? 🤔

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. cc @Uditgulati

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A function for an if statement?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 yes, maybe move only the conditional is weird, but I think that you could the conditional blocks in functions. WDYT? 🤔

rj = r, jpvt
else
rj = r
end

if mode == 'r'
return rj
elsif mode == 'raw'
return [qr, tau]
end

if m < n
q = NumRuby::Lapack.orgqr(qr[0...m, 0...m], tau)
elsif mode == 'economic'
q = NumRuby::Lapack.orgqr(qr, tau)
else
# TODO: Implement slice view and set slice
q = NumRuby::Lapack.orgqr(qr, tau)
end

return q, rj
end
end