- Notifications
You must be signed in to change notification settings - Fork 3.8k
Open
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
For metal codegen, when we use tvm_callback_metal_compile as a python-side debugger, it also set the format to metallib:
tvm/src/target/source/codegen_metal.cc
Lines 441 to 442 in 52e4547
| const auto fmetal_compile = tvm::ffi::Function::GetGlobal("tvm_callback_metal_compile"); | |
| std::string fmt = fmetal_compile ? "metallib" : "metal"; |
which will make MetalModuleNode treat text source as binary metallib and try to load with newLibraryWithData, and make tvm throw a tvm.error.InternalError: Fail to compile metal lib:Invalid library file
tvm/src/runtime/metal/metal_module.mm
Lines 123 to 128 in 52e4547
| } else { | |
| // Build from library. | |
| auto q = dispatch_queue_create("q", DISPATCH_QUEUE_SERIAL); | |
| auto data = dispatch_data_create(source.c_str(), source.length(), q, | |
| ^{ | |
| }); |
Also, with tvm-ffi>=0.1.8, it does not gives error message, but a crash instead, which seems not good:
(building from v0.1.8.post2 source is good)
python(69796,0x1f328ec40) malloc: *** error for object 0x8000000000000070: pointer being freed was not allocated python(69796,0x1f328ec40) malloc: *** set a breakpoint in malloc_error_break to debug reproducer:
import tilelang print("Imported tilelang") from tilelang import tvm as tvm from time import sleep # import tilelang.testing import tilelang.language as T import json import torch import os print("Imports done", flush=True) from tilelang.engine.callback import register_metal_postproc_callback @register_metal_postproc_callback def _p(code, target): print(code) return code @tilelang.jit def matmul(M, N, K, block_M, block_N, block_K, dtype="float32", accum_dtype="float"): @T.prim_func def gemm( A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype), C: T.Tensor((M, N), dtype), ): with T.Kernel( T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as ( bx, by, ): A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared") B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared") C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): T.copy(A[by * block_M, ko * block_K], A_shared) T.copy(B[ko * block_K, bx * block_N], B_shared) for i, j in T.Parallel(block_M, block_N): for k in T.Serial(block_K): C_local[i, j] += A_shared[i, k] * B_shared[k, j] T.copy(C_local, C[by * block_M, bx * block_N]) return gemm def benchmark(f, n, *args, **kwargs): # trigger jit f(*args, **kwargs) torch.mps.synchronize() with torch.mps.profiler.profile(mode="interval,event", wait_until_completed=True): start = torch.mps.Event(enable_timing=True) end = torch.mps.Event(enable_timing=True) start.record() for _ in range(n): f(*args, **kwargs) end.record() start.synchronize() end.synchronize() return start.elapsed_time(end) / 1000 if __name__ == "__main__": m = n = k = 128 torch_dtype = torch.float16 dtype = 'float16' a = torch.randn(m, k, device="mps", dtype=torch_dtype) b = torch.randn(k, n, device="mps", dtype=torch_dtype) c = torch.zeros(m, n, device="mps", dtype=torch_dtype) # torch_add = lambda: torch.matmul(a, b, out=c) # torch_add() # print(benchmark(torch_add, n=100)) print("Starting compilation...", flush=True) jit_kernel = matmul(m, n, k, 16, 16, 16, dtype=dtype, accum_dtype="float") print("Compilation finished.", flush=True) # print(jit_kernel.get_kernel_source()) jit_kernel(a, b, c) print(c) print(a @ b) Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug