Back to all projects
Track C: Agent Behavior & InfrastructureIntermediate20-25 hours

Distributed Simulation Pipeline

Scale simulations across multiple GPUs with efficient orchestration.

JAX pmapDistributed SystemsGPU Programming

Distributed Simulation Pipeline

Scale autonomous driving simulations from a single GPU bottleneck to a multi-device pipeline that achieves near-linear throughput scaling -- the infrastructure leap that separates research prototypes from production-grade validation systems.


Overview

Validating autonomous driving systems requires running millions of simulation scenarios. A single challenging intersection scenario takes a fraction of a second to simulate, but multiply that by ten million scenarios across hundreds of policy variants and you face a wall: single-device simulation cannot keep up with the pace of AD development. The solution is distributed simulation -- splitting work across multiple GPUs (or TPUs) with minimal overhead.

In this project you will build a complete distributed simulation pipeline using JAX's device parallelism primitives (jax.pmap, jax.sharding, device meshes). Starting from a single-device batched simulation, you will progressively scale to multi-device execution, instrument the pipeline with profiling to identify bottlenecks, and ultimately demonstrate near-linear throughput scaling. Along the way you will confront the real engineering challenges of distributed simulation: uneven scenario lengths, device synchronization overhead, memory pressure from large batches, and the gap between theoretical and actual scaling predicted by Amdahl's law.

Your deliverable is a modular Python pipeline consisting of: (1) a batched simulation engine using jax.vmap for scenario-level parallelism on a single device, (2) a multi-device parallel executor using jax.pmap and sharding to distribute scenarios across GPUs, (3) an async data pipeline with prefetching for continuous throughput, (4) a profiling and benchmarking harness that measures scenarios/second, GPU utilization, memory consumption, and scaling efficiency, and (5) a performance dashboard that visualizes the scaling curve and identifies optimization opportunities. The system should work on any number of JAX-visible devices -- from a single CPU to a multi-GPU workstation.

This matters because the industry is converging on simulation-centric validation. Waymo's Safety Report describes running billions of simulated miles. NVIDIA's Omniverse platform processes millions of sensor frames in parallel. Applied Intuition's platform scales across GPU clusters for fleet-level testing. Understanding the principles of distributed simulation -- batching strategy, data layout, communication minimization, and pipeline overlap -- is foundational for anyone building AD infrastructure at scale.


Learning Objectives

By completing this project, you will be able to:

  • Design batched simulation functions using jax.vmap that exploit data parallelism within a single device, understanding how batch size affects throughput, memory, and JIT compilation time.
  • Implement multi-device parallelism using jax.pmap and JAX's sharding API to distribute independent simulation scenarios across multiple GPUs/TPUs, including proper data distribution and result collection.
  • Build async data pipelines that overlap data loading and preprocessing with simulation execution, using Python threading, JAX async dispatch, and prefetch buffers to eliminate I/O bottlenecks.
  • Apply JAX profiling tools (jax.profiler, TensorBoard integration, device timeline analysis) to identify compute, memory, and communication bottlenecks in a distributed pipeline.
  • Analyze scaling behavior using Amdahl's law, strong scaling, and weak scaling frameworks, and explain why real-world pipelines achieve sub-linear scaling.
  • Implement practical optimizations including buffer donation (jax.donate_argnums), padding minimization for ragged batches, computation/communication overlap, and memory-efficient gradient accumulation patterns.
  • Build a performance benchmarking harness that produces reproducible throughput measurements (scenarios/second, steps/second) with proper warmup, statistical repetitions, and confidence intervals.
  • Create a monitoring dashboard that visualizes real-time throughput, per-device utilization, memory pressure, and scaling curves for pipeline diagnosis and optimization.

Prerequisites

  • Required: Python proficiency (comfortable with NumPy, dataclasses, and threading), JAX fundamentals (pure functions, jax.jit, jax.vmap, PyTrees), and a conceptual understanding of GPU programming (device memory, kernel execution, host-device transfer).
  • Recommended: Experience with jax.pmap or jax.sharding, familiarity with profiling tools (any framework), basic understanding of parallel computing concepts (Amdahl's law, data parallelism, communication overhead).
  • Hardware: The notebooks are designed to work on any hardware configuration. Single-CPU/GPU users will simulate multi-device behavior using JAX's device count emulation (XLA_FLAGS=--xla_force_host_platform_device_count=8). Multi-GPU users will see actual parallel execution.
  • Deep Dive Reading:
    • JAX documentation on parallelism: jax.pmap, jax.sharding, device meshes
    • Scaling JAX for RL -- Techniques for scaling reinforcement learning workloads across devices with JAX, directly applicable to simulation scaling.

Key Concepts

JAX Device Parallelism: pmap, Sharding, and Meshes

JAX provides multiple mechanisms for distributing computation across devices. Understanding when to use each is the core skill for distributed simulation.

jax.vmap (vectorized map) transforms a function that operates on a single example into one that operates on a batch. The entire batch runs on a single device. This is the foundation -- before distributing across devices, you need efficient within-device batching.

# Single scenario simulation step
def sim_step(state, action):
    """Advance one scenario by one timestep."""
    new_x = state['x'] + state['speed'] * jnp.cos(state['heading']) * DT
    new_y = state['y'] + state['speed'] * jnp.sin(state['heading']) * DT
    new_speed = jnp.clip(state['speed'] + action['accel'] * DT, 0.0, 30.0)
    new_heading = state['heading'] + action['steer'] * DT
    return {'x': new_x, 'y': new_y, 'speed': new_speed, 'heading': new_heading}

# Batched version: operates on (batch_size,) arrays for each field
batched_sim_step = jax.vmap(sim_step)

jax.pmap (parallel map) replicates a function across devices. Each device gets a slice of the data and executes the function independently. Results are gathered back to the host. This is the workhorse for embarrassingly parallel workloads like independent scenario simulation.

# Distribute scenarios across devices
@jax.pmap
def parallel_sim_step(state_shard, action_shard):
    """Each device simulates its shard of scenarios."""
    return batched_sim_step(state_shard, action_shard)

# state has shape (n_devices, scenarios_per_device, ...)
# Each device gets (scenarios_per_device, ...) automatically

jax.sharding (the newer API) provides finer-grained control over data placement. You define a Mesh of devices and use NamedSharding or PositionalSharding to specify how array dimensions map to mesh dimensions. This is more flexible than pmap and is the recommended approach for complex distributed workloads.

from jax.sharding import Mesh, PartitionSpec, NamedSharding

# Create a device mesh
devices = jax.devices()
mesh = Mesh(devices, axis_names=('devices',))

# Shard data along the batch dimension
sharding = NamedSharding(mesh, PartitionSpec('devices'))

# Place data on devices
sharded_states = jax.device_put(states, sharding)

# jit-compiled function automatically distributes based on input sharding
@jax.jit
def distributed_sim(states, actions):
    return jax.vmap(sim_step)(states, actions)

result = distributed_sim(sharded_states, sharded_actions)

The key insight is that simulation scenarios are embarrassingly parallel -- each scenario is independent, requiring no inter-device communication during the simulation loop. This is the best possible case for distributed computing, and you should expect near-linear scaling if you manage data transfer and synchronization correctly.

Simulation Batching Strategies

There are two orthogonal axes of batching for AD simulation:

Scenario-level batching: Multiple independent scenarios simulated in parallel. Each scenario has its own road geometry, initial traffic configuration, and agent set. This is the primary batching axis for validation workloads where you need to evaluate a policy across thousands of diverse situations.

Scenario Batch (vmap over scenarios):
  Scenario 1: [ego + 12 agents, 80 timesteps, highway merge]
  Scenario 2: [ego + 8 agents, 80 timesteps, unprotected left turn]
  Scenario 3: [ego + 15 agents, 80 timesteps, pedestrian crossing]
  ...

Agent-level batching: Within a single scenario, multiple agents simulated in parallel. This is useful when each agent runs its own policy (e.g., in multi-agent RL) and you want to vectorize across agents.

Agent Batch (vmap over agents within a scenario):
  Agent 1: IDM controller, following Agent 0
  Agent 2: learned policy, navigating intersection
  Agent 3: IDM controller, yielding to pedestrian
  ...

For distributed simulation, scenario-level batching is the primary strategy because scenarios are fully independent. Agent-level batching is a secondary optimization within each scenario. The pipeline in this project uses scenario-level batching as the outer loop and can optionally use agent-level batching within each scenario.

Handling ragged batches: Real scenarios have different numbers of agents (5 to 128+). Naive batching requires padding all scenarios to the maximum agent count, wasting memory and compute. Strategies include: (1) sorting scenarios by agent count and batching similar sizes together, (2) using JAX's masking support to mark padded agents as invalid, (3) bucketing scenarios into size classes.

Data Pipeline Design

A distributed simulation pipeline has three stages that should overlap:

Time -->
  Data Loading:  [Load B1] [Load B2] [Load B3] [Load B4] ...
  Simulation:           [Sim B1]  [Sim B2]  [Sim B3]  ...
  Collection:                  [Collect B1] [Collect B2] ...

Without overlap, the GPU sits idle during data loading and result collection. With overlap (pipelining), each stage processes a different batch concurrently:

Prefetching: Load the next batch of scenarios into host memory while the current batch is being simulated on the GPU. Use Python threads or concurrent.futures.ThreadPoolExecutor for async loading.

Host-to-device transfer: Move data from host RAM to GPU memory. Use jax.device_put with explicit sharding to place data directly on the correct device. For multi-GPU, this means each device gets its shard without an extra copy.

Async dispatch: JAX operations are dispatched asynchronously -- jax.jit-compiled functions return immediately with a Future-like object while the GPU executes. You can dispatch the next batch's host-to-device transfer while the GPU is still computing the current batch.

Result collection: After simulation, results (metrics, trajectories) need to be gathered from devices. Use jax.device_get or let results materialize lazily. For large-scale runs, stream results to disk rather than accumulating in host memory.

import threading
from collections import deque

class PrefetchPipeline:
    """Overlap data loading with simulation."""

    def __init__(self, data_source, batch_size, prefetch_count=2):
        self.data_source = data_source
        self.batch_size = batch_size
        self.buffer = deque(maxlen=prefetch_count)
        self.loader_thread = None

    def _load_worker(self):
        """Background thread that continuously loads batches."""
        for batch in self.data_source:
            processed = preprocess_batch(batch, self.batch_size)
            self.buffer.append(processed)

    def start(self):
        self.loader_thread = threading.Thread(target=self._load_worker, daemon=True)
        self.loader_thread.start()

    def get_batch(self):
        while len(self.buffer) == 0:
            pass  # spin-wait (use Event in production)
        return self.buffer.popleft()

Profiling and Bottleneck Identification

JAX provides several profiling tools:

jax.profiler.trace: Captures a trace of JAX operations that can be viewed in TensorBoard or Chrome's trace viewer. Shows kernel execution times, host-device transfers, and compilation events.

with jax.profiler.trace("/tmp/jax-trace"):
    for batch in pipeline:
        result = simulate(batch)
        result.block_until_ready()  # ensure computation completes within trace

jax.profiler.start_server: Starts a profiling server for continuous monitoring.

Device timeline: Shows when each device is computing, idle, or transferring data. Idle gaps indicate pipeline bubbles -- opportunities for optimization.

Key metrics to monitor:

  • Scenarios/second: Primary throughput metric. Total scenarios processed divided by wall-clock time.
  • GPU utilization: Fraction of time the GPU is actively computing. Target: >90%.
  • Memory utilization: Peak device memory usage. Determines maximum batch size.
  • Scaling efficiency: throughput(N devices) / (N * throughput(1 device)). 1.0 is perfect linear scaling.

Common bottlenecks and fixes:

BottleneckSymptomFix
JIT compilationSlow first batchWarm up with dummy batch; use jax.jit with static argnums
Host-device transferGPU idle between batchesPrefetch; overlap transfer with compute
Memory pressureOOM on large batchesReduce batch size; use gradient accumulation pattern; donate buffers
Padding wasteLow effective utilizationSort scenarios by size; use bucketed batching
SynchronizationAll devices wait for slowestBalance workload; avoid unnecessary all-reduce
Python overheadGIL contention in data pipelineUse multiprocessing or move preprocessing to C++/Rust

Scaling Laws and Amdahl's Law

Amdahl's Law describes the theoretical speedup from parallelization:

$$S(N) = \frac{1}{(1 - p) + \frac{p}{N}}$$

where $p$ is the fraction of work that is parallelizable and $N$ is the number of devices. Even with $p = 0.99$, the maximum speedup with 100 devices is only $\approx 50\times$ because the 1% serial portion becomes the bottleneck.

For simulation pipelines, the serial portions include:

  • Data loading from disk (can be partially parallelized with multiple reader threads)
  • JIT compilation (one-time cost, amortized over many batches)
  • Result aggregation (collecting and reducing metrics from all devices)
  • Host-side Python orchestration (loop overhead, batch preparation)

Strong scaling: Fix total work, increase devices. Measures how much faster you complete a fixed workload. Limited by serial fraction and communication overhead.

Weak scaling: Fix work per device, increase devices. Measures how throughput grows with resources. For embarrassingly parallel simulation, weak scaling should be near-linear because each device does independent work.

In practice, simulation pipelines achieve 85-95% scaling efficiency on 4-8 GPUs for scenario-level parallelism, dropping to 70-85% on 16-32 GPUs due to increasing data transfer overhead and host orchestration bottlenecks.

Checkpoint and Fault Tolerance

Long-running simulation campaigns (millions of scenarios over hours or days) must handle failures gracefully:

Progress checkpointing: Periodically save which scenarios have been completed. On restart, skip already-processed scenarios. Use a simple file-based checkpoint (JSON or SQLite) that records scenario IDs and their status.

Result streaming: Write results to disk incrementally rather than accumulating in memory. This prevents data loss if the process crashes and keeps memory usage bounded.

Device failure handling: In multi-GPU setups, a single device failure can crash the entire pmap operation. Strategies include: (1) catch and retry with reduced device count, (2) use JAX's experimental checkpointing for state recovery, (3) design the pipeline so each batch is independent and can be retried.

class CheckpointedPipeline:
    def __init__(self, checkpoint_path):
        self.checkpoint_path = checkpoint_path
        self.completed = self._load_checkpoint()

    def _load_checkpoint(self):
        if os.path.exists(self.checkpoint_path):
            with open(self.checkpoint_path, 'r') as f:
                return set(json.load(f))
        return set()

    def _save_checkpoint(self):
        with open(self.checkpoint_path, 'w') as f:
            json.dump(list(self.completed), f)

    def should_process(self, scenario_id):
        return scenario_id not in self.completed

    def mark_complete(self, scenario_id):
        self.completed.add(scenario_id)
        if len(self.completed) % 100 == 0:
            self._save_checkpoint()

Step-by-Step Implementation Guide

Step 1: Environment Setup (30 min)

Set up the project structure, install dependencies, and verify JAX can see your devices.

1.1 Project Directory Structure

distributed-simulation-pipeline/
├── notebooks/
│   ├── 01_batched_simulation.ipynb
│   ├── 02_multi_device_parallelism.ipynb
│   └── 03_full_pipeline.ipynb
├── requirements.txt
└── README.md

1.2 Install Dependencies

python -m venv .venv
source .venv/bin/activate

pip install -r requirements.txt

1.3 Verify JAX Setup and Device Visibility

import jax
import jax.numpy as jnp

print(f"JAX version: {jax.__version__}")
print(f"Available devices: {jax.devices()}")
print(f"Number of devices: {jax.device_count()}")
print(f"Local device count: {jax.local_device_count()}")
print(f"Default backend: {jax.default_backend()}")

# Quick functionality test
x = jnp.ones((1000,))
y = jax.jit(lambda x: x ** 2)(x)
print(f"JIT test: sum of squares = {y.sum()}")

For single-GPU or CPU users, emulate multiple devices:

# Emulate 8 devices for development/testing
export XLA_FLAGS="--xla_force_host_platform_device_count=8"
python -c "import jax; print(f'Devices: {jax.device_count()}')"
# Should print: Devices: 8

1.4 Define the Simulation State

We use a simple kinematic vehicle model throughout this project. The state is a JAX-compatible PyTree (nested dict of arrays) that works seamlessly with jit, vmap, and pmap.

import jax
import jax.numpy as jnp
from typing import Dict

# Type alias for clarity
State = Dict[str, jnp.ndarray]

def make_initial_state(n_agents: int, key: jnp.ndarray) -> State:
    """
    Create initial state for a scenario with n_agents vehicles.

    State is a dict of (n_agents,) arrays.
    """
    keys = jax.random.split(key, 5)
    return {
        'x': jax.random.uniform(keys[0], (n_agents,), minval=0.0, maxval=200.0),
        'y': jax.random.uniform(keys[1], (n_agents,), minval=-3.7, maxval=7.4),
        'heading': jax.random.uniform(keys[2], (n_agents,), minval=-0.1, maxval=0.1),
        'speed': jax.random.uniform(keys[3], (n_agents,), minval=5.0, maxval=25.0),
        'valid': jnp.ones(n_agents, dtype=jnp.bool_),
    }

DT = 0.1  # 10 Hz simulation
NUM_STEPS = 80  # 8 seconds

Step 2: Single-Device Batched Simulation -- Notebook 01 (4-5 hours)

The first notebook builds the foundation: a single-device simulation engine that uses jax.vmap to process many scenarios in parallel. This establishes the baseline throughput that we will scale in later notebooks.

2.1 Implement the Simulation Step Function

The step function must be a pure JAX function (no side effects, no Python control flow that depends on array values) so that jit and vmap can transform it.

@jax.jit
def sim_step(state: State, actions: Dict[str, jnp.ndarray]) -> State:
    """
    Advance all agents in a scenario by one timestep.

    Args:
        state: dict with arrays of shape (n_agents,) for each field
        actions: dict with 'accel' and 'steer' arrays of shape (n_agents,)

    Returns:
        Updated state dict
    """
    cos_h = jnp.cos(state['heading'])
    sin_h = jnp.sin(state['heading'])

    new_speed = jnp.clip(state['speed'] + actions['accel'] * DT, 0.0, 30.0)
    new_heading = state['heading'] + actions['steer'] * DT
    new_x = state['x'] + new_speed * cos_h * DT
    new_y = state['y'] + new_speed * sin_h * DT

    return {
        'x': new_x,
        'y': new_y,
        'heading': new_heading,
        'speed': new_speed,
        'valid': state['valid'],
    }

2.2 Roll Out a Full Scenario

Use jax.lax.scan for the temporal loop -- this is more efficient than a Python for-loop because JAX compiles the entire loop into a single XLA program.

def rollout_scenario(initial_state: State, action_sequence: Dict) -> State:
    """
    Simulate a full scenario using lax.scan for the time loop.

    Args:
        initial_state: state at t=0, arrays of shape (n_agents,)
        action_sequence: dict of arrays with shape (num_steps, n_agents)

    Returns:
        Final state and trajectory (all intermediate states)
    """
    def scan_fn(state, actions_t):
        next_state = sim_step(state, actions_t)
        return next_state, next_state  # (carry, output)

    # actions_t is a dict; scan iterates over the leading axis
    final_state, trajectory = jax.lax.scan(
        scan_fn, initial_state, action_sequence
    )
    return final_state, trajectory

2.3 Vectorize Across Scenarios

This is where vmap shines. A single vmap call transforms rollout_scenario from operating on one scenario to operating on a batch:

# Batch over scenarios: each field goes from (n_agents,) to (batch, n_agents)
batched_rollout = jax.vmap(rollout_scenario)

# Now:
#   initial_states: dict of (batch_size, n_agents) arrays
#   action_sequences: dict of (batch_size, num_steps, n_agents) arrays
#   returns: trajectory of (batch_size, num_steps, n_agents) arrays

2.4 Throughput Benchmarking

Measure how throughput changes with batch size. The key insight: there is an optimal batch size that maximizes GPU utilization without exceeding memory.

import time

def benchmark_throughput(batch_sizes, n_agents=16, n_steps=80, n_repeats=5):
    """Benchmark simulation throughput at different batch sizes."""
    results = []

    for batch_size in batch_sizes:
        # Generate random scenarios
        key = jax.random.PRNGKey(0)
        states = jax.vmap(make_initial_state, in_axes=(None, 0))(
            n_agents, jax.random.split(key, batch_size)
        )
        actions = {
            'accel': jax.random.normal(jax.random.PRNGKey(1),
                                       (batch_size, n_steps, n_agents)) * 0.5,
            'steer': jax.random.normal(jax.random.PRNGKey(2),
                                       (batch_size, n_steps, n_agents)) * 0.1,
        }

        # Warmup (includes JIT compilation)
        _ = batched_rollout(states, actions)
        jax.block_until_ready(_)

        # Timed runs
        times = []
        for _ in range(n_repeats):
            start = time.perf_counter()
            result = batched_rollout(states, actions)
            jax.block_until_ready(result)
            elapsed = time.perf_counter() - start
            times.append(elapsed)

        mean_time = sum(times) / len(times)
        throughput = batch_size / mean_time

        results.append({
            'batch_size': batch_size,
            'mean_time': mean_time,
            'throughput': throughput,
            'scenarios_per_sec': throughput,
        })
        print(f"Batch {batch_size:>6d}: {mean_time:.4f}s, "
              f"{throughput:.0f} scenarios/sec")

    return results

2.5 Memory Analysis

Understanding memory consumption is critical for choosing the right batch size:

def estimate_memory(batch_size, n_agents, n_steps):
    """Estimate memory usage for a batched simulation."""
    # State: 5 fields x batch x n_agents x float32 (4 bytes)
    state_bytes = 5 * batch_size * n_agents * 4
    # Actions: 2 fields x batch x n_steps x n_agents x float32
    action_bytes = 2 * batch_size * n_steps * n_agents * 4
    # Trajectory output: 5 fields x batch x n_steps x n_agents x float32
    trajectory_bytes = 5 * batch_size * n_steps * n_agents * 4
    total = state_bytes + action_bytes + trajectory_bytes
    return {
        'state_mb': state_bytes / 1e6,
        'actions_mb': action_bytes / 1e6,
        'trajectory_mb': trajectory_bytes / 1e6,
        'total_mb': total / 1e6,
    }

Key takeaways from Step 2:

  • jax.vmap + jax.lax.scan is the core pattern for efficient batched simulation.
  • Throughput scales with batch size up to a point, then plateaus (GPU saturated) or drops (OOM).
  • JIT compilation is a one-time cost that must be excluded from benchmarks.
  • The simulation function must be a pure function for vmap/jit to work.

Step 3: Multi-Device Parallelism -- Notebook 02 (5-6 hours)

The second notebook extends the pipeline to multiple devices. The key idea: split the scenario batch across devices, simulate independently on each device, and gather results.

3.1 Understanding JAX Device Mesh

import jax
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding

# Discover available devices
devices = jax.devices()
n_devices = len(devices)
print(f"Available devices: {n_devices}")

# Create a 1D mesh (simple data parallelism)
mesh = Mesh(devices, axis_names=('batch',))

# Define sharding: split the first (batch) axis across devices
batch_sharding = NamedSharding(mesh, P('batch'))

# Replicate sharding: every device gets a full copy
replicated_sharding = NamedSharding(mesh, P())

3.2 Distribute Scenarios with pmap

The classic pmap approach: reshape data so the leading dimension equals the number of devices.

@jax.pmap
def pmap_rollout(states, actions):
    """
    Each device simulates its shard of scenarios.
    Input shape per device: (scenarios_per_device, n_agents, ...)
    """
    return jax.vmap(rollout_scenario)(states, actions)

def distribute_and_simulate(all_states, all_actions, n_devices):
    """
    Split scenarios across devices, simulate in parallel, gather results.
    """
    batch_size = jax.tree_util.tree_leaves(all_states)[0].shape[0]
    per_device = batch_size // n_devices

    # Reshape: (batch,) -> (n_devices, per_device, ...)
    def reshape_for_pmap(x):
        return x.reshape(n_devices, per_device, *x.shape[1:])

    sharded_states = jax.tree_util.tree_map(reshape_for_pmap, all_states)
    sharded_actions = jax.tree_util.tree_map(reshape_for_pmap, all_actions)

    # Execute in parallel across devices
    final_states, trajectories = pmap_rollout(sharded_states, sharded_actions)

    # Reshape back: (n_devices, per_device, ...) -> (batch, ...)
    def reshape_from_pmap(x):
        return x.reshape(batch_size, *x.shape[2:])

    return (jax.tree_util.tree_map(reshape_from_pmap, final_states),
            jax.tree_util.tree_map(reshape_from_pmap, trajectories))

3.3 Modern Sharding API (jax.sharding)

The newer sharding API is more flexible and doesn't require manual reshaping:

@jax.jit
def sharded_rollout(states, actions):
    """
    JIT-compiled function that automatically distributes based on
    input sharding. No manual reshape needed.
    """
    return jax.vmap(rollout_scenario)(states, actions)

# Place data on devices with explicit sharding
with mesh:
    sharded_states = jax.device_put(all_states, batch_sharding)
    sharded_actions = jax.device_put(all_actions, batch_sharding)

    # JAX infers the parallel execution plan from input sharding
    final_states, trajectories = sharded_rollout(sharded_states, sharded_actions)

3.4 Cross-Device Communication: Collision Checking

Most simulation steps are per-scenario and need no communication. But some operations (like collision checking across scenarios or global metrics) require cross-device data movement.

@jax.pmap
def parallel_sim_with_global_metrics(states, actions):
    """Simulate and compute local metrics on each device."""
    final, traj = jax.vmap(rollout_scenario)(states, actions)

    # Local metrics (per-device, no communication)
    local_speeds = traj['speed'].mean(axis=(0, 1, 2))

    return final, traj, local_speeds


def gather_global_metrics(local_metrics):
    """
    Gather metrics from all devices and compute global statistics.
    Uses jax.lax.pmean/psum for efficient reduction.
    """
    # Option 1: Gather to host and aggregate in NumPy
    all_metrics = jax.device_get(local_metrics)  # (n_devices, ...)
    global_mean = all_metrics.mean()

    return global_mean

For in-device collective operations:

@jax.pmap(axis_name='devices')
def parallel_sim_with_allreduce(states, actions):
    """Use lax.pmean for in-device metric averaging."""
    final, traj = jax.vmap(rollout_scenario)(states, actions)
    local_collision_rate = compute_collision_rate(traj)

    # All-reduce: every device gets the global mean
    global_collision_rate = jax.lax.pmean(local_collision_rate, axis_name='devices')

    return final, traj, global_collision_rate

3.5 Scaling Benchmark

Measure throughput as you increase the number of devices:

def scaling_benchmark(batch_size, n_agents, n_steps, device_counts):
    """
    Measure throughput at different device counts.

    For device_counts > actual devices, use XLA_FLAGS emulation.
    """
    results = []

    for n_dev in device_counts:
        # Ensure batch_size is divisible by n_dev
        effective_batch = (batch_size // n_dev) * n_dev

        # Generate data
        key = jax.random.PRNGKey(42)
        states = generate_batch(effective_batch, n_agents, key)
        actions = generate_actions(effective_batch, n_steps, n_agents, key)

        # Distribute
        per_device = effective_batch // n_dev
        # ... reshape for pmap ...

        # Warmup
        _ = pmap_rollout(sharded_states, sharded_actions)
        jax.block_until_ready(_)

        # Benchmark
        times = []
        for _ in range(5):
            start = time.perf_counter()
            result = pmap_rollout(sharded_states, sharded_actions)
            jax.block_until_ready(result)
            times.append(time.perf_counter() - start)

        mean_time = sum(times) / len(times)
        throughput = effective_batch / mean_time

        results.append({
            'n_devices': n_dev,
            'throughput': throughput,
            'time': mean_time,
            'efficiency': throughput / (results[0]['throughput'] * n_dev)
                         if results else 1.0,
        })

    return results

3.6 Analyzing Scaling Efficiency

Plot the scaling curve and compare to ideal (linear) scaling:

def plot_scaling(results):
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    devices = [r['n_devices'] for r in results]
    throughputs = [r['throughput'] for r in results]
    speedups = [t / throughputs[0] for t in throughputs]
    efficiencies = [s / d for s, d in zip(speedups, devices)]

    # Throughput
    axes[0].plot(devices, throughputs, 'bo-', linewidth=2, label='Actual')
    axes[0].plot(devices, [throughputs[0] * d for d in devices],
                 'r--', label='Ideal (linear)')
    axes[0].set_xlabel('Number of Devices')
    axes[0].set_ylabel('Scenarios/sec')
    axes[0].set_title('Throughput Scaling')
    axes[0].legend()

    # Speedup
    axes[1].plot(devices, speedups, 'bo-', linewidth=2, label='Actual')
    axes[1].plot(devices, devices, 'r--', label='Ideal')
    axes[1].set_xlabel('Number of Devices')
    axes[1].set_ylabel('Speedup (x)')
    axes[1].set_title('Strong Scaling')
    axes[1].legend()

    # Efficiency
    axes[2].plot(devices, efficiencies, 'go-', linewidth=2)
    axes[2].axhline(y=1.0, color='r', linestyle='--', label='Perfect')
    axes[2].set_xlabel('Number of Devices')
    axes[2].set_ylabel('Efficiency')
    axes[2].set_title('Scaling Efficiency')
    axes[2].set_ylim(0, 1.1)
    axes[2].legend()

    plt.tight_layout()
    plt.show()

Key takeaways from Step 3:

  • pmap is the simplest path to multi-device parallelism; jax.sharding is more flexible.
  • Simulation scenarios are embarrassingly parallel -- no communication needed during rollout.
  • Scaling efficiency drops with more devices due to data transfer and synchronization overhead.
  • The batch size per device matters: too small and the GPU is underutilized; too large and you hit OOM.

Step 4: Full Pipeline with Profiling -- Notebook 03 (5-6 hours)

The final notebook assembles everything into a production-quality pipeline with async data loading, profiling, and a performance dashboard.

4.1 Async Data Pipeline

import threading
import queue
import time

class AsyncDataPipeline:
    """
    Multi-threaded data pipeline that overlaps loading with simulation.

    Architecture:
        [Loader Thread] --> [Queue] --> [Simulation Loop]

    The loader thread continuously prepares batches while the main
    thread runs simulation on the GPU.
    """

    def __init__(self, scenario_generator, batch_size, n_devices,
                 prefetch_count=2):
        self.scenario_generator = scenario_generator
        self.batch_size = batch_size
        self.n_devices = n_devices
        self.prefetch_count = prefetch_count
        self.queue = queue.Queue(maxsize=prefetch_count)
        self._stop_event = threading.Event()

    def _loader_worker(self):
        """Background thread: load, preprocess, and enqueue batches."""
        batch_states = []
        batch_actions = []

        for scenario in self.scenario_generator:
            if self._stop_event.is_set():
                break

            state, actions = preprocess_scenario(scenario)
            batch_states.append(state)
            batch_actions.append(actions)

            if len(batch_states) == self.batch_size:
                # Stack into batched arrays
                batched = stack_scenarios(batch_states, batch_actions)

                # Shard for multi-device
                sharded = shard_for_devices(batched, self.n_devices)

                self.queue.put(sharded)
                batch_states = []
                batch_actions = []

    def start(self):
        self._stop_event.clear()
        self.thread = threading.Thread(target=self._loader_worker, daemon=True)
        self.thread.start()

    def stop(self):
        self._stop_event.set()

    def __iter__(self):
        self.start()
        while True:
            try:
                yield self.queue.get(timeout=10.0)
            except queue.Empty:
                if not self.thread.is_alive():
                    break
        self.stop()

4.2 JAX Profiler Integration

def profile_pipeline(pipeline, simulate_fn, n_batches=10,
                     trace_dir="/tmp/jax-sim-trace"):
    """
    Run the pipeline with JAX profiling enabled.

    Produces a trace file viewable in TensorBoard or Chrome trace viewer.
    """
    import os
    os.makedirs(trace_dir, exist_ok=True)

    # Warmup (outside trace)
    warmup_batch = next(iter(pipeline))
    _ = simulate_fn(warmup_batch)
    jax.block_until_ready(_)

    # Profiled execution
    with jax.profiler.trace(trace_dir):
        batch_times = []
        for i, batch in enumerate(pipeline):
            if i >= n_batches:
                break

            start = time.perf_counter()
            result = simulate_fn(batch)
            jax.block_until_ready(result)
            elapsed = time.perf_counter() - start

            batch_times.append(elapsed)

    print(f"Trace saved to {trace_dir}")
    print(f"View with: tensorboard --logdir {trace_dir}")

    return batch_times

4.3 Bottleneck Identification

def identify_bottlenecks(batch_times, batch_size, loading_times=None):
    """
    Analyze timing data to identify pipeline bottlenecks.
    """
    compute_mean = sum(batch_times) / len(batch_times)
    throughput = batch_size / compute_mean

    print(f"Pipeline Analysis:")
    print(f"  Compute time per batch: {compute_mean:.4f}s")
    print(f"  Throughput: {throughput:.0f} scenarios/sec")

    if loading_times:
        load_mean = sum(loading_times) / len(loading_times)
        print(f"  Loading time per batch: {load_mean:.4f}s")

        if load_mean > compute_mean:
            print(f"  BOTTLENECK: Data loading ({load_mean:.4f}s > {compute_mean:.4f}s)")
            print(f"  Recommendation: Increase prefetch count or parallelize loading")
        else:
            print(f"  Pipeline is compute-bound (good!)")
            overlap = 1.0 - (load_mean / compute_mean)
            print(f"  Loading can be {overlap:.0%} overlapped with compute")

4.4 Optimization: Buffer Donation

Buffer donation tells JAX that input buffers can be reused for outputs, avoiding extra allocations:

@functools.partial(jax.jit, donate_argnums=(0,))
def sim_step_donate(state, actions):
    """
    Simulate with buffer donation.

    donate_argnums=(0,) tells JAX it can overwrite the state buffer,
    which is safe because we don't need the old state after stepping.
    """
    return sim_step(state, actions)

4.5 Optimization: Bucketed Batching for Ragged Scenarios

def bucket_scenarios(scenarios, bucket_boundaries):
    """
    Group scenarios by agent count to minimize padding waste.

    Args:
        scenarios: list of scenarios with varying agent counts
        bucket_boundaries: list of agent count thresholds [8, 16, 32, 64]

    Returns:
        dict mapping bucket_size -> list of scenarios
    """
    buckets = {b: [] for b in bucket_boundaries}

    for scenario in scenarios:
        n_agents = scenario['n_agents']
        # Find smallest bucket that fits
        for b in bucket_boundaries:
            if n_agents <= b:
                buckets[b].append(pad_scenario(scenario, b))
                break

    return buckets

def pad_scenario(scenario, target_n_agents):
    """Pad a scenario to target_n_agents with invalid agents."""
    current = scenario['n_agents']
    padding = target_n_agents - current

    padded = {}
    for key, val in scenario.items():
        if key == 'n_agents':
            padded[key] = target_n_agents
        elif isinstance(val, jnp.ndarray) and val.shape[0] == current:
            pad_width = [(0, padding)] + [(0, 0)] * (val.ndim - 1)
            padded[key] = jnp.pad(val, pad_width)
        else:
            padded[key] = val

    # Mark padded agents as invalid
    valid = jnp.zeros(target_n_agents, dtype=jnp.bool_)
    padded['valid'] = valid.at[:current].set(True)

    return padded

4.6 Full Throughput Benchmark and Dashboard

def run_full_benchmark(pipeline, simulate_fn, n_batches=50):
    """
    Run a comprehensive benchmark and produce a performance report.
    """
    batch_times = []
    batch_sizes = []
    memory_usage = []

    for i, batch in enumerate(pipeline):
        if i >= n_batches:
            break

        # Track memory before
        # (JAX doesn't expose device memory easily;
        #  use jax.local_devices()[0].memory_stats() if available)

        start = time.perf_counter()
        result = simulate_fn(batch)
        jax.block_until_ready(result)
        elapsed = time.perf_counter() - start

        batch_size = jax.tree_util.tree_leaves(batch)[0].shape[0]
        batch_times.append(elapsed)
        batch_sizes.append(batch_size)

    return {
        'batch_times': batch_times,
        'batch_sizes': batch_sizes,
        'total_scenarios': sum(batch_sizes),
        'total_time': sum(batch_times),
        'throughput': sum(batch_sizes) / sum(batch_times),
        'mean_batch_time': sum(batch_times) / len(batch_times),
    }


def plot_dashboard(results, scaling_results=None):
    """
    Create a performance dashboard with 4 panels.
    """
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))

    # Panel 1: Throughput over time (batch-by-batch)
    cumulative_scenarios = np.cumsum(results['batch_sizes'])
    cumulative_time = np.cumsum(results['batch_times'])
    instantaneous_throughput = (np.array(results['batch_sizes']) /
                                np.array(results['batch_times']))

    axes[0, 0].plot(range(len(instantaneous_throughput)),
                    instantaneous_throughput, 'b-', alpha=0.5)
    axes[0, 0].axhline(y=results['throughput'], color='r', linestyle='--',
                        label=f"Mean: {results['throughput']:.0f} scen/s")
    axes[0, 0].set_xlabel('Batch Index')
    axes[0, 0].set_ylabel('Scenarios/sec')
    axes[0, 0].set_title('Instantaneous Throughput')
    axes[0, 0].legend()

    # Panel 2: Cumulative progress
    axes[0, 1].plot(cumulative_time, cumulative_scenarios, 'g-', linewidth=2)
    axes[0, 1].set_xlabel('Wall Time (s)')
    axes[0, 1].set_ylabel('Total Scenarios')
    axes[0, 1].set_title('Cumulative Progress')

    # Panel 3: Batch time distribution
    axes[1, 0].hist(results['batch_times'], bins=20, color='steelblue',
                    edgecolor='white')
    axes[1, 0].axvline(x=results['mean_batch_time'], color='r', linestyle='--',
                        label=f"Mean: {results['mean_batch_time']*1000:.1f}ms")
    axes[1, 0].set_xlabel('Batch Time (s)')
    axes[1, 0].set_ylabel('Count')
    axes[1, 0].set_title('Batch Time Distribution')
    axes[1, 0].legend()

    # Panel 4: Scaling curve (if available)
    if scaling_results:
        devices = [r['n_devices'] for r in scaling_results]
        throughputs = [r['throughput'] for r in scaling_results]
        ideal = [throughputs[0] * d for d in devices]

        axes[1, 1].plot(devices, throughputs, 'bo-', linewidth=2, label='Actual')
        axes[1, 1].plot(devices, ideal, 'r--', linewidth=1, label='Ideal')
        axes[1, 1].set_xlabel('Number of Devices')
        axes[1, 1].set_ylabel('Scenarios/sec')
        axes[1, 1].set_title('Scaling Curve')
        axes[1, 1].legend()
    else:
        axes[1, 1].text(0.5, 0.5, 'Scaling data\nnot available',
                         ha='center', va='center', fontsize=14,
                         transform=axes[1, 1].transAxes)
        axes[1, 1].set_title('Scaling Curve')

    fig.suptitle('Distributed Simulation Pipeline - Performance Dashboard',
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

Key takeaways from Step 4:

  • Async data loading eliminates I/O bottlenecks by overlapping with GPU compute.
  • JAX profiling reveals whether the pipeline is compute-bound, memory-bound, or I/O-bound.
  • Buffer donation and bucketed batching are practical optimizations that matter at scale.
  • A performance dashboard should show throughput, latency distribution, scaling, and resource utilization.

Summary and Extensions

After completing this project, you will have built a distributed simulation pipeline that:

  1. Batches scenarios efficiently using vmap and lax.scan on a single device.
  2. Distributes across GPUs using pmap or JAX sharding for near-linear scaling.
  3. Overlaps I/O with compute using async prefetching.
  4. Profiles and optimizes using JAX's built-in tools.
  5. Produces reproducible benchmarks with proper warmup and statistical analysis.

Extensions for further exploration:

  • Multi-host scaling: Extend from multi-GPU (single host) to multi-host (cluster) using JAX's distributed runtime and jax.distributed.initialize().
  • Heterogeneous workloads: Mix simple and complex scenarios (varying agent counts, timesteps) with dynamic load balancing.
  • Streaming results to database: Replace in-memory result collection with streaming writes to a database (e.g., DuckDB) for long-running campaigns.
  • Integration with Waymax: Replace the toy kinematic model with Waymax's full simulation loop and benchmark real WOMD scenario throughput.
  • Cost analysis: Add cost estimation ($/scenario on various GPU types) to help teams make informed infrastructure decisions.