Adversarial Scenario Generator
Systematically discover the scenarios that break your autonomous driving stack -- because the real world does not wait for you to find bugs by chance, and random testing cannot reach the dark corners where catastrophic failures hide.
Overview
In this project you will build an adversarial scenario generation system that uses gradient-based optimization to automatically discover driving scenarios where an autonomous driving (AD) planner fails. The system takes a nominal driving scenario as input, parameterizes controllable aspects of the scene (initial positions, velocities, trajectories of other agents, timing of events), defines differentiable failure objectives (collision probability, minimum time-to-collision, off-road excursion, uncomfortable braking), and then optimizes the scenario parameters to maximize the probability of planner failure -- all while maintaining physical plausibility constraints so that the discovered scenarios are realistic rather than contrived.
This matters because testing autonomous vehicles is fundamentally a search problem in an astronomically large scenario space. A single 8-second driving scenario at 10 Hz with 32 agents has thousands of continuous parameters. Random sampling of this space is hopelessly inefficient: the failure-inducing regions are thin, high-dimensional manifolds embedded in a vast space of benign scenarios. Industry leaders like Waymo, Applied Intuition, and Cruise have recognized that targeted adversarial generation is the only tractable path to discovering the rare but safety-critical scenarios that matter. The NHTSA has estimated that proving an AD system is merely as safe as a human driver would require billions of miles of random testing -- adversarial methods aim to compress those billions of miles into thousands of carefully chosen scenarios.
Your deliverable is a modular JAX-based adversarial scenario generation framework consisting of: (1) a differentiable scenario parameterization and kinematic simulation engine, (2) a library of composable failure objectives that can be combined and weighted, (3) a gradient-based adversarial optimization loop with constraint projection to maintain physical plausibility, (4) a multi-restart search strategy for diversity, (5) a failure clustering and analysis pipeline that categorizes discovered failures by type and severity, and (6) a visualization and reporting system that communicates findings to safety engineers. The entire system should be JIT-compilable, vectorizable with jax.vmap, and extensible to new failure modes and constraint sets.
Learning Objectives
By completing this project, you will be able to:
- Parameterize driving scenarios as differentiable optimization variables by representing agent initial conditions, trajectory waypoints, and behavioral parameters as continuous vectors that JAX can differentiate through.
- Build a differentiable kinematic simulator in JAX that takes scenario parameters as input, rolls out multi-agent trajectories using a bicycle model, and produces collision/safety metrics as differentiable outputs -- enabling gradient-based search over the scenario space.
- Design and implement failure objective functions including collision probability, minimum time-to-collision (TTC), off-road rate, uncomfortable braking, and lane violation, each formulated as smooth differentiable surrogates suitable for gradient optimization.
- Implement gradient-based adversarial optimization using Adam (via Optax) to iteratively perturb scenario parameters toward failure, with constraint projection steps that enforce physical plausibility (speed limits, acceleration bounds, lane adherence) after each gradient update.
- Apply multi-restart and diversity-promoting strategies to discover a broad set of distinct failure modes rather than converging repeatedly to the same adversarial example, using techniques such as random initialization, repulsion terms, and clustering-based restart selection.
- Cluster and analyze discovered failures by extracting interpretable features from adversarial scenarios, applying unsupervised clustering to identify distinct failure categories (cut-in, sudden braking, jaywalker, occluded merge), and prioritizing them by severity and estimated real-world likelihood.
- Evaluate adversarial search efficiency by comparing the failure discovery rate, diversity, and severity distribution against random search baselines, quantifying the value of gradient-based targeting.
Prerequisites
- Required: JAX proficiency (comfortable with
jax.grad,jax.jit,jax.vmap,jax.lax.scan, and pytree manipulation), gradient-based optimization fundamentals (loss functions, learning rates, Adam), and basic familiarity with kinematic vehicle models and driving scenario representation. - Recommended: Waymax basics (loading scenarios, stepping simulations, accessing state), experience with Optax for optimizer configuration, and familiarity with clustering algorithms (K-means, DBSCAN) for the failure analysis portion.
- Deep Dive Reading:
- Long-Tail Scenarios Deep Dive -- Comprehensive coverage of why rare scenarios dominate AD safety risk, methods for generating and mining edge cases, and the connection between adversarial testing and regulatory validation. Directly relevant to the motivation and approach of this project.
Key Concepts
The Adversarial Testing Problem
Testing an autonomous driving system is fundamentally different from testing conventional software. In conventional software, you can enumerate inputs and check outputs against specifications. In AD, the "input" is the entire state of the world -- every vehicle's position, velocity, and intent; every pedestrian's trajectory; every traffic signal's state; road geometry, weather, lighting, sensor noise -- and the space of possible inputs is continuous and effectively infinite.
The safety-critical scenarios -- the ones where the AD system might fail catastrophically -- are vanishingly rare in the overall distribution. A human driver in the US encounters a police-reported crash roughly once every 500,000 miles. If you test by randomly sampling scenarios from the naturalistic distribution, you will spend nearly all your compute budget on boring, safe scenarios where the planner works fine, and almost never encounter the dangerous ones.
Adversarial scenario generation inverts this: instead of sampling scenarios and hoping to hit failures, you define what failure looks like and then optimize the scenario parameters to produce it. This is a search problem, and gradient-based optimization is the most efficient search algorithm available when the objective is differentiable.
Random Testing:
Sample scenario ~ P(scenarios) -> Run planner -> Check for failure
Efficiency: O(1 / P(failure)) -- astronomically many samples needed
Adversarial Testing:
Initialize scenario parameters theta
Loop:
Roll out scenario(theta) in differentiable sim
Compute failure_score = f(trajectory)
theta += alpha * grad(failure_score, theta)
Until: failure found or budget exhausted
Efficiency: O(gradient_steps) -- typically 50-500 steps
The key requirement is that the simulation must be differentiable end-to-end: from scenario parameters through the dynamics model to the failure metric. JAX makes this possible through its automatic differentiation capabilities, and its JIT compilation makes the optimization loop fast enough to be practical.
Scenario Parameterization
The first design decision is: what aspects of the scenario can the adversary control? The parameterization defines the search space and determines what kinds of failures can be discovered.
Initial condition parameters (simplest, most common):
- Positions of other agents:
(x_i, y_i)for each agenti - Velocities of other agents:
(vx_i, vy_i)for each agenti - Headings of other agents:
theta_ifor each agent - These give you a parameter vector of size
5 * N_agents
Trajectory parameters (richer, more expressive):
- Waypoints for each agent:
(x_i^t, y_i^t)at selected timesteps - Or control inputs:
(acceleration_i^t, steering_i^t)at each timestep - Interpolation between waypoints using splines for smoothness
- Parameter vector size:
2 * N_agents * N_waypointsor2 * N_agents * T
Behavioral parameters (most abstract):
- Aggressiveness of each agent (maps to IDM parameters)
- Reaction time, desired gap, politeness factor
- These indirectly control trajectories through a behavioral model
- Smaller parameter space, but requires a differentiable behavioral model
For this project, we use a hybrid approach: initial conditions plus trajectory waypoints, connected by a differentiable kinematic model. This gives enough expressiveness to discover diverse failures while keeping the optimization landscape manageable.
@dataclass
class ScenarioParams:
"""Differentiable scenario parameters."""
ego_init: jnp.ndarray # (4,) [x, y, heading, speed]
other_inits: jnp.ndarray # (N, 4) [x, y, heading, speed] per agent
other_actions: jnp.ndarray # (N, T, 2) [accel, steer_rate] per agent per step
# Total parameters: 4 + N*4 + N*T*2
Differentiable Kinematic Simulation
To compute gradients through the simulation, every operation from parameters to metrics must be differentiable. The kinematic bicycle model is a natural choice:
$$x_{t+1} = x_t + v_t \cos(\theta_t) \cdot \Delta t$$ $$y_{t+1} = y_t + v_t \sin(\theta_t) \cdot \Delta t$$ $$\theta_{t+1} = \theta_t + \omega_t \cdot \Delta t$$ $$v_{t+1} = \max(0, v_t + a_t \cdot \Delta t)$$
where $(a_t, \omega_t)$ are the acceleration and yaw rate at time $t$.
The $\max(0, \cdot)$ operation for speed clamping is not differentiable at zero. We replace it with a smooth approximation: $\text{softplus}(x) = \log(1 + e^x) / \beta$ where $\beta$ controls the sharpness. Similarly, collision detection requires smooth distance functions rather than hard boolean checks.
The forward pass through the simulation is implemented with jax.lax.scan for efficiency:
def rollout(params, ego_policy, dt=0.1):
"""Differentiable forward simulation.
params: ScenarioParams (initial conditions + actions for other agents)
ego_policy: function mapping observation -> action (not optimized)
Returns: all trajectories, shape (N+1, T, 4)
"""
def step_fn(state, t):
# Ego acts according to its policy (constant, not optimized)
ego_obs = extract_observation(state)
ego_action = ego_policy(ego_obs)
# Other agents follow their parameterized actions
other_actions = params.other_actions[:, t, :]
# Kinematic update for all agents
new_state = kinematic_step_batch(state, ego_action, other_actions, dt)
return new_state, new_state
init_state = build_initial_state(params)
_, trajectory = jax.lax.scan(step_fn, init_state, jnp.arange(T))
return trajectory
The ego vehicle follows a fixed policy (the system under test). The adversary controls only the other agents' initial conditions and actions. This reflects the real-world threat model: the ego cannot control what other road users do.
Failure Objective Functions
The failure objective defines what the adversary is trying to achieve. It must be:
- Differentiable: gradients must flow from the objective back through the simulation to the scenario parameters.
- Informative: the gradient should point toward more dangerous scenarios even when the current scenario is safe (no "gradient desert" problem).
- Calibrated: the magnitude should reflect severity, enabling comparison across scenarios.
Collision proximity (primary objective): $$L_{\text{collision}} = \sum_{t} \sum_{i} \text{sigmoid}\left(\frac{d_{\text{safe}} - |p_{\text{ego}}^t - p_i^t|_2}{\tau}\right)$$
where $d_{\text{safe}}$ is the safe distance threshold (sum of vehicle half-lengths plus a margin) and $\tau$ is a temperature controlling the sharpness of the sigmoid. When the ego is far from other agents, the loss is near zero but the gradient still points toward reducing distance. As the ego approaches collision distance, the loss rapidly increases.
Time-to-collision (TTC): $$L_{\text{TTC}} = \sum_{t} \max\left(0, \frac{1}{\text{TTC}^t} - \frac{1}{\text{TTC}_{\text{safe}}}\right)$$
TTC is computed as $d / v_{\text{closing}}$ where $v_{\text{closing}} = -\dot{d}$ is the rate at which the gap is shrinking. For differentiability, we use smooth approximations for the division and the max.
Off-road excursion: $$L_{\text{offroad}} = \sum_{t} \text{ReLU}(d_{\text{lane}}(p_{\text{ego}}^t) - w_{\text{lane}}/2)$$
where $d_{\text{lane}}$ is the signed distance from lane center and $w_{\text{lane}}$ is the lane width. This penalizes the ego for leaving the road.
Uncomfortable braking: $$L_{\text{comfort}} = \sum_{t} \text{ReLU}(-a_{\text{ego}}^t - a_{\text{comfort}})$$
This penalizes (rewards, from the adversary's perspective) scenarios that force the ego into hard braking (deceleration exceeding a comfort threshold like $3 \text{ m/s}^2$).
Combined objective: $$L_{\text{total}} = w_1 L_{\text{collision}} + w_2 L_{\text{TTC}} + w_3 L_{\text{offroad}} + w_4 L_{\text{comfort}}$$
The weights $w_i$ control the relative importance of each failure mode and can be adjusted to focus the search on specific types of failures.
Constraint Satisfaction and Physical Plausibility
An unconstrained adversary will produce physically impossible scenarios: vehicles teleporting, accelerating at 100 m/s^2, driving through buildings. The discovered failures would be meaningless because they could never occur in the real world. Constraint satisfaction is therefore essential.
Hard constraints (enforced by projection after each gradient step):
- Speed: $0 \leq v_i^t \leq v_{\max}$ (typically 35 m/s for highways, 15 m/s for urban)
- Acceleration: $a_{\min} \leq a_i^t \leq a_{\max}$ (typically $[-6, 4]$ m/s^2)
- Yaw rate: $|\omega_i^t| \leq \omega_{\max}$ (function of speed and vehicle geometry)
- Lane adherence: agents should be on or near driveable surface
- No initial overlap: agents should not start inside each other
Projection step:
def project_to_constraints(params):
"""Project scenario parameters back to the feasible set."""
# Clip accelerations
actions = jnp.clip(params.other_actions[..., 0], -6.0, 4.0)
# Clip steering rates (speed-dependent)
steer = jnp.clip(params.other_actions[..., 1], -0.5, 0.5)
# Clip initial speeds
speeds = jnp.clip(params.other_inits[:, 3], 0.0, 35.0)
return params.replace(
other_actions=jnp.stack([actions, steer], axis=-1),
other_inits=params.other_inits.at[:, 3].set(speeds),
)
This is a projected gradient descent approach: take a gradient step, then project back to the feasible set. For box constraints (min/max bounds), projection is just clipping. For more complex constraints (e.g., "agent must stay on road"), projection requires solving a constrained optimization subproblem or using a soft penalty.
Soft constraints (added as penalty terms to the objective): $$L_{\text{plausibility}} = \lambda_{\text{road}} \sum_{i,t} d_{\text{road}}(p_i^t)^2 + \lambda_{\text{kinematic}} \sum_{i,t} \text{ReLU}(|a_i^t| - a_{\max})^2$$
The advantage of soft constraints is simplicity (no projection step needed), but they require tuning the penalty weights $\lambda$ and do not guarantee strict feasibility.
Coverage and Diversity
Finding one adversarial scenario is useful; finding a diverse set that covers multiple failure modes is much more valuable. Without explicit diversity promotion, gradient-based optimization tends to converge repeatedly to the same failure -- often the "easiest" one for the optimizer to find.
Multi-restart strategy: Run the optimization from many different random initializations. Each restart may converge to a different local optimum in the failure landscape.
Repulsion term: Add a penalty that pushes new adversarial scenarios away from previously discovered ones: $$L_{\text{diversity}} = -\alpha \sum_{j \in \text{found}} \exp\left(-\frac{|\theta - \theta_j|^2}{2\sigma^2}\right)$$
This kernel-based repulsion (similar to Determinantal Point Processes) encourages the optimizer to explore new regions of the failure space.
Feature-space clustering: After running many restarts, extract interpretable features from each discovered failure (e.g., relative position of adversary at collision time, type of maneuver, severity) and cluster them to identify distinct failure categories.
Connection to Safety Analysis
Adversarial scenario generation is not just a clever testing technique -- it connects to formal safety analysis frameworks:
- STPA (Systems-Theoretic Process Analysis): STPA identifies "unsafe control actions" and "loss scenarios" through systematic analysis. Adversarial generation can be guided by STPA-identified hazards, focusing the search on scenarios where specific control actions are likely to fail.
- SOTIF (Safety Of The Intended Functionality, ISO 21448): SOTIF explicitly requires testing for "triggering conditions" that cause the system to fail despite correct hardware and software. Adversarial generation directly targets these triggering conditions.
- Operational Design Domain (ODD): The constraint set in adversarial generation defines the ODD boundary. Scenarios that violate constraints are outside the ODD and irrelevant; scenarios within constraints that cause failure represent genuine safety risks.
Step-by-Step Implementation Guide
Step 1: Environment Setup (30 min)
Set up the project structure and verify that all dependencies work correctly.
1.1 Project Directory Structure
adversarial-scenario-generator/
├── adversarial/
│ ├── __init__.py
│ ├── scenario.py # Scenario parameterization
│ ├── simulator.py # Differentiable kinematic simulation
│ ├── objectives.py # Failure objective functions
│ ├── optimizer.py # Adversarial optimization loop
│ ├── constraints.py # Constraint projection
│ ├── analysis.py # Failure clustering and analysis
│ └── visualization.py # Plotting and reporting
├── notebooks/
│ ├── 01_scenario_parameterization.ipynb
│ ├── 02_adversarial_optimization.ipynb
│ └── 03_failure_analysis.ipynb
├── results/ # Output directory for discovered failures
├── requirements.txt
└── README.md
1.2 Install Dependencies
python -m venv .venv
source .venv/bin/activate
pip install jax jaxlib optax
pip install numpy matplotlib scipy pandas seaborn
pip install tqdm ipykernel jupyter
1.3 Verify Setup
import jax
import jax.numpy as jnp
import optax
print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")
# Verify gradient computation
def test_fn(x):
return jnp.sum(x ** 2)
x = jnp.array([1.0, 2.0, 3.0])
grad_fn = jax.grad(test_fn)
print(f"grad(sum(x^2)) at x=[1,2,3]: {grad_fn(x)}")
# Expected: [2.0, 4.0, 6.0]
# Verify Optax
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(x)
grads = grad_fn(x)
updates, opt_state = optimizer.update(grads, opt_state)
new_x = optax.apply_updates(x, updates)
print(f"After one Adam step: {new_x}")
print("Setup verified.")
Step 2: Scenario Parameterization and Differentiable Simulation (Notebook 01, ~75 min)
Build the foundation: represent driving scenarios as vectors of differentiable parameters and implement a kinematic simulator that JAX can differentiate through.
2.1 Vehicle State and Scenario Representation
import jax
import jax.numpy as jnp
from typing import NamedTuple
class VehicleState(NamedTuple):
"""State of a single vehicle."""
x: jnp.ndarray # longitudinal position (m)
y: jnp.ndarray # lateral position (m)
heading: jnp.ndarray # heading angle (rad)
speed: jnp.ndarray # scalar speed (m/s)
class ScenarioParams(NamedTuple):
"""Differentiable scenario parameters.
These are the variables the adversary can optimize.
The ego vehicle follows a fixed policy (not optimized).
"""
ego_init: jnp.ndarray # (4,) [x, y, heading, speed]
other_inits: jnp.ndarray # (N, 4) initial [x, y, heading, speed]
other_actions: jnp.ndarray # (N, T, 2) [acceleration, yaw_rate] per step
# Example: 3 other agents, 80 timesteps
N_AGENTS = 3
T = 80
DT = 0.1
# Initialize with zeros (benign scenario)
params = ScenarioParams(
ego_init=jnp.array([0.0, 0.0, 0.0, 15.0]),
other_inits=jnp.zeros((N_AGENTS, 4)),
other_actions=jnp.zeros((N_AGENTS, T, 2)),
)
total_params = 4 + N_AGENTS * 4 + N_AGENTS * T * 2
print(f"Total optimizable parameters: {total_params}")
print(f" Ego init: 4")
print(f" Other inits: {N_AGENTS * 4}")
print(f" Other actions: {N_AGENTS * T * 2}")
2.2 Differentiable Kinematic Step
The key insight is that every operation must be composed of JAX-differentiable primitives. The standard max(0, speed) is replaced with jax.nn.softplus for smooth gradients near zero speed.
def softplus(x, beta=5.0):
"""Smooth approximation to max(0, x)."""
return jax.nn.softplus(x * beta) / beta
def kinematic_step(state, action, dt=0.1):
"""
One step of the kinematic bicycle model.
Fully differentiable in JAX.
state: (4,) [x, y, heading, speed]
action: (2,) [acceleration, yaw_rate]
"""
x, y, heading, speed = state[0], state[1], state[2], state[3]
accel, yaw_rate = action[0], action[1]
# Speed update with soft clamping (differentiable version of max(0, ...))
new_speed = softplus(speed + accel * dt)
# Position update
new_x = x + speed * jnp.cos(heading) * dt
new_y = y + speed * jnp.sin(heading) * dt
# Heading update
new_heading = heading + yaw_rate * dt
return jnp.array([new_x, new_y, new_heading, new_speed])
# Test: verify gradients exist
state = jnp.array([0.0, 0.0, 0.0, 10.0])
action = jnp.array([1.0, 0.1])
# Gradient of final x-position w.r.t. action
def final_x(action):
s = kinematic_step(state, action)
return s[0]
grad_action = jax.grad(final_x)(action)
print(f"Gradient of x w.r.t. action: {grad_action}")
2.3 Batch Rollout with jax.lax.scan
def rollout_single_agent(init_state, actions, dt=0.1):
"""
Roll out a single agent's trajectory given initial state and actions.
init_state: (4,) [x, y, heading, speed]
actions: (T, 2) [acceleration, yaw_rate] per timestep
Returns: (T+1, 4) trajectory including initial state
"""
def step_fn(state, action):
new_state = kinematic_step(state, action, dt)
return new_state, new_state
_, trajectory = jax.lax.scan(step_fn, init_state, actions)
# Prepend initial state
trajectory = jnp.concatenate([init_state[None, :], trajectory], axis=0)
return trajectory
# Vectorize over agents
rollout_batch = jax.vmap(rollout_single_agent, in_axes=(0, 0, None))
2.4 Full Scenario Rollout
def rollout_scenario(params, ego_policy_fn, dt=0.1):
"""
Roll out a complete scenario.
The ego follows ego_policy_fn (fixed, not optimized).
Other agents follow their parameterized actions.
Returns:
ego_traj: (T+1, 4)
other_trajs: (N, T+1, 4)
"""
# Other agents: simple rollout from parameterized actions
other_trajs = rollout_batch(
params.other_inits, params.other_actions, dt
)
# Ego: reactive policy that observes other agents
def ego_step_fn(ego_state, t):
# Build observation for ego policy
other_states_at_t = other_trajs[:, t, :]
ego_action = ego_policy_fn(ego_state, other_states_at_t)
new_ego = kinematic_step(ego_state, ego_action, dt)
return new_ego, new_ego
_, ego_traj = jax.lax.scan(
ego_step_fn, params.ego_init, jnp.arange(params.other_actions.shape[1])
)
ego_traj = jnp.concatenate([params.ego_init[None, :], ego_traj], axis=0)
return ego_traj, other_trajs
2.5 Simple Ego Policy (System Under Test)
For this project, we use a simple rule-based ego policy as the system under test. In practice, this would be your actual AD planner.
def simple_ego_policy(ego_state, other_states):
"""
Simple ego policy: constant speed with basic collision avoidance.
This is the "system under test" that the adversary tries to break.
"""
desired_speed = 15.0 # m/s
ego_x, ego_y, ego_heading, ego_speed = (
ego_state[0], ego_state[1], ego_state[2], ego_state[3]
)
# Basic IDM-like following for the closest agent ahead
# Compute distances to other agents
dx = other_states[:, 0] - ego_x
dy = other_states[:, 1] - ego_y
# Relative position in ego frame
cos_h, sin_h = jnp.cos(ego_heading), jnp.sin(ego_heading)
lon_dist = dx * cos_h + dy * sin_h # longitudinal distance
lat_dist = -dx * sin_h + dy * cos_h # lateral distance
# Find closest agent ahead in same lane (within 2m lateral)
in_lane = jnp.abs(lat_dist) < 2.0
ahead = lon_dist > 0
valid = in_lane & ahead
# Safe distance = large number if no one ahead
gap = jnp.where(valid, lon_dist, 1000.0)
min_gap = jnp.min(gap)
# IDM acceleration
s0 = 2.0 # minimum gap
T_hw = 1.5 # time headway
a_max = 2.0
b = 3.0
s_star = s0 + ego_speed * T_hw
accel = a_max * (1 - (ego_speed / desired_speed)**4 - (s_star / min_gap)**2)
accel = jnp.clip(accel, -6.0, a_max)
# No steering (drives straight)
steer = 0.0
return jnp.array([accel, steer])
2.6 Verifying Gradient Flow
The critical test: can we differentiate through the entire simulation?
def scenario_loss(params):
"""
Dummy loss: minimize distance between ego and closest other agent.
This tests that gradients flow end-to-end.
"""
ego_traj, other_trajs = rollout_scenario(params, simple_ego_policy)
# Minimum distance between ego and any other agent across all timesteps
# ego_traj: (T+1, 4), other_trajs: (N, T+1, 4)
ego_pos = ego_traj[:, :2] # (T+1, 2)
other_pos = other_trajs[:, :, :2] # (N, T+1, 2)
# Broadcast and compute distances
distances = jnp.sqrt(
jnp.sum((ego_pos[None, :, :] - other_pos) ** 2, axis=-1)
) # (N, T+1)
return jnp.min(distances)
# Compute gradient
grad_fn = jax.grad(scenario_loss)
grads = grad_fn(params)
print("Gradient shapes:")
print(f" ego_init grad: {grads.ego_init.shape}")
print(f" other_inits grad: {grads.other_inits.shape}")
print(f" other_actions grad: {grads.other_actions.shape}")
print(f" grad norms: ego_init={jnp.linalg.norm(grads.ego_init):.4f}, "
f"other_inits={jnp.linalg.norm(grads.other_inits):.4f}, "
f"other_actions={jnp.linalg.norm(grads.other_actions):.4f}")
If all gradient norms are non-zero, the differentiable simulation is working correctly and the adversarial optimizer has a gradient signal to follow.
2.7 Visualization
def plot_scenario(ego_traj, other_trajs, title="Scenario", ax=None):
"""Bird's-eye view of a multi-agent scenario."""
if ax is None:
fig, ax = plt.subplots(figsize=(12, 6))
# Road
ax.axhspan(-5.5, 5.5, color='#e0e0e0', alpha=0.3)
ax.axhline(y=-1.85, color='white', linestyle='--', linewidth=0.8)
ax.axhline(y=1.85, color='white', linestyle='--', linewidth=0.8)
ax.axhline(y=-5.5, color='black', linewidth=2)
ax.axhline(y=5.5, color='black', linewidth=2)
# Ego trajectory
ax.plot(ego_traj[:, 0], ego_traj[:, 1], 'b-', linewidth=2, label='Ego')
ax.plot(ego_traj[0, 0], ego_traj[0, 1], 'bs', markersize=8)
ax.plot(ego_traj[-1, 0], ego_traj[-1, 1], 'b^', markersize=8)
# Other agent trajectories
colors = ['red', 'orange', 'purple', 'green', 'brown']
for i in range(other_trajs.shape[0]):
c = colors[i % len(colors)]
ax.plot(other_trajs[i, :, 0], other_trajs[i, :, 1],
f'-', color=c, linewidth=2, label=f'Agent {i}')
ax.plot(other_trajs[i, 0, 0], other_trajs[i, 0, 1],
's', color=c, markersize=8)
ax.plot(other_trajs[i, -1, 0], other_trajs[i, -1, 1],
'^', color=c, markersize=8)
ax.set_xlabel('x (m)')
ax.set_ylabel('y (m)')
ax.set_title(title)
ax.legend(fontsize=9)
ax.set_aspect('equal')
ax.grid(True, alpha=0.3)
return ax
Key implementation notes for Step 2:
- Use
jax.lax.scaninstead of Python for-loops for the simulation rollout. Python loops will not be compiled byjax.jitefficiently and will be extremely slow. - The
softplusapproximation for speed clamping is critical. Without it, gradients die at zero speed, and the optimizer cannot learn to create scenarios where agents start from rest and then move. NamedTupleis a natural pytree in JAX, soScenarioParamsworks seamlessly withjax.grad,jax.jit, andjax.vmap.- The ego policy must also be JAX-compatible (no Python control flow that depends on array values). Use
jnp.whereinstead ofif/elsestatements.
Step 3: Adversarial Optimization (Notebook 02, ~90 min)
Implement the gradient-based adversarial search loop that perturbs scenario parameters to maximize the probability of planner failure.
3.1 Failure Objective Functions
# adversarial/objectives.py
def collision_objective(ego_traj, other_trajs, safe_dist=5.0, temperature=1.0):
"""
Soft collision proximity loss.
Higher values = closer to collision = worse for ego = better for adversary.
Uses sigmoid to create smooth gradient field even when vehicles are far apart.
"""
ego_pos = ego_traj[:, :2] # (T, 2)
other_pos = other_trajs[:, :, :2] # (N, T, 2)
# Pairwise distances: (N, T)
distances = jnp.sqrt(
jnp.sum((ego_pos[None, :, :] - other_pos) ** 2, axis=-1) + 1e-6
)
# Sigmoid: approaches 1 when distance < safe_dist
collision_score = jax.nn.sigmoid((safe_dist - distances) / temperature)
# Sum over agents and time (higher = more collision-prone)
return jnp.sum(collision_score)
def min_ttc_objective(ego_traj, other_trajs, dt=0.1, ttc_threshold=3.0):
"""
Time-to-collision objective.
Penalizes (rewards for adversary) low TTC values.
"""
ego_pos = ego_traj[:, :2]
ego_vel = jnp.diff(ego_pos, axis=0) / dt # (T-1, 2)
other_pos = other_trajs[:, :, :2]
other_vel = jnp.diff(other_pos, axis=1) / dt # (N, T-1, 2)
# Relative position and velocity
rel_pos = other_pos[:, :-1, :] - ego_pos[None, :-1, :] # (N, T-1, 2)
rel_vel = other_vel - ego_vel[None, :, :] # (N, T-1, 2)
# Distance and closing speed
dist = jnp.sqrt(jnp.sum(rel_pos ** 2, axis=-1) + 1e-6) # (N, T-1)
# Closing speed = -d(distance)/dt (positive when approaching)
closing_speed = -jnp.sum(rel_pos * rel_vel, axis=-1) / (dist + 1e-6)
# TTC = distance / closing_speed (when closing_speed > 0)
# Use softplus to handle negative closing speeds smoothly
safe_closing = softplus(closing_speed, beta=2.0) + 1e-6
ttc = dist / safe_closing
# Reward low TTC: 1/TTC clipped for stability
inv_ttc = 1.0 / (ttc + 0.1)
ttc_penalty = jnp.sum(jnp.where(ttc < ttc_threshold, inv_ttc, 0.0))
return ttc_penalty
def offroad_objective(ego_traj, lane_center_y=0.0, lane_width=3.7):
"""
Off-road excursion loss.
Rewards the adversary for pushing the ego out of its lane.
"""
ego_y = ego_traj[:, 1]
dist_from_center = jnp.abs(ego_y - lane_center_y)
offroad_amount = jax.nn.relu(dist_from_center - lane_width / 2)
return jnp.sum(offroad_amount)
def hard_braking_objective(ego_traj, dt=0.1, comfort_decel=3.0):
"""
Uncomfortable braking loss.
Rewards scenarios that force the ego into hard braking.
"""
ego_speed = ego_traj[:, 3]
accel = jnp.diff(ego_speed) / dt
hard_brake = jax.nn.relu(-accel - comfort_decel) # amount beyond comfort threshold
return jnp.sum(hard_brake)
def combined_objective(ego_traj, other_trajs, weights=None):
"""
Weighted combination of all failure objectives.
Returns scalar loss (higher = more adversarial).
"""
if weights is None:
weights = {
'collision': 1.0,
'ttc': 0.5,
'offroad': 0.3,
'braking': 0.2,
}
loss = 0.0
loss += weights['collision'] * collision_objective(ego_traj, other_trajs)
loss += weights['ttc'] * min_ttc_objective(ego_traj, other_trajs)
loss += weights['offroad'] * offroad_objective(ego_traj)
loss += weights['braking'] * hard_braking_objective(ego_traj)
return loss
3.2 Constraint Projection
# adversarial/constraints.py
def project_params(params, config=None):
"""
Project scenario parameters back to the physically plausible set.
Called after each gradient step.
"""
if config is None:
config = {
'max_speed': 30.0, # m/s (~108 km/h)
'min_speed': 0.0,
'max_accel': 4.0, # m/s^2
'min_accel': -8.0, # m/s^2 (hard braking)
'max_yaw_rate': 0.5, # rad/s
'road_y_min': -5.5, # road boundaries
'road_y_max': 5.5,
'x_min': -50.0, # longitudinal range
'x_max': 200.0,
}
# Clip other agent actions
accel = jnp.clip(
params.other_actions[:, :, 0],
config['min_accel'], config['max_accel']
)
yaw_rate = jnp.clip(
params.other_actions[:, :, 1],
-config['max_yaw_rate'], config['max_yaw_rate']
)
clipped_actions = jnp.stack([accel, yaw_rate], axis=-1)
# Clip initial conditions
clipped_inits = params.other_inits.at[:, 3].set(
jnp.clip(params.other_inits[:, 3], config['min_speed'], config['max_speed'])
)
clipped_inits = clipped_inits.at[:, 1].set(
jnp.clip(clipped_inits[:, 1], config['road_y_min'], config['road_y_max'])
)
clipped_inits = clipped_inits.at[:, 0].set(
jnp.clip(clipped_inits[:, 0], config['x_min'], config['x_max'])
)
return ScenarioParams(
ego_init=params.ego_init, # ego init is fixed
other_inits=clipped_inits,
other_actions=clipped_actions,
)
def smoothness_penalty(params, weight=0.01):
"""
Penalty for jerky, unrealistic agent behavior.
Encourages smooth acceleration and steering profiles.
"""
# Finite difference of actions (jerk)
action_diff = jnp.diff(params.other_actions, axis=1)
jerk_penalty = jnp.sum(action_diff ** 2)
return weight * jerk_penalty
3.3 Adversarial Optimization Loop
# adversarial/optimizer.py
import optax
from tqdm import tqdm
def adversarial_search(
initial_params,
ego_policy_fn,
n_steps=200,
learning_rate=0.01,
objective_weights=None,
constraint_config=None,
smoothness_weight=0.01,
verbose=True,
):
"""
Gradient-based adversarial scenario search.
Maximizes the failure objective while maintaining physical plausibility.
Uses Optax Adam optimizer with constraint projection after each step.
Args:
initial_params: ScenarioParams to start from
ego_policy_fn: the system under test
n_steps: number of optimization steps
learning_rate: Adam learning rate
objective_weights: weights for combined objective
constraint_config: physical plausibility constraints
smoothness_weight: weight for smoothness regularization
Returns:
best_params: scenario parameters that maximize failure
history: dict of loss values over optimization
"""
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(initial_params)
@jax.jit
def loss_fn(params):
ego_traj, other_trajs = rollout_scenario(params, ego_policy_fn)
adv_loss = combined_objective(ego_traj, other_trajs, objective_weights)
smooth_loss = smoothness_penalty(params, smoothness_weight)
# Negate because we want to MAXIMIZE failure (minimize negative loss)
return -(adv_loss - smooth_loss)
grad_fn = jax.grad(loss_fn)
params = initial_params
best_params = params
best_loss = float('inf')
history = {'loss': [], 'collision_score': [], 'min_distance': []}
for step in range(n_steps):
# Compute gradients
grads = grad_fn(params)
# Adam update
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
# Project to feasible set
params = project_params(params, constraint_config)
# Evaluate current scenario
ego_traj, other_trajs = rollout_scenario(params, ego_policy_fn)
current_loss = float(loss_fn(params))
# Track metrics
distances = jnp.sqrt(jnp.sum(
(ego_traj[None, :, :2] - other_trajs[:, :, :2]) ** 2, axis=-1
))
min_dist = float(jnp.min(distances))
col_score = float(collision_objective(ego_traj, other_trajs))
history['loss'].append(current_loss)
history['collision_score'].append(col_score)
history['min_distance'].append(min_dist)
if current_loss < best_loss:
best_loss = current_loss
best_params = params
if verbose and step % 20 == 0:
print(f"Step {step:4d} | loss={current_loss:.4f} | "
f"min_dist={min_dist:.2f}m | col_score={col_score:.2f}")
return best_params, history
3.4 Multi-Restart for Diversity
def multi_restart_search(
base_params,
ego_policy_fn,
n_restarts=10,
n_steps_per_restart=200,
key=jax.random.PRNGKey(0),
**kwargs,
):
"""
Run adversarial search from multiple random initializations.
Returns a list of discovered adversarial scenarios sorted by severity.
"""
all_results = []
for i in range(n_restarts):
key, subkey = jax.random.split(key)
# Random perturbation of initial scenario
perturbed_params = random_perturbation(base_params, subkey)
print(f"\n--- Restart {i+1}/{n_restarts} ---")
best_params, history = adversarial_search(
perturbed_params, ego_policy_fn,
n_steps=n_steps_per_restart, **kwargs
)
# Evaluate final scenario
ego_traj, other_trajs = rollout_scenario(best_params, ego_policy_fn)
min_dist = float(jnp.min(jnp.sqrt(jnp.sum(
(ego_traj[None, :, :2] - other_trajs[:, :, :2]) ** 2, axis=-1
))))
all_results.append({
'params': best_params,
'history': history,
'min_distance': min_dist,
'final_loss': history['loss'][-1],
'restart_id': i,
})
# Sort by severity (lowest min_distance = most dangerous)
all_results.sort(key=lambda r: r['min_distance'])
return all_results
def random_perturbation(params, key, position_std=10.0, speed_std=3.0, action_std=0.5):
"""Generate a random perturbation of scenario parameters."""
keys = jax.random.split(key, 3)
# Perturb other agent initial conditions
pos_noise = jax.random.normal(keys[0], params.other_inits[:, :2].shape) * position_std
speed_noise = jax.random.normal(keys[1], params.other_inits[:, 3:4].shape) * speed_std
new_inits = params.other_inits.at[:, :2].add(pos_noise)
new_inits = new_inits.at[:, 3:4].add(speed_noise)
# Perturb actions
action_noise = jax.random.normal(keys[2], params.other_actions.shape) * action_std
new_actions = params.other_actions + action_noise
return ScenarioParams(
ego_init=params.ego_init,
other_inits=new_inits,
other_actions=new_actions,
)
3.5 Comparing with Random Search Baseline
def random_search_baseline(
base_params,
ego_policy_fn,
n_samples=1000,
key=jax.random.PRNGKey(42),
):
"""
Random search: sample many random scenarios and keep the most adversarial.
This is the baseline that gradient-based search should beat.
"""
results = []
for i in range(n_samples):
key, subkey = jax.random.split(key)
random_params = random_perturbation(base_params, subkey)
random_params = project_params(random_params)
ego_traj, other_trajs = rollout_scenario(random_params, ego_policy_fn)
distances = jnp.sqrt(jnp.sum(
(ego_traj[None, :, :2] - other_trajs[:, :, :2]) ** 2, axis=-1
))
min_dist = float(jnp.min(distances))
col_score = float(collision_objective(ego_traj, other_trajs))
results.append({
'params': random_params,
'min_distance': min_dist,
'collision_score': col_score,
})
results.sort(key=lambda r: r['min_distance'])
return results
3.6 Visualization of Optimization Trajectory
def plot_optimization_trajectory(history, title="Adversarial Optimization"):
"""Plot loss, collision score, and min distance over optimization steps."""
fig, axes = plt.subplots(1, 3, figsize=(16, 4))
axes[0].plot(history['loss'], 'b-', linewidth=1.5)
axes[0].set_xlabel('Step')
axes[0].set_ylabel('Loss (negated)')
axes[0].set_title('Optimization Loss')
axes[0].grid(True, alpha=0.3)
axes[1].plot(history['collision_score'], 'r-', linewidth=1.5)
axes[1].set_xlabel('Step')
axes[1].set_ylabel('Collision Score')
axes[1].set_title('Collision Proximity')
axes[1].grid(True, alpha=0.3)
axes[2].plot(history['min_distance'], 'g-', linewidth=1.5)
axes[2].axhline(y=5.0, color='red', linestyle='--', label='Safe distance')
axes[2].set_xlabel('Step')
axes[2].set_ylabel('Min Distance (m)')
axes[2].set_title('Minimum Ego-Other Distance')
axes[2].legend()
axes[2].grid(True, alpha=0.3)
fig.suptitle(title, fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
def plot_before_after(initial_params, adversarial_params, ego_policy_fn):
"""Side-by-side comparison of initial and adversarial scenarios."""
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
ego_init, others_init = rollout_scenario(initial_params, ego_policy_fn)
plot_scenario(ego_init, others_init, title="Initial (Benign) Scenario", ax=axes[0])
ego_adv, others_adv = rollout_scenario(adversarial_params, ego_policy_fn)
plot_scenario(ego_adv, others_adv, title="Adversarial Scenario", ax=axes[1])
plt.tight_layout()
plt.show()
# Print comparison stats
init_dist = float(jnp.min(jnp.sqrt(jnp.sum(
(ego_init[None, :, :2] - others_init[:, :, :2]) ** 2, axis=-1))))
adv_dist = float(jnp.min(jnp.sqrt(jnp.sum(
(ego_adv[None, :, :2] - others_adv[:, :, :2]) ** 2, axis=-1))))
print(f"Initial scenario min distance: {init_dist:.2f} m")
print(f"Adversarial scenario min distance: {adv_dist:.2f} m")
print(f"Distance reduction: {init_dist - adv_dist:.2f} m ({(1 - adv_dist/init_dist)*100:.1f}%)")
Key implementation notes for Step 3:
- The loss function is negated because Optax minimizes and we want to maximize failure. Alternatively, you can use
optax.apply_updateswith negated gradients. - JIT-compile the loss function and gradient computation (shown with
@jax.jitdecorator). Without JIT, each optimization step takes seconds instead of milliseconds. - The constraint projection step is crucial. Without it, the optimizer will quickly find "adversarial" scenarios where other vehicles materialize at 200 m/s directly in front of the ego -- physically impossible and useless for safety analysis.
- Start with a small learning rate (0.001-0.01) and observe the optimization trajectory. If the loss oscillates wildly, reduce the learning rate. If progress is too slow, increase it.
- The smoothness penalty prevents the optimizer from discovering scenarios with discontinuous, physically impossible agent behavior (e.g., an agent that alternates between full throttle and full brake every timestep).
Step 4: Failure Mining and Analysis (Notebook 03, ~75 min)
Run adversarial search at scale, cluster discovered failures, and build a comprehensive failure report.
4.1 Running Adversarial Search at Scale
# Run multi-restart adversarial search
results = multi_restart_search(
base_params=nominal_scenario_params,
ego_policy_fn=simple_ego_policy,
n_restarts=20,
n_steps_per_restart=300,
learning_rate=0.01,
)
# Also run random search for comparison
random_results = random_search_baseline(
base_params=nominal_scenario_params,
ego_policy_fn=simple_ego_policy,
n_samples=2000,
)
print(f"Adversarial search: found {len(results)} scenarios")
print(f"Random search: found {len(random_results)} scenarios")
4.2 Feature Extraction
Before clustering, extract interpretable features from each discovered failure.
def extract_failure_features(params, ego_policy_fn, dt=0.1):
"""
Extract interpretable features from an adversarial scenario.
These features are used for clustering failures into categories.
"""
ego_traj, other_trajs = rollout_scenario(params, ego_policy_fn)
# Find the critical moment (minimum distance)
distances = jnp.sqrt(jnp.sum(
(ego_traj[None, :, :2] - other_trajs[:, :, :2]) ** 2, axis=-1
))
min_idx = jnp.unravel_index(jnp.argmin(distances), distances.shape)
critical_agent = int(min_idx[0])
critical_time = int(min_idx[1])
min_distance = float(distances[min_idx])
# Relative geometry at critical moment
ego_pos = ego_traj[critical_time, :2]
ego_heading = ego_traj[critical_time, 2]
ego_speed = ego_traj[critical_time, 3]
other_pos = other_trajs[critical_agent, critical_time, :2]
other_heading = other_trajs[critical_agent, critical_time, 2]
other_speed = other_trajs[critical_agent, critical_time, 3]
# Relative position in ego frame
dx = other_pos[0] - ego_pos[0]
dy = other_pos[1] - ego_pos[1]
cos_h, sin_h = jnp.cos(ego_heading), jnp.sin(ego_heading)
lon_rel = dx * cos_h + dy * sin_h # positive = ahead
lat_rel = -dx * sin_h + dy * cos_h # positive = left
# Relative heading
heading_diff = other_heading - ego_heading
heading_diff = (heading_diff + jnp.pi) % (2 * jnp.pi) - jnp.pi
# Speed ratio
speed_ratio = other_speed / (ego_speed + 1e-6)
# Ego deceleration at critical moment
ego_accel = jnp.diff(ego_traj[:, 3]) / dt
critical_decel = float(ego_accel[min(critical_time, len(ego_accel)-1)])
# Approach angle (from which direction does the threat come?)
approach_angle = jnp.arctan2(lat_rel, lon_rel)
features = {
'min_distance': min_distance,
'critical_time': critical_time * dt,
'lon_relative': float(lon_rel),
'lat_relative': float(lat_rel),
'heading_diff': float(heading_diff),
'speed_ratio': float(speed_ratio),
'ego_decel': critical_decel,
'approach_angle': float(approach_angle),
'other_speed': float(other_speed),
'ego_speed': float(ego_speed),
}
return features
4.3 Failure Clustering
from sklearn.cluster import KMeans, DBSCAN
from sklearn.preprocessing import StandardScaler
def cluster_failures(results, ego_policy_fn, n_clusters=5):
"""
Cluster discovered adversarial scenarios into failure categories.
"""
# Extract features for all failures
all_features = []
for r in results:
feat = extract_failure_features(r['params'], ego_policy_fn)
all_features.append(feat)
# Convert to array for clustering
feature_names = list(all_features[0].keys())
X = np.array([[f[k] for k in feature_names] for f in all_features])
# Standardize
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# K-means clustering
kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
labels = kmeans.fit_predict(X_scaled)
# Assign human-readable names based on cluster centroids
cluster_info = []
for c in range(n_clusters):
mask = labels == c
cluster_features = {k: np.mean([f[k] for f, m in zip(all_features, mask) if m])
for k in feature_names}
# Heuristic naming based on features
name = classify_failure_type(cluster_features)
cluster_info.append({
'id': c,
'name': name,
'count': int(np.sum(mask)),
'mean_features': cluster_features,
'severity': np.mean([r['min_distance'] for r, m in zip(results, mask) if m]),
})
return labels, cluster_info, all_features
def classify_failure_type(features):
"""
Heuristic classification of failure type based on extracted features.
"""
lon = features['lon_relative']
lat = features['lat_relative']
heading_diff = features['heading_diff']
speed_ratio = features['speed_ratio']
if abs(lat) > 2.0 and lon > 0:
if heading_diff < -0.3:
return "Cut-in from left"
elif heading_diff > 0.3:
return "Cut-in from right"
else:
return "Side swipe"
elif lon > 0 and abs(lat) < 2.0:
if speed_ratio < 0.5:
return "Sudden braking (lead vehicle)"
else:
return "Rear approach conflict"
elif lon < 0 and abs(lat) < 2.0:
return "Rear-end (ego is lead)"
elif abs(heading_diff) > 2.5:
return "Head-on approach"
elif abs(lat) > 3.0:
return "Lateral encroachment"
else:
return "Complex interaction"
4.4 Coverage Analysis
def coverage_analysis(adversarial_results, random_results, ego_policy_fn):
"""
Compare failure coverage between adversarial and random search.
"""
# Adversarial failures
adv_features = [extract_failure_features(r['params'], ego_policy_fn)
for r in adversarial_results]
# Random search failures (only keep those with min_dist < threshold)
threshold = 5.0 # meters
random_failures = [r for r in random_results if r['min_distance'] < threshold]
rand_features = [extract_failure_features(r['params'], ego_policy_fn)
for r in random_failures]
print(f"Adversarial: {len(adversarial_results)} scenarios total")
print(f" Failures (min_dist < {threshold}m): "
f"{sum(1 for r in adversarial_results if r['min_distance'] < threshold)}")
print(f" Failure rate: "
f"{sum(1 for r in adversarial_results if r['min_distance'] < threshold) / len(adversarial_results) * 100:.1f}%")
print(f"\nRandom: {len(random_results)} scenarios total")
print(f" Failures (min_dist < {threshold}m): {len(random_failures)}")
print(f" Failure rate: {len(random_failures) / len(random_results) * 100:.1f}%")
# Diversity comparison
if len(adv_features) > 1:
adv_X = np.array([[f[k] for k in adv_features[0].keys()] for f in adv_features])
adv_diversity = np.mean(scipy.spatial.distance.pdist(
StandardScaler().fit_transform(adv_X)))
print(f"\nAdversarial diversity (mean pairwise distance): {adv_diversity:.3f}")
if len(rand_features) > 1:
rand_X = np.array([[f[k] for k in rand_features[0].keys()] for f in rand_features])
rand_diversity = np.mean(scipy.spatial.distance.pdist(
StandardScaler().fit_transform(rand_X)))
print(f"Random diversity (mean pairwise distance): {rand_diversity:.3f}")
4.5 Severity Prioritization
def prioritize_failures(results, ego_policy_fn):
"""
Rank failures by a composite severity score.
Severity considers:
- How close to collision (min distance)
- How hard the ego had to brake (deceleration)
- How early in the scenario the failure occurs (earlier = harder to avoid)
- Plausibility of the adversarial scenario
"""
scored_results = []
for r in results:
features = extract_failure_features(r['params'], ego_policy_fn)
# Proximity score: 1/distance, capped
proximity = 1.0 / (features['min_distance'] + 0.1)
# Braking severity: how hard ego braked
brake_severity = max(0, -features['ego_decel'] / 6.0) # normalized by max decel
# Time pressure: earlier failures are harder to react to
time_pressure = 1.0 - (features['critical_time'] / 8.0)
# Composite severity
severity = 0.5 * proximity + 0.3 * brake_severity + 0.2 * time_pressure
scored_results.append({
**r,
'features': features,
'severity_score': severity,
})
scored_results.sort(key=lambda r: -r['severity_score'])
return scored_results
4.6 Failure Report Generation
def generate_failure_report(scored_results, cluster_info, ego_policy_fn):
"""
Generate a comprehensive failure analysis report with visualizations.
"""
fig = plt.figure(figsize=(20, 16))
gs = fig.add_gridspec(3, 3, hspace=0.4, wspace=0.3)
# 1. Severity distribution
ax1 = fig.add_subplot(gs[0, 0])
severities = [r['severity_score'] for r in scored_results]
ax1.hist(severities, bins=20, color='#d62728', edgecolor='white', alpha=0.8)
ax1.set_xlabel('Severity Score')
ax1.set_ylabel('Count')
ax1.set_title('Severity Distribution')
ax1.grid(True, alpha=0.3)
# 2. Failure type distribution (pie chart)
ax2 = fig.add_subplot(gs[0, 1])
types = [ci['name'] for ci in cluster_info]
counts = [ci['count'] for ci in cluster_info]
ax2.pie(counts, labels=types, autopct='%1.1f%%', colors=plt.cm.Set3.colors)
ax2.set_title('Failure Categories')
# 3. Min distance vs severity
ax3 = fig.add_subplot(gs[0, 2])
min_dists = [r['min_distance'] for r in scored_results]
ax3.scatter(min_dists, severities, c='#1f77b4', alpha=0.6, edgecolors='k', linewidths=0.5)
ax3.set_xlabel('Min Distance (m)')
ax3.set_ylabel('Severity Score')
ax3.set_title('Distance vs Severity')
ax3.grid(True, alpha=0.3)
# 4-6. Top 3 most severe failures (BEV plots)
for i in range(min(3, len(scored_results))):
ax = fig.add_subplot(gs[1, i])
r = scored_results[i]
ego_traj, other_trajs = rollout_scenario(r['params'], ego_policy_fn)
plot_scenario(ego_traj, other_trajs,
title=f"#{i+1}: severity={r['severity_score']:.2f}, "
f"dist={r['min_distance']:.1f}m", ax=ax)
# 7. Feature scatter (approach angle vs speed ratio)
ax7 = fig.add_subplot(gs[2, 0])
approach_angles = [r['features']['approach_angle'] for r in scored_results]
speed_ratios = [r['features']['speed_ratio'] for r in scored_results]
ax7.scatter(np.degrees(approach_angles), speed_ratios,
c=severities, cmap='Reds', alpha=0.7, edgecolors='k', linewidths=0.5)
ax7.set_xlabel('Approach Angle (deg)')
ax7.set_ylabel('Speed Ratio (other/ego)')
ax7.set_title('Threat Geometry')
ax7.grid(True, alpha=0.3)
# 8. Cluster severity comparison
ax8 = fig.add_subplot(gs[2, 1])
cluster_names = [ci['name'] for ci in cluster_info]
cluster_severities = [ci['severity'] for ci in cluster_info]
ax8.barh(cluster_names, cluster_severities, color=plt.cm.Set2.colors)
ax8.set_xlabel('Mean Min Distance (m)')
ax8.set_title('Severity by Failure Type')
# 9. Summary table
ax9 = fig.add_subplot(gs[2, 2])
ax9.axis('off')
summary_text = (
f"Total scenarios tested: {len(scored_results)}\n"
f"Failures found (< 5m): {sum(1 for r in scored_results if r['min_distance'] < 5)}\n"
f"Near-collisions (< 2m): {sum(1 for r in scored_results if r['min_distance'] < 2)}\n"
f"Distinct failure types: {len(cluster_info)}\n"
f"Most severe category: {cluster_info[0]['name'] if cluster_info else 'N/A'}\n"
f"Mean severity: {np.mean(severities):.3f}\n"
f"Max severity: {max(severities):.3f}"
)
ax9.text(0.1, 0.5, summary_text, fontsize=11, fontfamily='monospace',
verticalalignment='center', transform=ax9.transAxes,
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
ax9.set_title('Summary')
fig.suptitle('Adversarial Scenario Generation -- Failure Report',
fontsize=16, fontweight='bold')
plt.savefig('failure_report.png', dpi=150, bbox_inches='tight')
plt.show()
Key implementation notes for Step 4:
- Feature extraction should be fast because it runs on every discovered scenario. JIT-compile the
rollout_scenariofunction if not already done. - The heuristic failure classification function (
classify_failure_type) is a starting point. In practice, you would refine the thresholds based on your specific ODD and scenario distribution. - Use DBSCAN instead of K-means if you do not know the number of failure clusters in advance. DBSCAN automatically determines the number of clusters based on density, but requires tuning the
epsandmin_samplesparameters. - The coverage analysis is the key result that justifies adversarial generation over random testing. Expect to see 5-50x higher failure discovery rates with gradient-based search compared to random sampling.
- Consider exporting discovered failures in a format compatible with your simulation platform (e.g., Waymax scenario format, OpenSCENARIO) so they can be replayed and shared with the safety team.
Exercises
Exercise 1: Custom Failure Objective (in Notebook 02)
Implement a new failure objective that targets a specific failure mode: the ego vehicle being forced to make an emergency lane change. The objective should reward scenarios where:
- An obstacle appears ahead in the ego's lane, requiring evasive action.
- The adjacent lane has limited gap, making the lane change dangerous.
- The ego must choose between braking hard or changing lanes into a tight gap.
Hints:
- Define "lane change pressure" as a function of the gap in the ego's lane and the gaps in adjacent lanes.
- Use a product of terms: (small gap ahead) * (small gap beside) = high pressure.
- Test your objective by running adversarial search and inspecting the discovered scenarios visually.
Exercise 2: Constrained Adversary (in Notebook 02)
Modify the adversarial optimizer to enforce an additional constraint: all other agents must follow plausible driving behavior (not just kinematically feasible, but behaviorally plausible). Specifically:
- Each other agent should approximately follow IDM with respect to its own leader.
- The adversary can only modify the IDM parameters (desired speed, desired headway, politeness) rather than directly controlling actions.
- This creates a "naturalistic adversary" that is much harder to distinguish from real-world driving.
Hints:
- Parameterize each agent by IDM parameters:
[v_desired, s0, T, a_max, b]. - Inside the rollout, compute actions from IDM given these parameters.
- The optimization landscape is lower-dimensional but more complex.
Exercise 3: Adversarial Weather Conditions (in Notebook 03)
Extend the failure analysis to consider how environmental factors amplify adversarial scenarios. Add parameterized environmental conditions:
- Reduced visibility (fog/rain): modeled as increased sensor noise or reduced detection range for the ego's perception.
- Reduced friction: modeled as lower maximum braking capability for all vehicles.
- Increased reaction time: modeled as a delay in the ego's policy response.
Run adversarial search with and without these environmental factors and compare:
- Does the set of discovered failures change?
- Are some failure types only reachable under adverse conditions?
- How does the severity distribution shift?
Summary and Next Steps
In this project you built a complete adversarial scenario generation framework:
- Differentiable scenario parameterization: Driving scenarios represented as continuous vectors that JAX can differentiate through, enabling gradient-based search.
- Differentiable kinematic simulation: A bicycle-model simulator implemented entirely in JAX with smooth approximations for non-differentiable operations (speed clamping, distance thresholds).
- Composable failure objectives: A library of differentiable failure metrics (collision proximity, TTC, off-road, hard braking) that can be weighted and combined to target specific failure modes.
- Gradient-based adversarial optimization: An Optax-based optimizer that iteratively perturbs scenario parameters toward failure, with constraint projection to maintain physical plausibility.
- Multi-restart diversity: A strategy for discovering diverse failure modes rather than converging to the same adversarial example repeatedly.
- Failure analysis pipeline: Feature extraction, clustering, coverage analysis, and severity prioritization that transforms raw adversarial scenarios into actionable safety insights.
Where to go from here:
- Integration with Waymax: Replace the toy kinematic simulator with Waymax's full dynamics model for more realistic scenarios. Waymax's differentiable simulation enables the same gradient-based approach but with realistic multi-agent dynamics, road geometry, and traffic rules.
- Learned ego policies: Replace the simple rule-based ego policy with a learned neural network policy (e.g., from the RL Training phase). Adversarial testing of learned policies often reveals failure modes that unit tests miss.
- Scenario libraries: Build a curated library of adversarial scenarios organized by failure type, severity, and ODD region. This library becomes a regression test suite that every new policy version must pass.
- Automated CI/CD integration: Run adversarial search as part of the continuous integration pipeline. Every code change to the planner triggers adversarial testing, and new failures are automatically filed as issues.
- SOTIF compliance: Map discovered failure types to SOTIF hazard categories and use the coverage analysis to argue that your testing has explored the relevant portions of the scenario space.
References
- Ding, W., et al. (2021). "Multimodal Safety-Critical Scenarios Generation for Decision-Making Algorithms Evaluation." IEEE Robotics and Automation Letters.
- Rempe, D., et al. (2022). "Generating Useful Accident-Prone Driving Scenarios via a Learned Traffic Model." CVPR 2022.
- Sun, C., et al. (2021). "Adversarial Evaluation of Autonomous Vehicles in Lane-Change Scenarios." IEEE Transactions on Intelligent Transportation Systems.
- Koren, M., et al. (2018). "Adaptive Stress Testing for Autonomous Vehicles." IEEE Intelligent Vehicles Symposium.
- Corso, A., et al. (2021). "A Survey of Algorithms for Black-Box Safety Validation of Cyber-Physical Systems." Journal of Artificial Intelligence Research.
- Wachi, A., and Sui, Y. (2020). "Safe Reinforcement Learning in Constrained Markov Decision Processes." ICML 2020.
- ISO 21448:2022. "Road vehicles -- Safety of the intended functionality."