Back to all projects
Track C: Agent Behavior & InfrastructureAdvanced25-35 hours

Adversarial Scenario Generator

Find edge cases that break perception and planning systems.

Adversarial MLScenario GenerationOptimization

Adversarial Scenario Generator

Systematically discover the scenarios that break your autonomous driving stack -- because the real world does not wait for you to find bugs by chance, and random testing cannot reach the dark corners where catastrophic failures hide.


Overview

In this project you will build an adversarial scenario generation system that uses gradient-based optimization to automatically discover driving scenarios where an autonomous driving (AD) planner fails. The system takes a nominal driving scenario as input, parameterizes controllable aspects of the scene (initial positions, velocities, trajectories of other agents, timing of events), defines differentiable failure objectives (collision probability, minimum time-to-collision, off-road excursion, uncomfortable braking), and then optimizes the scenario parameters to maximize the probability of planner failure -- all while maintaining physical plausibility constraints so that the discovered scenarios are realistic rather than contrived.

This matters because testing autonomous vehicles is fundamentally a search problem in an astronomically large scenario space. A single 8-second driving scenario at 10 Hz with 32 agents has thousands of continuous parameters. Random sampling of this space is hopelessly inefficient: the failure-inducing regions are thin, high-dimensional manifolds embedded in a vast space of benign scenarios. Industry leaders like Waymo, Applied Intuition, and Cruise have recognized that targeted adversarial generation is the only tractable path to discovering the rare but safety-critical scenarios that matter. The NHTSA has estimated that proving an AD system is merely as safe as a human driver would require billions of miles of random testing -- adversarial methods aim to compress those billions of miles into thousands of carefully chosen scenarios.

Your deliverable is a modular JAX-based adversarial scenario generation framework consisting of: (1) a differentiable scenario parameterization and kinematic simulation engine, (2) a library of composable failure objectives that can be combined and weighted, (3) a gradient-based adversarial optimization loop with constraint projection to maintain physical plausibility, (4) a multi-restart search strategy for diversity, (5) a failure clustering and analysis pipeline that categorizes discovered failures by type and severity, and (6) a visualization and reporting system that communicates findings to safety engineers. The entire system should be JIT-compilable, vectorizable with jax.vmap, and extensible to new failure modes and constraint sets.


Learning Objectives

By completing this project, you will be able to:

  • Parameterize driving scenarios as differentiable optimization variables by representing agent initial conditions, trajectory waypoints, and behavioral parameters as continuous vectors that JAX can differentiate through.
  • Build a differentiable kinematic simulator in JAX that takes scenario parameters as input, rolls out multi-agent trajectories using a bicycle model, and produces collision/safety metrics as differentiable outputs -- enabling gradient-based search over the scenario space.
  • Design and implement failure objective functions including collision probability, minimum time-to-collision (TTC), off-road rate, uncomfortable braking, and lane violation, each formulated as smooth differentiable surrogates suitable for gradient optimization.
  • Implement gradient-based adversarial optimization using Adam (via Optax) to iteratively perturb scenario parameters toward failure, with constraint projection steps that enforce physical plausibility (speed limits, acceleration bounds, lane adherence) after each gradient update.
  • Apply multi-restart and diversity-promoting strategies to discover a broad set of distinct failure modes rather than converging repeatedly to the same adversarial example, using techniques such as random initialization, repulsion terms, and clustering-based restart selection.
  • Cluster and analyze discovered failures by extracting interpretable features from adversarial scenarios, applying unsupervised clustering to identify distinct failure categories (cut-in, sudden braking, jaywalker, occluded merge), and prioritizing them by severity and estimated real-world likelihood.
  • Evaluate adversarial search efficiency by comparing the failure discovery rate, diversity, and severity distribution against random search baselines, quantifying the value of gradient-based targeting.

Prerequisites

  • Required: JAX proficiency (comfortable with jax.grad, jax.jit, jax.vmap, jax.lax.scan, and pytree manipulation), gradient-based optimization fundamentals (loss functions, learning rates, Adam), and basic familiarity with kinematic vehicle models and driving scenario representation.
  • Recommended: Waymax basics (loading scenarios, stepping simulations, accessing state), experience with Optax for optimizer configuration, and familiarity with clustering algorithms (K-means, DBSCAN) for the failure analysis portion.
  • Deep Dive Reading:
    • Long-Tail Scenarios Deep Dive -- Comprehensive coverage of why rare scenarios dominate AD safety risk, methods for generating and mining edge cases, and the connection between adversarial testing and regulatory validation. Directly relevant to the motivation and approach of this project.

Key Concepts

The Adversarial Testing Problem

Testing an autonomous driving system is fundamentally different from testing conventional software. In conventional software, you can enumerate inputs and check outputs against specifications. In AD, the "input" is the entire state of the world -- every vehicle's position, velocity, and intent; every pedestrian's trajectory; every traffic signal's state; road geometry, weather, lighting, sensor noise -- and the space of possible inputs is continuous and effectively infinite.

The safety-critical scenarios -- the ones where the AD system might fail catastrophically -- are vanishingly rare in the overall distribution. A human driver in the US encounters a police-reported crash roughly once every 500,000 miles. If you test by randomly sampling scenarios from the naturalistic distribution, you will spend nearly all your compute budget on boring, safe scenarios where the planner works fine, and almost never encounter the dangerous ones.

Adversarial scenario generation inverts this: instead of sampling scenarios and hoping to hit failures, you define what failure looks like and then optimize the scenario parameters to produce it. This is a search problem, and gradient-based optimization is the most efficient search algorithm available when the objective is differentiable.

Random Testing:
  Sample scenario ~ P(scenarios)  ->  Run planner  ->  Check for failure
  Efficiency: O(1 / P(failure))  -- astronomically many samples needed

Adversarial Testing:
  Initialize scenario parameters theta
  Loop:
    Roll out scenario(theta) in differentiable sim
    Compute failure_score = f(trajectory)
    theta += alpha * grad(failure_score, theta)
  Until: failure found or budget exhausted
  Efficiency: O(gradient_steps) -- typically 50-500 steps

The key requirement is that the simulation must be differentiable end-to-end: from scenario parameters through the dynamics model to the failure metric. JAX makes this possible through its automatic differentiation capabilities, and its JIT compilation makes the optimization loop fast enough to be practical.

Scenario Parameterization

The first design decision is: what aspects of the scenario can the adversary control? The parameterization defines the search space and determines what kinds of failures can be discovered.

Initial condition parameters (simplest, most common):

  • Positions of other agents: (x_i, y_i) for each agent i
  • Velocities of other agents: (vx_i, vy_i) for each agent i
  • Headings of other agents: theta_i for each agent
  • These give you a parameter vector of size 5 * N_agents

Trajectory parameters (richer, more expressive):

  • Waypoints for each agent: (x_i^t, y_i^t) at selected timesteps
  • Or control inputs: (acceleration_i^t, steering_i^t) at each timestep
  • Interpolation between waypoints using splines for smoothness
  • Parameter vector size: 2 * N_agents * N_waypoints or 2 * N_agents * T

Behavioral parameters (most abstract):

  • Aggressiveness of each agent (maps to IDM parameters)
  • Reaction time, desired gap, politeness factor
  • These indirectly control trajectories through a behavioral model
  • Smaller parameter space, but requires a differentiable behavioral model

For this project, we use a hybrid approach: initial conditions plus trajectory waypoints, connected by a differentiable kinematic model. This gives enough expressiveness to discover diverse failures while keeping the optimization landscape manageable.

@dataclass
class ScenarioParams:
    """Differentiable scenario parameters."""
    ego_init: jnp.ndarray       # (4,) [x, y, heading, speed]
    other_inits: jnp.ndarray    # (N, 4) [x, y, heading, speed] per agent
    other_actions: jnp.ndarray  # (N, T, 2) [accel, steer_rate] per agent per step
    # Total parameters: 4 + N*4 + N*T*2

Differentiable Kinematic Simulation

To compute gradients through the simulation, every operation from parameters to metrics must be differentiable. The kinematic bicycle model is a natural choice:

$$x_{t+1} = x_t + v_t \cos(\theta_t) \cdot \Delta t$$ $$y_{t+1} = y_t + v_t \sin(\theta_t) \cdot \Delta t$$ $$\theta_{t+1} = \theta_t + \omega_t \cdot \Delta t$$ $$v_{t+1} = \max(0, v_t + a_t \cdot \Delta t)$$

where $(a_t, \omega_t)$ are the acceleration and yaw rate at time $t$.

The $\max(0, \cdot)$ operation for speed clamping is not differentiable at zero. We replace it with a smooth approximation: $\text{softplus}(x) = \log(1 + e^x) / \beta$ where $\beta$ controls the sharpness. Similarly, collision detection requires smooth distance functions rather than hard boolean checks.

The forward pass through the simulation is implemented with jax.lax.scan for efficiency:

def rollout(params, ego_policy, dt=0.1):
    """Differentiable forward simulation.

    params: ScenarioParams (initial conditions + actions for other agents)
    ego_policy: function mapping observation -> action (not optimized)
    Returns: all trajectories, shape (N+1, T, 4)
    """
    def step_fn(state, t):
        # Ego acts according to its policy (constant, not optimized)
        ego_obs = extract_observation(state)
        ego_action = ego_policy(ego_obs)

        # Other agents follow their parameterized actions
        other_actions = params.other_actions[:, t, :]

        # Kinematic update for all agents
        new_state = kinematic_step_batch(state, ego_action, other_actions, dt)
        return new_state, new_state

    init_state = build_initial_state(params)
    _, trajectory = jax.lax.scan(step_fn, init_state, jnp.arange(T))
    return trajectory

The ego vehicle follows a fixed policy (the system under test). The adversary controls only the other agents' initial conditions and actions. This reflects the real-world threat model: the ego cannot control what other road users do.

Failure Objective Functions

The failure objective defines what the adversary is trying to achieve. It must be:

  1. Differentiable: gradients must flow from the objective back through the simulation to the scenario parameters.
  2. Informative: the gradient should point toward more dangerous scenarios even when the current scenario is safe (no "gradient desert" problem).
  3. Calibrated: the magnitude should reflect severity, enabling comparison across scenarios.

Collision proximity (primary objective): $$L_{\text{collision}} = \sum_{t} \sum_{i} \text{sigmoid}\left(\frac{d_{\text{safe}} - |p_{\text{ego}}^t - p_i^t|_2}{\tau}\right)$$

where $d_{\text{safe}}$ is the safe distance threshold (sum of vehicle half-lengths plus a margin) and $\tau$ is a temperature controlling the sharpness of the sigmoid. When the ego is far from other agents, the loss is near zero but the gradient still points toward reducing distance. As the ego approaches collision distance, the loss rapidly increases.

Time-to-collision (TTC): $$L_{\text{TTC}} = \sum_{t} \max\left(0, \frac{1}{\text{TTC}^t} - \frac{1}{\text{TTC}_{\text{safe}}}\right)$$

TTC is computed as $d / v_{\text{closing}}$ where $v_{\text{closing}} = -\dot{d}$ is the rate at which the gap is shrinking. For differentiability, we use smooth approximations for the division and the max.

Off-road excursion: $$L_{\text{offroad}} = \sum_{t} \text{ReLU}(d_{\text{lane}}(p_{\text{ego}}^t) - w_{\text{lane}}/2)$$

where $d_{\text{lane}}$ is the signed distance from lane center and $w_{\text{lane}}$ is the lane width. This penalizes the ego for leaving the road.

Uncomfortable braking: $$L_{\text{comfort}} = \sum_{t} \text{ReLU}(-a_{\text{ego}}^t - a_{\text{comfort}})$$

This penalizes (rewards, from the adversary's perspective) scenarios that force the ego into hard braking (deceleration exceeding a comfort threshold like $3 \text{ m/s}^2$).

Combined objective: $$L_{\text{total}} = w_1 L_{\text{collision}} + w_2 L_{\text{TTC}} + w_3 L_{\text{offroad}} + w_4 L_{\text{comfort}}$$

The weights $w_i$ control the relative importance of each failure mode and can be adjusted to focus the search on specific types of failures.

Constraint Satisfaction and Physical Plausibility

An unconstrained adversary will produce physically impossible scenarios: vehicles teleporting, accelerating at 100 m/s^2, driving through buildings. The discovered failures would be meaningless because they could never occur in the real world. Constraint satisfaction is therefore essential.

Hard constraints (enforced by projection after each gradient step):

  • Speed: $0 \leq v_i^t \leq v_{\max}$ (typically 35 m/s for highways, 15 m/s for urban)
  • Acceleration: $a_{\min} \leq a_i^t \leq a_{\max}$ (typically $[-6, 4]$ m/s^2)
  • Yaw rate: $|\omega_i^t| \leq \omega_{\max}$ (function of speed and vehicle geometry)
  • Lane adherence: agents should be on or near driveable surface
  • No initial overlap: agents should not start inside each other

Projection step:

def project_to_constraints(params):
    """Project scenario parameters back to the feasible set."""
    # Clip accelerations
    actions = jnp.clip(params.other_actions[..., 0], -6.0, 4.0)
    # Clip steering rates (speed-dependent)
    steer = jnp.clip(params.other_actions[..., 1], -0.5, 0.5)
    # Clip initial speeds
    speeds = jnp.clip(params.other_inits[:, 3], 0.0, 35.0)
    return params.replace(
        other_actions=jnp.stack([actions, steer], axis=-1),
        other_inits=params.other_inits.at[:, 3].set(speeds),
    )

This is a projected gradient descent approach: take a gradient step, then project back to the feasible set. For box constraints (min/max bounds), projection is just clipping. For more complex constraints (e.g., "agent must stay on road"), projection requires solving a constrained optimization subproblem or using a soft penalty.

Soft constraints (added as penalty terms to the objective): $$L_{\text{plausibility}} = \lambda_{\text{road}} \sum_{i,t} d_{\text{road}}(p_i^t)^2 + \lambda_{\text{kinematic}} \sum_{i,t} \text{ReLU}(|a_i^t| - a_{\max})^2$$

The advantage of soft constraints is simplicity (no projection step needed), but they require tuning the penalty weights $\lambda$ and do not guarantee strict feasibility.

Coverage and Diversity

Finding one adversarial scenario is useful; finding a diverse set that covers multiple failure modes is much more valuable. Without explicit diversity promotion, gradient-based optimization tends to converge repeatedly to the same failure -- often the "easiest" one for the optimizer to find.

Multi-restart strategy: Run the optimization from many different random initializations. Each restart may converge to a different local optimum in the failure landscape.

Repulsion term: Add a penalty that pushes new adversarial scenarios away from previously discovered ones: $$L_{\text{diversity}} = -\alpha \sum_{j \in \text{found}} \exp\left(-\frac{|\theta - \theta_j|^2}{2\sigma^2}\right)$$

This kernel-based repulsion (similar to Determinantal Point Processes) encourages the optimizer to explore new regions of the failure space.

Feature-space clustering: After running many restarts, extract interpretable features from each discovered failure (e.g., relative position of adversary at collision time, type of maneuver, severity) and cluster them to identify distinct failure categories.

Connection to Safety Analysis

Adversarial scenario generation is not just a clever testing technique -- it connects to formal safety analysis frameworks:

  • STPA (Systems-Theoretic Process Analysis): STPA identifies "unsafe control actions" and "loss scenarios" through systematic analysis. Adversarial generation can be guided by STPA-identified hazards, focusing the search on scenarios where specific control actions are likely to fail.
  • SOTIF (Safety Of The Intended Functionality, ISO 21448): SOTIF explicitly requires testing for "triggering conditions" that cause the system to fail despite correct hardware and software. Adversarial generation directly targets these triggering conditions.
  • Operational Design Domain (ODD): The constraint set in adversarial generation defines the ODD boundary. Scenarios that violate constraints are outside the ODD and irrelevant; scenarios within constraints that cause failure represent genuine safety risks.

Step-by-Step Implementation Guide

Step 1: Environment Setup (30 min)

Set up the project structure and verify that all dependencies work correctly.

1.1 Project Directory Structure

adversarial-scenario-generator/
├── adversarial/
│   ├── __init__.py
│   ├── scenario.py          # Scenario parameterization
│   ├── simulator.py         # Differentiable kinematic simulation
│   ├── objectives.py        # Failure objective functions
│   ├── optimizer.py         # Adversarial optimization loop
│   ├── constraints.py       # Constraint projection
│   ├── analysis.py          # Failure clustering and analysis
│   └── visualization.py     # Plotting and reporting
├── notebooks/
│   ├── 01_scenario_parameterization.ipynb
│   ├── 02_adversarial_optimization.ipynb
│   └── 03_failure_analysis.ipynb
├── results/                 # Output directory for discovered failures
├── requirements.txt
└── README.md

1.2 Install Dependencies

python -m venv .venv
source .venv/bin/activate

pip install jax jaxlib optax
pip install numpy matplotlib scipy pandas seaborn
pip install tqdm ipykernel jupyter

1.3 Verify Setup

import jax
import jax.numpy as jnp
import optax

print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")

# Verify gradient computation
def test_fn(x):
    return jnp.sum(x ** 2)

x = jnp.array([1.0, 2.0, 3.0])
grad_fn = jax.grad(test_fn)
print(f"grad(sum(x^2)) at x=[1,2,3]: {grad_fn(x)}")
# Expected: [2.0, 4.0, 6.0]

# Verify Optax
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(x)
grads = grad_fn(x)
updates, opt_state = optimizer.update(grads, opt_state)
new_x = optax.apply_updates(x, updates)
print(f"After one Adam step: {new_x}")
print("Setup verified.")

Step 2: Scenario Parameterization and Differentiable Simulation (Notebook 01, ~75 min)

Build the foundation: represent driving scenarios as vectors of differentiable parameters and implement a kinematic simulator that JAX can differentiate through.

2.1 Vehicle State and Scenario Representation

import jax
import jax.numpy as jnp
from typing import NamedTuple

class VehicleState(NamedTuple):
    """State of a single vehicle."""
    x: jnp.ndarray       # longitudinal position (m)
    y: jnp.ndarray       # lateral position (m)
    heading: jnp.ndarray  # heading angle (rad)
    speed: jnp.ndarray    # scalar speed (m/s)

class ScenarioParams(NamedTuple):
    """Differentiable scenario parameters.

    These are the variables the adversary can optimize.
    The ego vehicle follows a fixed policy (not optimized).
    """
    ego_init: jnp.ndarray        # (4,) [x, y, heading, speed]
    other_inits: jnp.ndarray     # (N, 4) initial [x, y, heading, speed]
    other_actions: jnp.ndarray   # (N, T, 2) [acceleration, yaw_rate] per step

# Example: 3 other agents, 80 timesteps
N_AGENTS = 3
T = 80
DT = 0.1

# Initialize with zeros (benign scenario)
params = ScenarioParams(
    ego_init=jnp.array([0.0, 0.0, 0.0, 15.0]),
    other_inits=jnp.zeros((N_AGENTS, 4)),
    other_actions=jnp.zeros((N_AGENTS, T, 2)),
)

total_params = 4 + N_AGENTS * 4 + N_AGENTS * T * 2
print(f"Total optimizable parameters: {total_params}")
print(f"  Ego init: 4")
print(f"  Other inits: {N_AGENTS * 4}")
print(f"  Other actions: {N_AGENTS * T * 2}")

2.2 Differentiable Kinematic Step

The key insight is that every operation must be composed of JAX-differentiable primitives. The standard max(0, speed) is replaced with jax.nn.softplus for smooth gradients near zero speed.

def softplus(x, beta=5.0):
    """Smooth approximation to max(0, x)."""
    return jax.nn.softplus(x * beta) / beta

def kinematic_step(state, action, dt=0.1):
    """
    One step of the kinematic bicycle model.
    Fully differentiable in JAX.

    state: (4,) [x, y, heading, speed]
    action: (2,) [acceleration, yaw_rate]
    """
    x, y, heading, speed = state[0], state[1], state[2], state[3]
    accel, yaw_rate = action[0], action[1]

    # Speed update with soft clamping (differentiable version of max(0, ...))
    new_speed = softplus(speed + accel * dt)

    # Position update
    new_x = x + speed * jnp.cos(heading) * dt
    new_y = y + speed * jnp.sin(heading) * dt

    # Heading update
    new_heading = heading + yaw_rate * dt

    return jnp.array([new_x, new_y, new_heading, new_speed])

# Test: verify gradients exist
state = jnp.array([0.0, 0.0, 0.0, 10.0])
action = jnp.array([1.0, 0.1])

# Gradient of final x-position w.r.t. action
def final_x(action):
    s = kinematic_step(state, action)
    return s[0]

grad_action = jax.grad(final_x)(action)
print(f"Gradient of x w.r.t. action: {grad_action}")

2.3 Batch Rollout with jax.lax.scan

def rollout_single_agent(init_state, actions, dt=0.1):
    """
    Roll out a single agent's trajectory given initial state and actions.

    init_state: (4,) [x, y, heading, speed]
    actions: (T, 2) [acceleration, yaw_rate] per timestep
    Returns: (T+1, 4) trajectory including initial state
    """
    def step_fn(state, action):
        new_state = kinematic_step(state, action, dt)
        return new_state, new_state

    _, trajectory = jax.lax.scan(step_fn, init_state, actions)
    # Prepend initial state
    trajectory = jnp.concatenate([init_state[None, :], trajectory], axis=0)
    return trajectory

# Vectorize over agents
rollout_batch = jax.vmap(rollout_single_agent, in_axes=(0, 0, None))

2.4 Full Scenario Rollout

def rollout_scenario(params, ego_policy_fn, dt=0.1):
    """
    Roll out a complete scenario.

    The ego follows ego_policy_fn (fixed, not optimized).
    Other agents follow their parameterized actions.

    Returns:
        ego_traj: (T+1, 4)
        other_trajs: (N, T+1, 4)
    """
    # Other agents: simple rollout from parameterized actions
    other_trajs = rollout_batch(
        params.other_inits, params.other_actions, dt
    )

    # Ego: reactive policy that observes other agents
    def ego_step_fn(ego_state, t):
        # Build observation for ego policy
        other_states_at_t = other_trajs[:, t, :]
        ego_action = ego_policy_fn(ego_state, other_states_at_t)
        new_ego = kinematic_step(ego_state, ego_action, dt)
        return new_ego, new_ego

    _, ego_traj = jax.lax.scan(
        ego_step_fn, params.ego_init, jnp.arange(params.other_actions.shape[1])
    )
    ego_traj = jnp.concatenate([params.ego_init[None, :], ego_traj], axis=0)

    return ego_traj, other_trajs

2.5 Simple Ego Policy (System Under Test)

For this project, we use a simple rule-based ego policy as the system under test. In practice, this would be your actual AD planner.

def simple_ego_policy(ego_state, other_states):
    """
    Simple ego policy: constant speed with basic collision avoidance.
    This is the "system under test" that the adversary tries to break.
    """
    desired_speed = 15.0  # m/s
    ego_x, ego_y, ego_heading, ego_speed = (
        ego_state[0], ego_state[1], ego_state[2], ego_state[3]
    )

    # Basic IDM-like following for the closest agent ahead
    # Compute distances to other agents
    dx = other_states[:, 0] - ego_x
    dy = other_states[:, 1] - ego_y

    # Relative position in ego frame
    cos_h, sin_h = jnp.cos(ego_heading), jnp.sin(ego_heading)
    lon_dist = dx * cos_h + dy * sin_h  # longitudinal distance
    lat_dist = -dx * sin_h + dy * cos_h  # lateral distance

    # Find closest agent ahead in same lane (within 2m lateral)
    in_lane = jnp.abs(lat_dist) < 2.0
    ahead = lon_dist > 0
    valid = in_lane & ahead

    # Safe distance = large number if no one ahead
    gap = jnp.where(valid, lon_dist, 1000.0)
    min_gap = jnp.min(gap)

    # IDM acceleration
    s0 = 2.0   # minimum gap
    T_hw = 1.5 # time headway
    a_max = 2.0
    b = 3.0

    s_star = s0 + ego_speed * T_hw
    accel = a_max * (1 - (ego_speed / desired_speed)**4 - (s_star / min_gap)**2)
    accel = jnp.clip(accel, -6.0, a_max)

    # No steering (drives straight)
    steer = 0.0

    return jnp.array([accel, steer])

2.6 Verifying Gradient Flow

The critical test: can we differentiate through the entire simulation?

def scenario_loss(params):
    """
    Dummy loss: minimize distance between ego and closest other agent.
    This tests that gradients flow end-to-end.
    """
    ego_traj, other_trajs = rollout_scenario(params, simple_ego_policy)

    # Minimum distance between ego and any other agent across all timesteps
    # ego_traj: (T+1, 4), other_trajs: (N, T+1, 4)
    ego_pos = ego_traj[:, :2]           # (T+1, 2)
    other_pos = other_trajs[:, :, :2]   # (N, T+1, 2)

    # Broadcast and compute distances
    distances = jnp.sqrt(
        jnp.sum((ego_pos[None, :, :] - other_pos) ** 2, axis=-1)
    )  # (N, T+1)

    return jnp.min(distances)

# Compute gradient
grad_fn = jax.grad(scenario_loss)
grads = grad_fn(params)

print("Gradient shapes:")
print(f"  ego_init grad: {grads.ego_init.shape}")
print(f"  other_inits grad: {grads.other_inits.shape}")
print(f"  other_actions grad: {grads.other_actions.shape}")
print(f"  grad norms: ego_init={jnp.linalg.norm(grads.ego_init):.4f}, "
      f"other_inits={jnp.linalg.norm(grads.other_inits):.4f}, "
      f"other_actions={jnp.linalg.norm(grads.other_actions):.4f}")

If all gradient norms are non-zero, the differentiable simulation is working correctly and the adversarial optimizer has a gradient signal to follow.

2.7 Visualization

def plot_scenario(ego_traj, other_trajs, title="Scenario", ax=None):
    """Bird's-eye view of a multi-agent scenario."""
    if ax is None:
        fig, ax = plt.subplots(figsize=(12, 6))

    # Road
    ax.axhspan(-5.5, 5.5, color='#e0e0e0', alpha=0.3)
    ax.axhline(y=-1.85, color='white', linestyle='--', linewidth=0.8)
    ax.axhline(y=1.85, color='white', linestyle='--', linewidth=0.8)
    ax.axhline(y=-5.5, color='black', linewidth=2)
    ax.axhline(y=5.5, color='black', linewidth=2)

    # Ego trajectory
    ax.plot(ego_traj[:, 0], ego_traj[:, 1], 'b-', linewidth=2, label='Ego')
    ax.plot(ego_traj[0, 0], ego_traj[0, 1], 'bs', markersize=8)
    ax.plot(ego_traj[-1, 0], ego_traj[-1, 1], 'b^', markersize=8)

    # Other agent trajectories
    colors = ['red', 'orange', 'purple', 'green', 'brown']
    for i in range(other_trajs.shape[0]):
        c = colors[i % len(colors)]
        ax.plot(other_trajs[i, :, 0], other_trajs[i, :, 1],
                f'-', color=c, linewidth=2, label=f'Agent {i}')
        ax.plot(other_trajs[i, 0, 0], other_trajs[i, 0, 1],
                's', color=c, markersize=8)
        ax.plot(other_trajs[i, -1, 0], other_trajs[i, -1, 1],
                '^', color=c, markersize=8)

    ax.set_xlabel('x (m)')
    ax.set_ylabel('y (m)')
    ax.set_title(title)
    ax.legend(fontsize=9)
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.3)
    return ax

Key implementation notes for Step 2:

  • Use jax.lax.scan instead of Python for-loops for the simulation rollout. Python loops will not be compiled by jax.jit efficiently and will be extremely slow.
  • The softplus approximation for speed clamping is critical. Without it, gradients die at zero speed, and the optimizer cannot learn to create scenarios where agents start from rest and then move.
  • NamedTuple is a natural pytree in JAX, so ScenarioParams works seamlessly with jax.grad, jax.jit, and jax.vmap.
  • The ego policy must also be JAX-compatible (no Python control flow that depends on array values). Use jnp.where instead of if/else statements.

Step 3: Adversarial Optimization (Notebook 02, ~90 min)

Implement the gradient-based adversarial search loop that perturbs scenario parameters to maximize the probability of planner failure.

3.1 Failure Objective Functions

# adversarial/objectives.py

def collision_objective(ego_traj, other_trajs, safe_dist=5.0, temperature=1.0):
    """
    Soft collision proximity loss.
    Higher values = closer to collision = worse for ego = better for adversary.

    Uses sigmoid to create smooth gradient field even when vehicles are far apart.
    """
    ego_pos = ego_traj[:, :2]           # (T, 2)
    other_pos = other_trajs[:, :, :2]   # (N, T, 2)

    # Pairwise distances: (N, T)
    distances = jnp.sqrt(
        jnp.sum((ego_pos[None, :, :] - other_pos) ** 2, axis=-1) + 1e-6
    )

    # Sigmoid: approaches 1 when distance < safe_dist
    collision_score = jax.nn.sigmoid((safe_dist - distances) / temperature)

    # Sum over agents and time (higher = more collision-prone)
    return jnp.sum(collision_score)


def min_ttc_objective(ego_traj, other_trajs, dt=0.1, ttc_threshold=3.0):
    """
    Time-to-collision objective.
    Penalizes (rewards for adversary) low TTC values.
    """
    ego_pos = ego_traj[:, :2]
    ego_vel = jnp.diff(ego_pos, axis=0) / dt  # (T-1, 2)
    other_pos = other_trajs[:, :, :2]
    other_vel = jnp.diff(other_pos, axis=1) / dt  # (N, T-1, 2)

    # Relative position and velocity
    rel_pos = other_pos[:, :-1, :] - ego_pos[None, :-1, :]  # (N, T-1, 2)
    rel_vel = other_vel - ego_vel[None, :, :]                # (N, T-1, 2)

    # Distance and closing speed
    dist = jnp.sqrt(jnp.sum(rel_pos ** 2, axis=-1) + 1e-6)  # (N, T-1)
    # Closing speed = -d(distance)/dt (positive when approaching)
    closing_speed = -jnp.sum(rel_pos * rel_vel, axis=-1) / (dist + 1e-6)

    # TTC = distance / closing_speed (when closing_speed > 0)
    # Use softplus to handle negative closing speeds smoothly
    safe_closing = softplus(closing_speed, beta=2.0) + 1e-6
    ttc = dist / safe_closing

    # Reward low TTC: 1/TTC clipped for stability
    inv_ttc = 1.0 / (ttc + 0.1)
    ttc_penalty = jnp.sum(jnp.where(ttc < ttc_threshold, inv_ttc, 0.0))

    return ttc_penalty


def offroad_objective(ego_traj, lane_center_y=0.0, lane_width=3.7):
    """
    Off-road excursion loss.
    Rewards the adversary for pushing the ego out of its lane.
    """
    ego_y = ego_traj[:, 1]
    dist_from_center = jnp.abs(ego_y - lane_center_y)
    offroad_amount = jax.nn.relu(dist_from_center - lane_width / 2)
    return jnp.sum(offroad_amount)


def hard_braking_objective(ego_traj, dt=0.1, comfort_decel=3.0):
    """
    Uncomfortable braking loss.
    Rewards scenarios that force the ego into hard braking.
    """
    ego_speed = ego_traj[:, 3]
    accel = jnp.diff(ego_speed) / dt
    hard_brake = jax.nn.relu(-accel - comfort_decel)  # amount beyond comfort threshold
    return jnp.sum(hard_brake)


def combined_objective(ego_traj, other_trajs, weights=None):
    """
    Weighted combination of all failure objectives.
    Returns scalar loss (higher = more adversarial).
    """
    if weights is None:
        weights = {
            'collision': 1.0,
            'ttc': 0.5,
            'offroad': 0.3,
            'braking': 0.2,
        }

    loss = 0.0
    loss += weights['collision'] * collision_objective(ego_traj, other_trajs)
    loss += weights['ttc'] * min_ttc_objective(ego_traj, other_trajs)
    loss += weights['offroad'] * offroad_objective(ego_traj)
    loss += weights['braking'] * hard_braking_objective(ego_traj)

    return loss

3.2 Constraint Projection

# adversarial/constraints.py

def project_params(params, config=None):
    """
    Project scenario parameters back to the physically plausible set.
    Called after each gradient step.
    """
    if config is None:
        config = {
            'max_speed': 30.0,       # m/s (~108 km/h)
            'min_speed': 0.0,
            'max_accel': 4.0,        # m/s^2
            'min_accel': -8.0,       # m/s^2 (hard braking)
            'max_yaw_rate': 0.5,     # rad/s
            'road_y_min': -5.5,      # road boundaries
            'road_y_max': 5.5,
            'x_min': -50.0,          # longitudinal range
            'x_max': 200.0,
        }

    # Clip other agent actions
    accel = jnp.clip(
        params.other_actions[:, :, 0],
        config['min_accel'], config['max_accel']
    )
    yaw_rate = jnp.clip(
        params.other_actions[:, :, 1],
        -config['max_yaw_rate'], config['max_yaw_rate']
    )
    clipped_actions = jnp.stack([accel, yaw_rate], axis=-1)

    # Clip initial conditions
    clipped_inits = params.other_inits.at[:, 3].set(
        jnp.clip(params.other_inits[:, 3], config['min_speed'], config['max_speed'])
    )
    clipped_inits = clipped_inits.at[:, 1].set(
        jnp.clip(clipped_inits[:, 1], config['road_y_min'], config['road_y_max'])
    )
    clipped_inits = clipped_inits.at[:, 0].set(
        jnp.clip(clipped_inits[:, 0], config['x_min'], config['x_max'])
    )

    return ScenarioParams(
        ego_init=params.ego_init,  # ego init is fixed
        other_inits=clipped_inits,
        other_actions=clipped_actions,
    )


def smoothness_penalty(params, weight=0.01):
    """
    Penalty for jerky, unrealistic agent behavior.
    Encourages smooth acceleration and steering profiles.
    """
    # Finite difference of actions (jerk)
    action_diff = jnp.diff(params.other_actions, axis=1)
    jerk_penalty = jnp.sum(action_diff ** 2)
    return weight * jerk_penalty

3.3 Adversarial Optimization Loop

# adversarial/optimizer.py
import optax
from tqdm import tqdm

def adversarial_search(
    initial_params,
    ego_policy_fn,
    n_steps=200,
    learning_rate=0.01,
    objective_weights=None,
    constraint_config=None,
    smoothness_weight=0.01,
    verbose=True,
):
    """
    Gradient-based adversarial scenario search.

    Maximizes the failure objective while maintaining physical plausibility.
    Uses Optax Adam optimizer with constraint projection after each step.

    Args:
        initial_params: ScenarioParams to start from
        ego_policy_fn: the system under test
        n_steps: number of optimization steps
        learning_rate: Adam learning rate
        objective_weights: weights for combined objective
        constraint_config: physical plausibility constraints
        smoothness_weight: weight for smoothness regularization

    Returns:
        best_params: scenario parameters that maximize failure
        history: dict of loss values over optimization
    """
    optimizer = optax.adam(learning_rate)
    opt_state = optimizer.init(initial_params)

    @jax.jit
    def loss_fn(params):
        ego_traj, other_trajs = rollout_scenario(params, ego_policy_fn)
        adv_loss = combined_objective(ego_traj, other_trajs, objective_weights)
        smooth_loss = smoothness_penalty(params, smoothness_weight)
        # Negate because we want to MAXIMIZE failure (minimize negative loss)
        return -(adv_loss - smooth_loss)

    grad_fn = jax.grad(loss_fn)

    params = initial_params
    best_params = params
    best_loss = float('inf')
    history = {'loss': [], 'collision_score': [], 'min_distance': []}

    for step in range(n_steps):
        # Compute gradients
        grads = grad_fn(params)

        # Adam update
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)

        # Project to feasible set
        params = project_params(params, constraint_config)

        # Evaluate current scenario
        ego_traj, other_trajs = rollout_scenario(params, ego_policy_fn)
        current_loss = float(loss_fn(params))

        # Track metrics
        distances = jnp.sqrt(jnp.sum(
            (ego_traj[None, :, :2] - other_trajs[:, :, :2]) ** 2, axis=-1
        ))
        min_dist = float(jnp.min(distances))
        col_score = float(collision_objective(ego_traj, other_trajs))

        history['loss'].append(current_loss)
        history['collision_score'].append(col_score)
        history['min_distance'].append(min_dist)

        if current_loss < best_loss:
            best_loss = current_loss
            best_params = params

        if verbose and step % 20 == 0:
            print(f"Step {step:4d} | loss={current_loss:.4f} | "
                  f"min_dist={min_dist:.2f}m | col_score={col_score:.2f}")

    return best_params, history

3.4 Multi-Restart for Diversity

def multi_restart_search(
    base_params,
    ego_policy_fn,
    n_restarts=10,
    n_steps_per_restart=200,
    key=jax.random.PRNGKey(0),
    **kwargs,
):
    """
    Run adversarial search from multiple random initializations.
    Returns a list of discovered adversarial scenarios sorted by severity.
    """
    all_results = []

    for i in range(n_restarts):
        key, subkey = jax.random.split(key)

        # Random perturbation of initial scenario
        perturbed_params = random_perturbation(base_params, subkey)

        print(f"\n--- Restart {i+1}/{n_restarts} ---")
        best_params, history = adversarial_search(
            perturbed_params, ego_policy_fn,
            n_steps=n_steps_per_restart, **kwargs
        )

        # Evaluate final scenario
        ego_traj, other_trajs = rollout_scenario(best_params, ego_policy_fn)
        min_dist = float(jnp.min(jnp.sqrt(jnp.sum(
            (ego_traj[None, :, :2] - other_trajs[:, :, :2]) ** 2, axis=-1
        ))))

        all_results.append({
            'params': best_params,
            'history': history,
            'min_distance': min_dist,
            'final_loss': history['loss'][-1],
            'restart_id': i,
        })

    # Sort by severity (lowest min_distance = most dangerous)
    all_results.sort(key=lambda r: r['min_distance'])
    return all_results


def random_perturbation(params, key, position_std=10.0, speed_std=3.0, action_std=0.5):
    """Generate a random perturbation of scenario parameters."""
    keys = jax.random.split(key, 3)

    # Perturb other agent initial conditions
    pos_noise = jax.random.normal(keys[0], params.other_inits[:, :2].shape) * position_std
    speed_noise = jax.random.normal(keys[1], params.other_inits[:, 3:4].shape) * speed_std

    new_inits = params.other_inits.at[:, :2].add(pos_noise)
    new_inits = new_inits.at[:, 3:4].add(speed_noise)

    # Perturb actions
    action_noise = jax.random.normal(keys[2], params.other_actions.shape) * action_std
    new_actions = params.other_actions + action_noise

    return ScenarioParams(
        ego_init=params.ego_init,
        other_inits=new_inits,
        other_actions=new_actions,
    )

3.5 Comparing with Random Search Baseline

def random_search_baseline(
    base_params,
    ego_policy_fn,
    n_samples=1000,
    key=jax.random.PRNGKey(42),
):
    """
    Random search: sample many random scenarios and keep the most adversarial.
    This is the baseline that gradient-based search should beat.
    """
    results = []

    for i in range(n_samples):
        key, subkey = jax.random.split(key)
        random_params = random_perturbation(base_params, subkey)
        random_params = project_params(random_params)

        ego_traj, other_trajs = rollout_scenario(random_params, ego_policy_fn)
        distances = jnp.sqrt(jnp.sum(
            (ego_traj[None, :, :2] - other_trajs[:, :, :2]) ** 2, axis=-1
        ))
        min_dist = float(jnp.min(distances))
        col_score = float(collision_objective(ego_traj, other_trajs))

        results.append({
            'params': random_params,
            'min_distance': min_dist,
            'collision_score': col_score,
        })

    results.sort(key=lambda r: r['min_distance'])
    return results

3.6 Visualization of Optimization Trajectory

def plot_optimization_trajectory(history, title="Adversarial Optimization"):
    """Plot loss, collision score, and min distance over optimization steps."""
    fig, axes = plt.subplots(1, 3, figsize=(16, 4))

    axes[0].plot(history['loss'], 'b-', linewidth=1.5)
    axes[0].set_xlabel('Step')
    axes[0].set_ylabel('Loss (negated)')
    axes[0].set_title('Optimization Loss')
    axes[0].grid(True, alpha=0.3)

    axes[1].plot(history['collision_score'], 'r-', linewidth=1.5)
    axes[1].set_xlabel('Step')
    axes[1].set_ylabel('Collision Score')
    axes[1].set_title('Collision Proximity')
    axes[1].grid(True, alpha=0.3)

    axes[2].plot(history['min_distance'], 'g-', linewidth=1.5)
    axes[2].axhline(y=5.0, color='red', linestyle='--', label='Safe distance')
    axes[2].set_xlabel('Step')
    axes[2].set_ylabel('Min Distance (m)')
    axes[2].set_title('Minimum Ego-Other Distance')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)

    fig.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()


def plot_before_after(initial_params, adversarial_params, ego_policy_fn):
    """Side-by-side comparison of initial and adversarial scenarios."""
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))

    ego_init, others_init = rollout_scenario(initial_params, ego_policy_fn)
    plot_scenario(ego_init, others_init, title="Initial (Benign) Scenario", ax=axes[0])

    ego_adv, others_adv = rollout_scenario(adversarial_params, ego_policy_fn)
    plot_scenario(ego_adv, others_adv, title="Adversarial Scenario", ax=axes[1])

    plt.tight_layout()
    plt.show()

    # Print comparison stats
    init_dist = float(jnp.min(jnp.sqrt(jnp.sum(
        (ego_init[None, :, :2] - others_init[:, :, :2]) ** 2, axis=-1))))
    adv_dist = float(jnp.min(jnp.sqrt(jnp.sum(
        (ego_adv[None, :, :2] - others_adv[:, :, :2]) ** 2, axis=-1))))

    print(f"Initial scenario min distance: {init_dist:.2f} m")
    print(f"Adversarial scenario min distance: {adv_dist:.2f} m")
    print(f"Distance reduction: {init_dist - adv_dist:.2f} m ({(1 - adv_dist/init_dist)*100:.1f}%)")

Key implementation notes for Step 3:

  • The loss function is negated because Optax minimizes and we want to maximize failure. Alternatively, you can use optax.apply_updates with negated gradients.
  • JIT-compile the loss function and gradient computation (shown with @jax.jit decorator). Without JIT, each optimization step takes seconds instead of milliseconds.
  • The constraint projection step is crucial. Without it, the optimizer will quickly find "adversarial" scenarios where other vehicles materialize at 200 m/s directly in front of the ego -- physically impossible and useless for safety analysis.
  • Start with a small learning rate (0.001-0.01) and observe the optimization trajectory. If the loss oscillates wildly, reduce the learning rate. If progress is too slow, increase it.
  • The smoothness penalty prevents the optimizer from discovering scenarios with discontinuous, physically impossible agent behavior (e.g., an agent that alternates between full throttle and full brake every timestep).

Step 4: Failure Mining and Analysis (Notebook 03, ~75 min)

Run adversarial search at scale, cluster discovered failures, and build a comprehensive failure report.

4.1 Running Adversarial Search at Scale

# Run multi-restart adversarial search
results = multi_restart_search(
    base_params=nominal_scenario_params,
    ego_policy_fn=simple_ego_policy,
    n_restarts=20,
    n_steps_per_restart=300,
    learning_rate=0.01,
)

# Also run random search for comparison
random_results = random_search_baseline(
    base_params=nominal_scenario_params,
    ego_policy_fn=simple_ego_policy,
    n_samples=2000,
)

print(f"Adversarial search: found {len(results)} scenarios")
print(f"Random search: found {len(random_results)} scenarios")

4.2 Feature Extraction

Before clustering, extract interpretable features from each discovered failure.

def extract_failure_features(params, ego_policy_fn, dt=0.1):
    """
    Extract interpretable features from an adversarial scenario.
    These features are used for clustering failures into categories.
    """
    ego_traj, other_trajs = rollout_scenario(params, ego_policy_fn)

    # Find the critical moment (minimum distance)
    distances = jnp.sqrt(jnp.sum(
        (ego_traj[None, :, :2] - other_trajs[:, :, :2]) ** 2, axis=-1
    ))
    min_idx = jnp.unravel_index(jnp.argmin(distances), distances.shape)
    critical_agent = int(min_idx[0])
    critical_time = int(min_idx[1])
    min_distance = float(distances[min_idx])

    # Relative geometry at critical moment
    ego_pos = ego_traj[critical_time, :2]
    ego_heading = ego_traj[critical_time, 2]
    ego_speed = ego_traj[critical_time, 3]

    other_pos = other_trajs[critical_agent, critical_time, :2]
    other_heading = other_trajs[critical_agent, critical_time, 2]
    other_speed = other_trajs[critical_agent, critical_time, 3]

    # Relative position in ego frame
    dx = other_pos[0] - ego_pos[0]
    dy = other_pos[1] - ego_pos[1]
    cos_h, sin_h = jnp.cos(ego_heading), jnp.sin(ego_heading)
    lon_rel = dx * cos_h + dy * sin_h   # positive = ahead
    lat_rel = -dx * sin_h + dy * cos_h  # positive = left

    # Relative heading
    heading_diff = other_heading - ego_heading
    heading_diff = (heading_diff + jnp.pi) % (2 * jnp.pi) - jnp.pi

    # Speed ratio
    speed_ratio = other_speed / (ego_speed + 1e-6)

    # Ego deceleration at critical moment
    ego_accel = jnp.diff(ego_traj[:, 3]) / dt
    critical_decel = float(ego_accel[min(critical_time, len(ego_accel)-1)])

    # Approach angle (from which direction does the threat come?)
    approach_angle = jnp.arctan2(lat_rel, lon_rel)

    features = {
        'min_distance': min_distance,
        'critical_time': critical_time * dt,
        'lon_relative': float(lon_rel),
        'lat_relative': float(lat_rel),
        'heading_diff': float(heading_diff),
        'speed_ratio': float(speed_ratio),
        'ego_decel': critical_decel,
        'approach_angle': float(approach_angle),
        'other_speed': float(other_speed),
        'ego_speed': float(ego_speed),
    }
    return features

4.3 Failure Clustering

from sklearn.cluster import KMeans, DBSCAN
from sklearn.preprocessing import StandardScaler

def cluster_failures(results, ego_policy_fn, n_clusters=5):
    """
    Cluster discovered adversarial scenarios into failure categories.
    """
    # Extract features for all failures
    all_features = []
    for r in results:
        feat = extract_failure_features(r['params'], ego_policy_fn)
        all_features.append(feat)

    # Convert to array for clustering
    feature_names = list(all_features[0].keys())
    X = np.array([[f[k] for k in feature_names] for f in all_features])

    # Standardize
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)

    # K-means clustering
    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
    labels = kmeans.fit_predict(X_scaled)

    # Assign human-readable names based on cluster centroids
    cluster_info = []
    for c in range(n_clusters):
        mask = labels == c
        cluster_features = {k: np.mean([f[k] for f, m in zip(all_features, mask) if m])
                          for k in feature_names}

        # Heuristic naming based on features
        name = classify_failure_type(cluster_features)

        cluster_info.append({
            'id': c,
            'name': name,
            'count': int(np.sum(mask)),
            'mean_features': cluster_features,
            'severity': np.mean([r['min_distance'] for r, m in zip(results, mask) if m]),
        })

    return labels, cluster_info, all_features


def classify_failure_type(features):
    """
    Heuristic classification of failure type based on extracted features.
    """
    lon = features['lon_relative']
    lat = features['lat_relative']
    heading_diff = features['heading_diff']
    speed_ratio = features['speed_ratio']

    if abs(lat) > 2.0 and lon > 0:
        if heading_diff < -0.3:
            return "Cut-in from left"
        elif heading_diff > 0.3:
            return "Cut-in from right"
        else:
            return "Side swipe"
    elif lon > 0 and abs(lat) < 2.0:
        if speed_ratio < 0.5:
            return "Sudden braking (lead vehicle)"
        else:
            return "Rear approach conflict"
    elif lon < 0 and abs(lat) < 2.0:
        return "Rear-end (ego is lead)"
    elif abs(heading_diff) > 2.5:
        return "Head-on approach"
    elif abs(lat) > 3.0:
        return "Lateral encroachment"
    else:
        return "Complex interaction"

4.4 Coverage Analysis

def coverage_analysis(adversarial_results, random_results, ego_policy_fn):
    """
    Compare failure coverage between adversarial and random search.
    """
    # Adversarial failures
    adv_features = [extract_failure_features(r['params'], ego_policy_fn)
                    for r in adversarial_results]

    # Random search failures (only keep those with min_dist < threshold)
    threshold = 5.0  # meters
    random_failures = [r for r in random_results if r['min_distance'] < threshold]
    rand_features = [extract_failure_features(r['params'], ego_policy_fn)
                     for r in random_failures]

    print(f"Adversarial: {len(adversarial_results)} scenarios total")
    print(f"  Failures (min_dist < {threshold}m): "
          f"{sum(1 for r in adversarial_results if r['min_distance'] < threshold)}")
    print(f"  Failure rate: "
          f"{sum(1 for r in adversarial_results if r['min_distance'] < threshold) / len(adversarial_results) * 100:.1f}%")

    print(f"\nRandom: {len(random_results)} scenarios total")
    print(f"  Failures (min_dist < {threshold}m): {len(random_failures)}")
    print(f"  Failure rate: {len(random_failures) / len(random_results) * 100:.1f}%")

    # Diversity comparison
    if len(adv_features) > 1:
        adv_X = np.array([[f[k] for k in adv_features[0].keys()] for f in adv_features])
        adv_diversity = np.mean(scipy.spatial.distance.pdist(
            StandardScaler().fit_transform(adv_X)))
        print(f"\nAdversarial diversity (mean pairwise distance): {adv_diversity:.3f}")

    if len(rand_features) > 1:
        rand_X = np.array([[f[k] for k in rand_features[0].keys()] for f in rand_features])
        rand_diversity = np.mean(scipy.spatial.distance.pdist(
            StandardScaler().fit_transform(rand_X)))
        print(f"Random diversity (mean pairwise distance): {rand_diversity:.3f}")

4.5 Severity Prioritization

def prioritize_failures(results, ego_policy_fn):
    """
    Rank failures by a composite severity score.

    Severity considers:
    - How close to collision (min distance)
    - How hard the ego had to brake (deceleration)
    - How early in the scenario the failure occurs (earlier = harder to avoid)
    - Plausibility of the adversarial scenario
    """
    scored_results = []

    for r in results:
        features = extract_failure_features(r['params'], ego_policy_fn)

        # Proximity score: 1/distance, capped
        proximity = 1.0 / (features['min_distance'] + 0.1)

        # Braking severity: how hard ego braked
        brake_severity = max(0, -features['ego_decel'] / 6.0)  # normalized by max decel

        # Time pressure: earlier failures are harder to react to
        time_pressure = 1.0 - (features['critical_time'] / 8.0)

        # Composite severity
        severity = 0.5 * proximity + 0.3 * brake_severity + 0.2 * time_pressure

        scored_results.append({
            **r,
            'features': features,
            'severity_score': severity,
        })

    scored_results.sort(key=lambda r: -r['severity_score'])
    return scored_results

4.6 Failure Report Generation

def generate_failure_report(scored_results, cluster_info, ego_policy_fn):
    """
    Generate a comprehensive failure analysis report with visualizations.
    """
    fig = plt.figure(figsize=(20, 16))
    gs = fig.add_gridspec(3, 3, hspace=0.4, wspace=0.3)

    # 1. Severity distribution
    ax1 = fig.add_subplot(gs[0, 0])
    severities = [r['severity_score'] for r in scored_results]
    ax1.hist(severities, bins=20, color='#d62728', edgecolor='white', alpha=0.8)
    ax1.set_xlabel('Severity Score')
    ax1.set_ylabel('Count')
    ax1.set_title('Severity Distribution')
    ax1.grid(True, alpha=0.3)

    # 2. Failure type distribution (pie chart)
    ax2 = fig.add_subplot(gs[0, 1])
    types = [ci['name'] for ci in cluster_info]
    counts = [ci['count'] for ci in cluster_info]
    ax2.pie(counts, labels=types, autopct='%1.1f%%', colors=plt.cm.Set3.colors)
    ax2.set_title('Failure Categories')

    # 3. Min distance vs severity
    ax3 = fig.add_subplot(gs[0, 2])
    min_dists = [r['min_distance'] for r in scored_results]
    ax3.scatter(min_dists, severities, c='#1f77b4', alpha=0.6, edgecolors='k', linewidths=0.5)
    ax3.set_xlabel('Min Distance (m)')
    ax3.set_ylabel('Severity Score')
    ax3.set_title('Distance vs Severity')
    ax3.grid(True, alpha=0.3)

    # 4-6. Top 3 most severe failures (BEV plots)
    for i in range(min(3, len(scored_results))):
        ax = fig.add_subplot(gs[1, i])
        r = scored_results[i]
        ego_traj, other_trajs = rollout_scenario(r['params'], ego_policy_fn)
        plot_scenario(ego_traj, other_trajs,
                     title=f"#{i+1}: severity={r['severity_score']:.2f}, "
                           f"dist={r['min_distance']:.1f}m", ax=ax)

    # 7. Feature scatter (approach angle vs speed ratio)
    ax7 = fig.add_subplot(gs[2, 0])
    approach_angles = [r['features']['approach_angle'] for r in scored_results]
    speed_ratios = [r['features']['speed_ratio'] for r in scored_results]
    ax7.scatter(np.degrees(approach_angles), speed_ratios,
               c=severities, cmap='Reds', alpha=0.7, edgecolors='k', linewidths=0.5)
    ax7.set_xlabel('Approach Angle (deg)')
    ax7.set_ylabel('Speed Ratio (other/ego)')
    ax7.set_title('Threat Geometry')
    ax7.grid(True, alpha=0.3)

    # 8. Cluster severity comparison
    ax8 = fig.add_subplot(gs[2, 1])
    cluster_names = [ci['name'] for ci in cluster_info]
    cluster_severities = [ci['severity'] for ci in cluster_info]
    ax8.barh(cluster_names, cluster_severities, color=plt.cm.Set2.colors)
    ax8.set_xlabel('Mean Min Distance (m)')
    ax8.set_title('Severity by Failure Type')

    # 9. Summary table
    ax9 = fig.add_subplot(gs[2, 2])
    ax9.axis('off')
    summary_text = (
        f"Total scenarios tested: {len(scored_results)}\n"
        f"Failures found (< 5m): {sum(1 for r in scored_results if r['min_distance'] < 5)}\n"
        f"Near-collisions (< 2m): {sum(1 for r in scored_results if r['min_distance'] < 2)}\n"
        f"Distinct failure types: {len(cluster_info)}\n"
        f"Most severe category: {cluster_info[0]['name'] if cluster_info else 'N/A'}\n"
        f"Mean severity: {np.mean(severities):.3f}\n"
        f"Max severity: {max(severities):.3f}"
    )
    ax9.text(0.1, 0.5, summary_text, fontsize=11, fontfamily='monospace',
            verticalalignment='center', transform=ax9.transAxes,
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
    ax9.set_title('Summary')

    fig.suptitle('Adversarial Scenario Generation -- Failure Report',
                fontsize=16, fontweight='bold')
    plt.savefig('failure_report.png', dpi=150, bbox_inches='tight')
    plt.show()

Key implementation notes for Step 4:

  • Feature extraction should be fast because it runs on every discovered scenario. JIT-compile the rollout_scenario function if not already done.
  • The heuristic failure classification function (classify_failure_type) is a starting point. In practice, you would refine the thresholds based on your specific ODD and scenario distribution.
  • Use DBSCAN instead of K-means if you do not know the number of failure clusters in advance. DBSCAN automatically determines the number of clusters based on density, but requires tuning the eps and min_samples parameters.
  • The coverage analysis is the key result that justifies adversarial generation over random testing. Expect to see 5-50x higher failure discovery rates with gradient-based search compared to random sampling.
  • Consider exporting discovered failures in a format compatible with your simulation platform (e.g., Waymax scenario format, OpenSCENARIO) so they can be replayed and shared with the safety team.

Exercises

Exercise 1: Custom Failure Objective (in Notebook 02)

Implement a new failure objective that targets a specific failure mode: the ego vehicle being forced to make an emergency lane change. The objective should reward scenarios where:

  1. An obstacle appears ahead in the ego's lane, requiring evasive action.
  2. The adjacent lane has limited gap, making the lane change dangerous.
  3. The ego must choose between braking hard or changing lanes into a tight gap.

Hints:

  • Define "lane change pressure" as a function of the gap in the ego's lane and the gaps in adjacent lanes.
  • Use a product of terms: (small gap ahead) * (small gap beside) = high pressure.
  • Test your objective by running adversarial search and inspecting the discovered scenarios visually.

Exercise 2: Constrained Adversary (in Notebook 02)

Modify the adversarial optimizer to enforce an additional constraint: all other agents must follow plausible driving behavior (not just kinematically feasible, but behaviorally plausible). Specifically:

  1. Each other agent should approximately follow IDM with respect to its own leader.
  2. The adversary can only modify the IDM parameters (desired speed, desired headway, politeness) rather than directly controlling actions.
  3. This creates a "naturalistic adversary" that is much harder to distinguish from real-world driving.

Hints:

  • Parameterize each agent by IDM parameters: [v_desired, s0, T, a_max, b].
  • Inside the rollout, compute actions from IDM given these parameters.
  • The optimization landscape is lower-dimensional but more complex.

Exercise 3: Adversarial Weather Conditions (in Notebook 03)

Extend the failure analysis to consider how environmental factors amplify adversarial scenarios. Add parameterized environmental conditions:

  1. Reduced visibility (fog/rain): modeled as increased sensor noise or reduced detection range for the ego's perception.
  2. Reduced friction: modeled as lower maximum braking capability for all vehicles.
  3. Increased reaction time: modeled as a delay in the ego's policy response.

Run adversarial search with and without these environmental factors and compare:

  • Does the set of discovered failures change?
  • Are some failure types only reachable under adverse conditions?
  • How does the severity distribution shift?

Summary and Next Steps

In this project you built a complete adversarial scenario generation framework:

  1. Differentiable scenario parameterization: Driving scenarios represented as continuous vectors that JAX can differentiate through, enabling gradient-based search.
  2. Differentiable kinematic simulation: A bicycle-model simulator implemented entirely in JAX with smooth approximations for non-differentiable operations (speed clamping, distance thresholds).
  3. Composable failure objectives: A library of differentiable failure metrics (collision proximity, TTC, off-road, hard braking) that can be weighted and combined to target specific failure modes.
  4. Gradient-based adversarial optimization: An Optax-based optimizer that iteratively perturbs scenario parameters toward failure, with constraint projection to maintain physical plausibility.
  5. Multi-restart diversity: A strategy for discovering diverse failure modes rather than converging to the same adversarial example repeatedly.
  6. Failure analysis pipeline: Feature extraction, clustering, coverage analysis, and severity prioritization that transforms raw adversarial scenarios into actionable safety insights.

Where to go from here:

  • Integration with Waymax: Replace the toy kinematic simulator with Waymax's full dynamics model for more realistic scenarios. Waymax's differentiable simulation enables the same gradient-based approach but with realistic multi-agent dynamics, road geometry, and traffic rules.
  • Learned ego policies: Replace the simple rule-based ego policy with a learned neural network policy (e.g., from the RL Training phase). Adversarial testing of learned policies often reveals failure modes that unit tests miss.
  • Scenario libraries: Build a curated library of adversarial scenarios organized by failure type, severity, and ODD region. This library becomes a regression test suite that every new policy version must pass.
  • Automated CI/CD integration: Run adversarial search as part of the continuous integration pipeline. Every code change to the planner triggers adversarial testing, and new failures are automatically filed as issues.
  • SOTIF compliance: Map discovered failure types to SOTIF hazard categories and use the coverage analysis to argue that your testing has explored the relevant portions of the scenario space.

References

  • Ding, W., et al. (2021). "Multimodal Safety-Critical Scenarios Generation for Decision-Making Algorithms Evaluation." IEEE Robotics and Automation Letters.
  • Rempe, D., et al. (2022). "Generating Useful Accident-Prone Driving Scenarios via a Learned Traffic Model." CVPR 2022.
  • Sun, C., et al. (2021). "Adversarial Evaluation of Autonomous Vehicles in Lane-Change Scenarios." IEEE Transactions on Intelligent Transportation Systems.
  • Koren, M., et al. (2018). "Adaptive Stress Testing for Autonomous Vehicles." IEEE Intelligent Vehicles Symposium.
  • Corso, A., et al. (2021). "A Survey of Algorithms for Black-Box Safety Validation of Cyber-Physical Systems." Journal of Artificial Intelligence Research.
  • Wachi, A., and Sui, Y. (2020). "Safe Reinforcement Learning in Constrained Markov Decision Processes." ICML 2020.
  • ISO 21448:2022. "Road vehicles -- Safety of the intended functionality."