a friend of mine implemented a sparse version of torch.bmm that actually works, but when I try a test, I have a runtime error (that has nothing to do with this implementation), that I don't understand. I have seen a few topics about if but couldn't find a solution. Here is the code, and the error:
if __name__ == "__main__": tmp = torch.zeros(1).cuda() batch_csr = BatchCSR() sparse_bmm = SparseBMM() i=torch.LongTensor([[0,5,8], [1,5,8], [2,5,8]]) v=torch.FloatTensor([4,3,8]) s=torch.Size([3,500,500]) indices, values, size = i,v,s a_ = torch.sparse.FloatTensor(indices, values, size).cuda().transpose(2, 1) batch_size, num_nodes, num_faces = a_.size() a = a_.to_dense() for _ in range(10): b = torch.randn(batch_size, num_faces, 16).cuda() torch.cuda.synchronize() time1 = time.time() result = torch.bmm(a, b) torch.cuda.synchronize() time2 = time.time() print("{} CuBlas dense bmm".format(time2 - time1)) torch.cuda.synchronize() time1 = time.time() col_ind, col_ptr = batch_csr(a_.indices(), a_.size()) my_result = sparse_bmm(a_.values(), col_ind, col_ptr, a_.size(), b) torch.cuda.synchronize() time2 = time.time() print("{} My sparse bmm".format(time2 - time1)) print("{} Diff".format((result-my_result).abs().max())) And the error:
Traceback (most recent call last): File "sparse_bmm.py", line 72, in <module> b = torch.randn(3, 500, 16).cuda() File "/home/bizeul/virtual_env/lib/python2.7/site-packages/torch/_utils.py", line 65, in _cuda return new_type(self.size()).copy_(self, async) RuntimeError: cuda runtime error (59) : device-side assert triggered at /b/wheel/pytorch-src/torch/lib/THC/generic/THCTensorCopy.c:18 When running with the command CUDA_LAUNCH_BLOCKING=1, I get the error :
/b/wheel/pytorch-src/torch/lib/THC/THCTensorIndex.cu:121: void indexAddSmallIndex(TensorInfo<T, IndexType>, TensorInfo<T, IndexType>, TensorInfo<long, IndexType>, int, int, IndexType, long) [with T = float, IndexType = unsigned int, DstDim = 1, SrcDim = 1, IdxDim = -2]: block: [0,0,0], thread: [0,0,0] Assertion `dstIndex < dstAddDimSize` failed. THCudaCheck FAIL file=/b/wheel/pytorch-src/torch/lib/THCS/generic/THCSTensorMath.cu line=292 error=59 : device-side assert triggered Traceback (most recent call last): File "sparse_bmm.py", line 69, in <module> a = a_.to_dense() RuntimeError: cuda runtime error (59) : device-side assert triggered at /b/wheel/pytorch-src/torch/lib/THCS/generic/THCSTensorMath.cu:292
CUDA_LAUNCH_BLOCKING=1 python your_script.pyand update your question