Skip to content

[Bug][Metal] metal codegen hook introduce unexpected side effect #18798

@oraluben

Description

@oraluben

For metal codegen, when we use tvm_callback_metal_compile as a python-side debugger, it also set the format to metallib:

const auto fmetal_compile = tvm::ffi::Function::GetGlobal("tvm_callback_metal_compile");
std::string fmt = fmetal_compile ? "metallib" : "metal";

which will make MetalModuleNode treat text source as binary metallib and try to load with newLibraryWithData, and make tvm throw a tvm.error.InternalError: Fail to compile metal lib:Invalid library file

} else {
// Build from library.
auto q = dispatch_queue_create("q", DISPATCH_QUEUE_SERIAL);
auto data = dispatch_data_create(source.c_str(), source.length(), q,
^{
});

Also, with tvm-ffi>=0.1.8, it does not gives error message, but a crash instead, which seems not good:
(building from v0.1.8.post2 source is good)

python(69796,0x1f328ec40) malloc: *** error for object 0x8000000000000070: pointer being freed was not allocated python(69796,0x1f328ec40) malloc: *** set a breakpoint in malloc_error_break to debug 

reproducer:

import tilelang print("Imported tilelang") from tilelang import tvm as tvm from time import sleep # import tilelang.testing import tilelang.language as T import json import torch import os print("Imports done", flush=True) from tilelang.engine.callback import register_metal_postproc_callback @register_metal_postproc_callback def _p(code, target): print(code) return code @tilelang.jit def matmul(M, N, K, block_M, block_N, block_K, dtype="float32", accum_dtype="float"): @T.prim_func def gemm( A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype), C: T.Tensor((M, N), dtype), ): with T.Kernel( T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as ( bx, by, ): A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared") B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared") C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): T.copy(A[by * block_M, ko * block_K], A_shared) T.copy(B[ko * block_K, bx * block_N], B_shared) for i, j in T.Parallel(block_M, block_N): for k in T.Serial(block_K): C_local[i, j] += A_shared[i, k] * B_shared[k, j] T.copy(C_local, C[by * block_M, bx * block_N]) return gemm def benchmark(f, n, *args, **kwargs): # trigger jit f(*args, **kwargs) torch.mps.synchronize() with torch.mps.profiler.profile(mode="interval,event", wait_until_completed=True): start = torch.mps.Event(enable_timing=True) end = torch.mps.Event(enable_timing=True) start.record() for _ in range(n): f(*args, **kwargs) end.record() start.synchronize() end.synchronize() return start.elapsed_time(end) / 1000 if __name__ == "__main__": m = n = k = 128 torch_dtype = torch.float16 dtype = 'float16' a = torch.randn(m, k, device="mps", dtype=torch_dtype) b = torch.randn(k, n, device="mps", dtype=torch_dtype) c = torch.zeros(m, n, device="mps", dtype=torch_dtype) # torch_add = lambda: torch.matmul(a, b, out=c) # torch_add() # print(benchmark(torch_add, n=100)) print("Starting compilation...", flush=True) jit_kernel = matmul(m, n, k, 16, 16, 16, dtype=dtype, accum_dtype="float") print("Compilation finished.", flush=True) # print(jit_kernel.get_kernel_source()) jit_kernel(a, b, c) print(c) print(a @ b) 

cc @echuraev @junrushao

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions