TPU Monitoring Library
Unlock deep insights into your Cloud TPU hardware's performance and behavior with advanced TPU monitoring capabilities, built directly upon the foundational software layer, LibTPU. While LibTPU encompasses drivers, networking libraries, the XLA compiler, and TPU runtime for interacting with TPUs, the focus of this document is the TPU Monitoring Library.
The TPU Monitoring Library provides:
Comprehensive observability: Gain access to the telemetry API and metrics suite, which provides detailed insights into the operational performance and specific behaviors of your TPUs.
Diagnostic toolkits: Provides an SDK and command-line interface (CLI) designed to enable debugging and in-depth performance analysis of your TPU resources.
These monitoring features are designed to be a top-level, customer-facing solution, providing you with the essential tools to optimize your TPU workloads effectively.
The TPU Monitoring Library gives you detailed information on how machine learning workloads are performing on TPU hardware. It's designed to help you understand your TPU utilization, identify bottlenecks, and debug performance issues. It gives you more detailed information than interruption metrics, goodput metrics, and other metrics.
Get started with the TPU Monitoring Library
Accessing these powerful insights is straightforward. The TPU monitoring functionality is integrated with the LibTPU SDK, so the functionality is included when you install LibTPU.
Install LibTPU
pip install libtpu Alternately, the LibTPU updates are coordinated with JAX releases, meaning that when you install the latest JAX release (released monthly), it will typically pin you to the latest compatible LibTPU version and its features.
Install JAX
pip install -U "jax[tpu]" For PyTorch users, installing PyTorch/XLA provides the latest LibTPU and TPU monitoring functionality.
Install PyTorch/XLA
pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \ -f https://storage.googleapis.com/libtpu-releases/index.html \ -f https://storage.googleapis.com/libtpu-wheels/index.html # Optional: if you're using custom kernels, install pallas dependencies pip install 'torch_xla[pallas]' \ -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html For more information about installing PyTorch/XLA, see Installation in the PyTorch/XLA GitHub repository.
Import the library in Python
To start using the TPU Monitoring Library, you need to import the libtpu module in your Python code.
from libtpu.sdk import tpumonitoring List all supported functionality
List all metric names and the functionality they support:
from libtpu.sdk import tpumonitoring tpumonitoring.help() " libtpu.sdk.monitoring.help(): List all supported functionality. libtpu.sdk.monitoring.list_support_metrics() List support metric names in the list of str format. libtpu.sdk.monitoring.get_metric(metric_name:str) Get metric data with metric name. It represents the snapshot mode. The metric data is a object with `description()` and `data()` methods, where the `description()` returns a string describe the format of data and data unit, `data()` returns the metric data in the list in str format. " Supported metrics
The following code sample shows how to list all supported metric names:
from libtpu.sdk import tpumonitoring tpumonitoring.list_supported_metrics() ["duty_cycle_pct", "tensorcore_util", "hbm_util", ...] The following table shows all metrics and their corresponding definitions:
| Metric | Definition | Metric name for API | Example values |
|---|---|---|---|
| Tensor Core Utilization | Measures the percentage of your TensorCore usage, calculated as the percentage of operations that are part of the TensorCore operations. Sampled 10 microseconds every 1 second. You cannot modify the sampling rate. This metric lets you monitor the efficiency of your workloads on TPU devices. | tensorcore_util | ['1.11', '2.22', '3.33', '4.44'] # utilization percentage for accelerator ID 0-3. |
| Duty Cycle Percentage | Percentage of time over the past sample period (every 5 seconds; can be tuned by setting the LIBTPU_INIT_ARG flag) during which the accelerator was actively processing (recorded with cycles used to execute HLO programs over the last sampling period). This metric represents how busy a TPU is. The metric is emitted per chip. | duty_cycle_pct | ['10.00', '20.00', '30.00', '40.00'] # Duty cycle percentage for accelerator ID 0-3. |
| HBM Capacity Total | This metric reports the total HBM capacity in bytes. | hbm_capacity_total | ['30000000000', '30000000000', '30000000000', '30000000000'] # Total HBM capacity in bytes that attached to accelerator ID 0-3. |
| HBM Capacity Usage | This metric reports the usage of HBM capacity in bytes over the past sample period (every 5 seconds; can be tuned by setting the LIBTPU_INIT_ARG flag). | hbm_capacity_usage | ['100', '200', '300', '400'] # Capacity usage for HBM in bytes that attached to accelerator ID 0-3. |
| Buffer transfer latency | Network transfer latencies for megascale multislice traffic. This visualization lets you understand the overall network performance environment. | buffer_transfer_latency | ["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"] # buffer size, mean, p50, p90, p99, p99.9 of network transfer latency distribution. |
| High Level Operation Execution Time Distribution Metrics | Provides granular performance insights into the HLO compiled binary execution status, enabling regression detection and model-level debugging. | hlo_exec_timing | ["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"] # The HLO execution time duration distribution for CoreType-CoreID with mean, p50, p90, p95, p999. |
| High Level Optimizer queue size | HLO execution queue size monitoring tracks the number of compiled HLO programs waiting or undergoing execution. This metric reveals execution pipeline congestion, enabling identification of performance bottlenecks in hardware execution, driver overhead, or resource allocation. | hlo_queue_size | ["tensorcore-0: 1", "tensorcore-1: 2"] # Measures queue size for CoreType-CoreID. |
| Collective End to End Latency | This metric measures the end-to-end collective latency over DCN in microseconds, from the host initiating the operation to all peers receiving the output. It includes host-side data reduction and sending output to the TPU. Results are strings detailing buffer size, type, and mean, p50, p90, p95, and p99.9 latencies. | collective_e2e_latency | ["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …] # Transfer size-collective op, mean, p50, p90, p95, p999 of collective end to end latency. |
| Round Trip Latency at Transport Layer | Distribution of the minimum Round Trip Times (RTT) observed on TCP connections used by gRPC for multislice TPU traffic. | grpc_tcp_min_round_trip_times | ['27.63, 29.03, 38.52, 41.63, 52.74'] # Represents the distribution's mean, p50, p90, p95, and p99.9 percentiles in microseconds (µs). |
| Throughput at Transport Layer | Cumulative distribution of the recent throughput of TCP connections used by gRPC for multislice TPU traffic. | grpc_tcp_delivery_rates | ['11354.89, 10986.35, 18239.55, 25718.35, 44841.55'] # Represents the distribution's mean, p50, p90, p95, and p99.9 percentiles in microseconds (µs). |
Read metric data
To read metric data, specify the metric name when you call the tpumonitoring.get_metric function. You can insert ad hoc metric checks into low-performance code to identify whether performance issues stem from software or hardware.
The following code sample shows how to read the duty_cycle metric:
from libtpu.sdk import tpumonitoring metric = tpumonitoring.get_metric("duty_cycle_pct") metric.description() "The metric provides a list of duty cycle percentages, one for each accelerator (from accelerator_0 to accelerator_x). The duty cycle represents the percentage of time an accelerator was actively processing during the last sample period, indicating TPU utilization." metric.data() ["0.00", "0.00", "0.00", "0.00"] # accelerator_0-3 Use metrics to check TPU utilization
The following examples show how to use metrics from the TPU Monitoring Library to track TPU utilization.
Monitor TPU duty cycle during JAX training
Scenario: You are running a JAX training script and want to monitor the TPU's duty_cycle_pct metric throughout the training process to confirm your TPUs are being effectively utilized. You can log this metric periodically during training to track TPU utilization.
The following code sample shows how to monitor TPU Duty Cycle during JAX training:
import jax import jax.numpy as jnp from libtpu.sdk import tpumonitoring import time # --- Your JAX model and training setup would go here --- # --- Example placeholder model and data (replace with your actual setup)--- def simple_model(x): return jnp.sum(x) def loss_fn(params, x, y): preds = simple_model(x) return jnp.mean((preds - y)**2) def train_step(params, x, y, optimizer): grads = jax.grad(loss_fn)(params, x, y) return optimizer.update(grads, params) key = jax.random.PRNGKey(0) params = jnp.array([1.0, 2.0]) # Example params optimizer = ... # Your optimizer (for example, optax.adam) data_x = jnp.ones((10, 10)) data_y = jnp.zeros((10,)) num_epochs = 10 log_interval_steps = 2 # Log duty cycle every 2 steps for epoch in range(num_epochs): for step in range(5): # Example steps per epoch params = train_step(params, data_x, data_y, optimizer) if (step + 1) % log_interval_steps == 0: # --- Integrate TPU Monitoring Library here to get duty_cycle --- duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct") duty_cycle_data = duty_cycle_metric.data print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:") print(f" Description: {duty_cycle_metric.description}") print(f" Data: {duty_cycle_data}") # --- End TPU Monitoring Library Integration --- # --- Rest of your training loop logic --- time.sleep(0.1) # Simulate some computation print("Training complete.") Check HBM utilization before running JAX inference
Scenario: Before running inference with your JAX model, check the current HBM (High Bandwidth Memory) utilization on the TPU to confirm that you have enough memory available and to get a baseline measurement before inference starts.
# The following code sample shows how to check HBM utilization before JAX inference: import jax import jax.numpy as jnp from libtpu.sdk import tpumonitoring # --- Your JAX model and inference setup would go here --- # --- Example placeholder model (replace with your actual model loading/setup)--- def simple_model(x): return jnp.sum(x) key = jax.random.PRNGKey(0) params = ... # Load your trained parameters # Integrate the TPU Monitoring Library to get HBM utilization before inference hbm_util_metric = tpumonitoring.get_metric("hbm_util") hbm_util_data = hbm_util_metric.data print("HBM Utilization Before Inference:") print(f" Description: {hbm_util_metric.description}") print(f" Data: {hbm_util_data}") # End TPU Monitoring Library Integration # Your Inference Logic input_data = jnp.ones((1, 10)) # Example input predictions = simple_model(input_data) print("Inference Predictions:", predictions) print("Inference complete.") Check network metrics
Scenario: You are running a multi-host and multislice workload and want to connect to one of the GKE pods or TPUs using ssh to view network metrics while the workload is running. The commands can also be incorporated directly into the multi-host workload.
import jax import jax.numpy as jnp from libtpu.sdk import tpumonitoring # --- Your JAX model and training setup goes here --- # --- Example placeholder model and data (replace with your actual setup)--- def simple_model(x): return jnp.sum(x) # --- details here --- # ============================================================================== # Metric 1: TCP Delivery Rate # ============================================================================== # This metric reflects the delivery rate of the TCP connection (bytes delivered / elapsed time). # The output is a list of strings representing latency statistics: # [mean, p50, p90, p95, p99.9] # Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in Mbps) # Get the metric object delivery_rate_metric = tpumonitoring.get_metric(metric_name_rate) # Print the description provided by the library print("Description:", delivery_rate_metric.description()) # Print the actual data payload print("Data:", delivery_rate_metric.data()) # ============================================================================== # Metric 2: TCP Minimum Round Trip Time (RTT) # ============================================================================== # This metric reflects the minimum RTT measured between sending a TCP packet # and receiving the acknowledgement. # The output is a list of strings representing latency statistics: # [mean, p50, p90, p95, p99.9] # Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in us - microseconds) # Get the metric object min_rtt_metric = tpumonitoring.get_metric(metric_name_rtt) # Print the description provided by the library print("Description:", min_rtt_metric.description()) # Print the actual data payload print("Data:", min_rtt_metric.data()) Refresh frequency of TPU metrics
The refresh frequency of TPU metrics is constrained to a minimum of one second. Host metric data is exported at a fixed frequency of 1 Hz. The latency introduced by this export process is negligible. Runtime metrics from LibTPU are not subject to the same frequency constraint. However, for consistency, these metrics are also sampled at 1 Hz or 1 sample per second.
TPU-Z module
TPU-Z is a telemetry and debugging facility for TPUs. It provides detailed runtime status information for all TPU cores attached to a host. The functionality is provided through the tpuz module, which is part of the libtpu.sdk module in the libtpu Python SDK. The module provides a snapshot of each core's state.
The primary use case for TPU-Z is diagnosing hangs or deadlocks in distributed TPU workloads. You can query the TPU-Z service on hosts to capture the state of every core, comparing the Program Counters, HLO locations, and Run IDs across all cores to identify anomalies.
Use the get_core_state_summary() function within the libtpu.sdk library to display the TPU-Z metrics:
summary = sdk.tpuz.get_core_state_summary() The output for the TPU-Z metrics is provided as a dictionary. The following is a truncated example for a single core:
{ "host_name": "my-tpu-host-vm", "core_states": { "1": { "core_id": { "global_core_id": 1, "chip_id": 0, "core_on_chip": { "type": "TPU_CORE_TYPE_TENSOR_CORE", "index": 1 } }, "sequencer_info": [ { "sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER", "sequencer_index": 0, "pc": 4490, "program_id": 3274167277388825310, "run_id": 3 } ], "program_fingerprint": "b'\\xbefB\\xc6\\x1eb\\xc1#\\xd0...'", "queued_program_info": [], "error_message": "" } // ... } } To retrieve information about the High-Level Optimizers (HLO) on each core, set the include_hlo_info parameter to True:
summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True) The output includes additional HLO Information:
"1": { "core_id": { "global_core_id": 1, "chip_id": 0, "core_on_chip": { "type": "TPU_CORE_TYPE_TENSOR_CORE", "index": 1 } }, "sequencer_info": [ { "sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER", "sequencer_index": 0, "pc": 17776, "tag": 3, "tracemark": 2147483646, "program_id": 3230481660274331500, "run_id": 81, "hlo_location": "HLO: fusion.11; HLO computation: main.126_spmd", "hlo_detailed_info": "[{\"details\":\"HloModule fusion.11, entry_computation_layout={(bf16[20>..." } ], "program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde\uff>", "launch_id": 1394130914, "queued_program_info": [ { "run_id": 81, "launch_id": 1394130914, "program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde>" } ] } TPU-Z metrics
The get_core_state_summary function returns TPU-Z metrics in the form of a dictionary with the following structure.
CurrentCoreStateSummary
The CurrentCoreStateSummary dictionary provides detailed summary of an individual TPU core's state.
| Field | Type | Description |
|---|---|---|
core_id | dictionary | A TpuCoreIdentifier dictionary that contains ID information about the TPU core. |
sequencer_info | list of dictionaries | A list of SequencerInfo dictionaries, describing the state of each sequencer on the core. |
program_fingerprint | bytes | The fingerprint of the program executing on this core. |
launch_id | integer | The launch ID of the current or most recent program. |
queued_program_info | list of dictionaries | A list of QueuedProgramInfo dictionaries for programs queued for execution. |
error_message | string | Any error messages for this core. |
TpuCoreIdentifier
The TpuCoreIdentifier dictionary provides ID information for cores within the TPU system.
| Field | Type | Description |
|---|---|---|
global_core_id | integer | The ID of the core. |
chip_id | integer | The ID of the chip that the core belongs to. |
core_on_chip | dictionary | A TpuCoreOnChip dictionary describing the core's type and its index on the chip. |
TpuCoreOnChip
The TpuCoreOnChip dictionary contains information about a core's properties within a specific chip.
| Field | Type | Description |
|---|---|---|
type | string | The type of the TPU core. For example: TPU_CORE_TYPE_TENSOR_CORE. |
index | integer | The index of the core on the chip. |
SequencerInfo
The SequencerInfo dictionary contains information about the state of a single sequencer on a core.
| Field | Type | Description |
|---|---|---|
sequencer_type | string | The type of sequencer. For example: TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER. |
sequencer_index | integer | The index of the sequencer (if there are multiple of the same type). |
pc | integer | The current Program Counter value. |
program_id | integer | The ID associated with a specific instance of a program being launched for execution on a TPU core. |
run_id | integer | The Run ID associated with a specific instance of a program's execution on a TPU core. |
hlo_location | string | High Level Optimizer location information. |
hlo_detailed_info | string | Detailed High Level Optimizer information. |
QueuedProgramInfo
The QueuedProgramInfo dictionary contains information about programs queued for execution on a core.
| Field | Type | Description |
|---|---|---|
run_id | integer | The Run ID for the queued program. |
launch_id | integer | The Launch ID for the queued program. |
program_fingerprint | bytes | The fingerprint of the queued program. |
TPU-Z with JAX
You can access TPU-Z metrics in JAX workloads through the libtpu.sdk library. The following Python script uses JAX for high-performance tensor computation, while simultaneously using the libtpu SDK in a background thread to monitor the state and activity of the underlying TPU hardware.
Include the following Python packages:
import jax import jax.numpy as jnp import time import threading from functools import partial from libtpu import sdk The monitor_tpu_status function uses a background thread to continuously show the operational status of the TPUs cores while the main application executes a JAX workload. It acts as a real-time diagnostic tool.
def monitor_tpu_status(): """Monitors TPU status in a background thread.""" while monitoring_active: try: summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True) if summary and 'core_states' in summary: print(summary) else: print('WARNING: Call returned an empty or invalid summary.') except RuntimeError as e: print(f'FAIL: Error calling API: {e}') except Exception as e: print(f'FAIL: Unexpected error in monitor thread: {e}') for _ in range(MONITORING_INTERVAL_SECONDS * 2): if not monitoring_active: break time.sleep(0.5) print('✅ Monitoring thread stopped.') The transformer_block function implements a complete layer of the Transformer architecture, which is the foundational building block for LLMs.
@partial(jax.jit, static_argnames=['num_heads']) def transformer_block(params, x, num_heads=32): """A simplified but computationally intensive Transformer block.""" # Multi-head Self-Attention qkv = jnp.dot(x, params['qkv_kernel']) q, k, v = jnp.array_split(qkv, 3, axis=-1) # Reshape for multi-head attention q = q.reshape(q.shape[0], q.shape[1], num_heads, -1).transpose(0, 2, 1, 3) k = k.reshape(k.shape[0], k.shape[1], num_heads, -1).transpose(0, 2, 1, 3) v = v.reshape(v.shape[0], v.shape[1], num_heads, -1).transpose(0, 2, 1, 3) # Scaled dot-product attention attention_scores = jnp.einsum('nhqd,nhkd->nhqk', q, k) / jnp.sqrt(q.shape[-1]) attention_weights = jax.nn.softmax(attention_scores, axis=-1) attention_output = jnp.einsum('nhqk,nhvd->nhqd', attention_weights, v) attention_output = attention_output.transpose(0, 2, 1, 3).reshape(x.shape) attention_output = jnp.dot(attention_output, params['o_kernel']) # Residual connection and Layer Normalization 1 h1 = x + attention_output h1_norm = h1 - jnp.mean(h1, axis=-1, keepdims=True) h1_norm = h1_norm / jnp.sqrt( jnp.mean(jnp.square(h1_norm), axis=-1, keepdims=True) + 1e-5 ) # Feed-Forward Network ffn_hidden = jax.nn.gelu(jnp.dot(h1_norm, params['ffn1_kernel'])) ffn_output = jnp.dot(ffn_hidden, params['ffn2_kernel']) # Residual connection and Layer Normalization 2 h2 = h1_norm + ffn_output h2_norm = h2 - jnp.mean(h2, axis=-1, keepdims=True) h2_norm = h2_norm / jnp.sqrt( jnp.mean(jnp.square(h2_norm), axis=-1, keepdims=True) + 1e-5 ) return h2_norm The main function orchestrates the setup of the JAX computation, initiates the background TPU monitoring, and runs the main workload loop.
def main(): num_devices = jax.device_count() print(f"Running on {num_devices} devices.") batch_size = 128 * num_devices seq_len = 512 embed_dim = 1024 ffn_dim = embed_dim * 4 key = jax.random.PRNGKey(0) params = { 'qkv_kernel': jax.random.normal( key, (embed_dim, embed_dim * 3), dtype=jnp.bfloat16 ), 'o_kernel': jax.random.normal( key, (embed_dim, embed_dim), dtype=jnp.bfloat16 ), 'ffn1_kernel': jax.random.normal( key, (embed_dim, ffn_dim), dtype=jnp.bfloat16 ), 'ffn2_kernel': jax.random.normal( key, (ffn_dim, embed_dim), dtype=jnp.bfloat16 ), } input_data = jax.random.normal( key, (batch_size, seq_len, embed_dim), dtype=jnp.bfloat16 ) input_data = jax.device_put(input_data) monitor_thread = threading.Thread(target=monitor_tpu_status) monitor_thread.start() print("Starting JAX computation loop...") start_time = time.time() iterations = 0 while time.time() - start_time < JOB_DURATION_SECONDS: result = transformer_block(params, input_data) result.block_until_ready() iterations += 1 print(f' -> Jax iteration {iterations} complete.', end='\r') print(f"\nCompleted {iterations} iterations in {JOB_DURATION_SECONDS} seconds.") global monitoring_active monitoring_active = False monitor_thread.join() if __name__ == '__main__': main() Troubleshooting
This section provides troubleshooting information to help you identify and resolve problems you might encounter while using the TPU Monitoring Library.
Missing features or metrics
If you are unable to view some features or metrics, the most common cause is an outdated libtpu version. The TPU Monitoring Library features and metrics are included in the libtpu releases, and outdated versions might be missing new features and metrics.
Check the version of libtpu that is running in your environment:
Command line:
pip show libtpu Python:
import libtpu print(libtpu.__version__) If you are not using the latest version of libtpu, use the following command to update the library:
pip install --upgrade libtpu