3

I am trying to multiply two complex matrices in PyTorch and it seems the torch.matmul functions is not added yet to PyTorch library for complex numbers.

Do you have any recommendation or is there another method to multiply complex matrices in PyTorch?

2 Answers 2

3

Currently torch.matmul is not supported for complex tensors such as ComplexFloatTensor but you could do something as compact as the following code:

def matmul_complex(t1,t2): return torch.view_as_complex(torch.stack((t1.real @ t2.real - t1.imag @ t2.imag, t1.real @ t2.imag + t1.imag @ t2.real),dim=2)) 

When possible avoid using for loops as these will result in much slower implementations. Vectorization is achieved by using built-in methods as demonstrated in the code I have attached. For example, your code takes roughly 6.1s on CPU while the vectorized version takes only 101ms (~60 times faster) for 2 random complex matrices with dimensions 1000 X 1000.

Update:

Since PyTorch 1.7.0 (as @EduardoReis mentioned) you can do matrix multiplication between complex matrices similarly to real-valued matrices as follows:

t1 @ t2 (for t1, t2 complex matrices).

Sign up to request clarification or add additional context in comments.

2 Comments

Recently, using torch 1.8.1+cu101 I have been able to simply multiply the two tensors by x*h, and this is producing their complex product.
@EduardoReis You are correct. Since PyTorch 1.7.0 you can shorten the code above. But, do note that t1 * t2 is pointwise multiplication between tensors t1 & t2. You can use t1 @ t2 to obtain matrix multiplication equivalent to the matmul_complex. I updated the post.
0

I implemented this function for pytorch.matmul for complex numbers using torch.mv and it's working fine for time-being:

def matmul_complex(t1, t2): m = list(t1.size())[0] n = list(t2.size())[1] t = torch.empty((1,n), dtype=torch.cfloat) t_total = torch.empty((m,n), dtype=torch.cfloat) for i in range(0,n): if i == 0: t_total = torch.mv(t1,t2[:,i]) else: t_total = torch.cat((t_total, torch.mv(t1,t2[:,i])), 0) t_final = torch.reshape(t_total, (m,n)) return t_final 

I am new to PyTorch, so please correct me if I am wrong.

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.