MARL Training Infrastructure¶
Overview¶
Implementation Status
The Training Infrastructure is fully implemented with five RL algorithms, checkpoint management, metrics tracking, TensorBoard integration, and convergence detection.
The MARL Training Infrastructure provides the complete pipeline for training and evaluating reinforcement learning agents in OpenCDA-MARL scenarios.
Training Infrastructure
├── MARLManager # Algorithm orchestrator
├── BaseAlgorithm # Common algorithm interface
├── Algorithms # TD3, DQN, Q-Learning, MAPPO, SAC
├── ObservationExtractor # Feature extraction for RL
├── CheckpointManager # Model saving/loading
├── TrainingMetrics # Episode statistics & CSV export
├── Replay Buffers # Smart, Prioritized, Rollout
└── TensorBoard # Training visualization
Supported Algorithms¶
| Algorithm | Type | Status | Description |
|---|---|---|---|
| TD3 | Continuous, Off-policy | ✅ Implemented | Twin Delayed DDPG with LSTM encoder |
| DQN | Discrete, Off-policy | ✅ Implemented | Deep Q-Network with target network |
| Q-Learning | Discrete, Tabular | ✅ Implemented | Tabular Q-Learning with state bins |
| MAPPO | Continuous, On-policy | ✅ Implemented | Multi-Agent PPO with GAE |
| SAC | Continuous, Off-policy | ✅ Implemented | Soft Actor-Critic with entropy tuning |
Core Classes¶
The algorithm orchestrator that selects and manages the active RL algorithm.
from opencda_marl.core.marl.marl_manager import MARLManager
manager = MARLManager(config)
# Select action based on observations
action = manager.select_action(
multi_agent_obs=observations,
ego_agent_id="agent_001",
training=True
)
# Store experience
manager.store_transition(obs, ego_id, action, reward, next_obs, done)
# Update algorithm (returns loss dict)
losses = manager.update()
# Get training statistics
info = manager.get_training_info()
Abstract base class that all algorithms implement:
from opencda_marl.core.marl.algorithms.base_algorithm import BaseAlgorithm
class BaseAlgorithm(ABC):
"""Common interface for all RL algorithms."""
def select_action(self, state, training=True) -> float: ...
def store_transition(self, state, action, reward, next_state, done): ...
def update(self) -> Dict[str, float]: ...
def reset_episode(self): ...
def get_training_info(self) -> Dict: ...
def save(self, path: str): ...
def load(self, path: str): ...
Built-in features: TensorBoard logging, convergence detection, episode metrics.
Converts raw CARLA vehicle data into normalized observation vectors:
from opencda_marl.core.marl.extractor import ObservationExtractor
extractor = ObservationExtractor(config)
obs_vector = extractor.extract(vehicle_data)
Supported features: relative position, heading, speed, distance to intersection, distance to front vehicle, lane position, waypoint buffer, min TTC, distance to destination.
Algorithm Details¶
The primary algorithm with LSTM multi-agent context encoding.
# Architecture: 8D ego state + LSTM-encoded context → speed action [0, 65 km/h]
# Actor: [8+256, 1024, 1024, 512, 256] → LayerNorm → tanh
# Critic: Twin Q-networks for reduced overestimation
Key features:
- LSTM encoder for processing multi-agent observations
- LayerNorm before tanh to prevent gradient vanishing
- Delayed policy updates (every 2 steps)
- Target policy smoothing with noise
- Exploration noise decay (0.3 → 0.05)
- Warmup phase (1000 steps of vanilla agent)
- SmartReplayBuffer or PrioritizedReplayBuffer (configurable)
Discrete action deep Q-network.
# Network: [state_dim, 64, 32] → Q-values for discrete speed actions
# Actions: Predefined speed levels (e.g., [0, 5, 8, 12, 15] m/s)
Key features: Epsilon-greedy exploration, target network (updated every 100 steps), gradient clipping.
Multi-Agent Proximal Policy Optimization.
Key features: Generalized Advantage Estimation (GAE), Gaussian actor for continuous actions, rollout buffer, clipped surrogate objective.
Soft Actor-Critic with automatic entropy tuning.
Key features: Entropy-regularized objective, automatic temperature (alpha) tuning, twin Q-networks, reparameterization trick.
Checkpoint Management¶
from opencda_marl.core.marl.checkpoint import CheckpointManager
checkpoint_mgr = CheckpointManager(config)
# Save checkpoints (automatic in training)
checkpoint_mgr.save(algorithm, episode=100, reward=350.0)
# Load checkpoints
checkpoint_mgr.load(algorithm, mode="latest") # or "best"
Saves three types: latest_checkpoint.pth, best_checkpoint.pth, episode_XXXX.pth
Metrics & Logging¶
Tracks per episode: total reward, success/collision rates, average speed, target speed gap, throughput, near-miss count, TTC violations.
Exports to CSV in metrics_history/ directory.
Metrics logged to TensorBoard:
- Loss/critic, Loss/actor: Training losses
- Q_values/Q1_mean, Q_values/Q2_mean: Value estimates
- Gradients/critic_pre_clip, Gradients/actor_pre_clip: Gradient norms
- TD3/exploration_noise: Current noise level
- Learning/reward_moving_avg: Smoothed reward trend
- Safety/near_miss_count, Safety/ttc_violation_rate: Safety metrics
Automatic convergence detection checks:
- Reward CV < 15% over 10-episode window
- Success rate CV < 20%
- Collision rate improving (second half ≤ first half × 1.1)
- Minimum 20 episodes before checking
Replay Buffers¶
Default off-policy buffer with recency bias.
- Pre-allocated numpy arrays for O(1) random access
- Circular buffer with O(1) push
- Sampling: 50% recent experiences (last 20%) + 50% diverse
- Pre-emptive clearing at 90% capacity
Optional TD-error prioritized sampling.
- Alpha (priority exponent): 0.6
- Beta annealing: 0.4 → 1.0 for importance sampling
- Enabled via config:
td3.use_per: true
On-policy buffer for MAPPO.
- Stores complete episodes for GAE computation
- Cleared after each policy update
Usage¶
The CheckpointManager will load the best model automatically when training is disabled.
# configs/marl/td3_simple_v4.yaml
MARL:
algorithm: "td3"
state_dim: 8
action_dim: 1
training: true
td3:
learning_rate_actor: 0.0001
learning_rate_critic: 0.001
batch_size: 256
memory_size: 100000
gamma: 0.99
tau: 0.005
exploration_noise: 0.3
noise_decay: 0.998
min_noise: 0.05
warmup_steps: 1000
lstm_hidden: 256
- Location:
opencda_marl/core/marl/ - Algorithms:
opencda_marl/core/marl/algorithms/