I turned on XLA when running TensorFLow, and in order to further optimize the fused kernels, I added export XLA_FLAGS="--xla_dump_to=/tmp/xla_dump", and got the dumped IRs, including lmhlo.xxx.mlir and other llvm IRs. Now as I'm trying to further analyze those dumped mlir's, I want to read them as structured MLIR modules, so I need to read and parse it correctly. But I can't find any resources documenting how to do this in pure python, I tried "pymlir" module, but not working well with this TF XLA HLO module, maybe the dumped module has different format. So does anybody know how to read and parse this dumped mlir?
1 Answer
You can use the MLIR Python bindings to parse the MLIR files being dumped. Here's the general process and you can tweak the actual script as needed.
Clone the LLVM project repository
git clone https://github.com/llvm/llvm-project.gitCreate a build directory
cd llvm-project mkdir build cd buildConfigure the build to include MLIR and enable Python bindings
cmake -G Ninja ../llvm \ -DLLVM_ENABLE_PROJECTS="mlir" \ -DLLVM_BUILD_EXAMPLES=ON \ -DLLVM_TARGETS_TO_BUILD="host" \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_ENABLE_BINDINGS=ON \ -DLLVM_ENABLE_RTTI=ON \ -DPython3_EXECUTABLE=$(which python3)Build project
ninjaAdd the bindings to your
PYTHONPATHto make sure Python can find the MLIR modulesexport PYTHONPATH=$(pwd)/tools/mlir/python_packages/mlir_coreOnce the environment is set up, you can parse and analyze MLIR files. Here's a script that will read the MLIR file and iterate through its operations. It starts by initializing the MLIR context, registers the
lmhlodialect (assuming that's what the XLA dump uses), reads the dumped MLIR file, and parses it into a structured module so you can look at its operations.from mlir import ir # initialize MLIR context with ir.Context() as ctx: # register lmhlo dialect from mlir.dialects import lmhlo lmhlo.register_dialect(ctx) # read file with open('/tmp/xla_dump/lmhlo.xxx.mlir', 'r') as f: mlir_module = f.read() # parse module = ir.Module.parse(mlir_module) # loop through operations for op in module.body.operations: print(op)