Training details
Replay buffers
In single-agent RL, the replay buffer always stores environment transitions in the form (obs, action, reward, done, next_obs).
However, many MARL algorithms use replay buffers that store entire episodes instead.
Except for VDN, we use episode-based replay buffers for the implemented off-policy algorithms (QMIX, MADDPG, FACMAC).
Episode-based replay buffers are easier to handle when training recurrent policies, since we have access to full sequences and can use accurate hidden states.
In transition-based buffers, as we did for VDN, we store sequences of transitions.
When we sample a sequence to update the policy, we first initialize the hidden state h = None.
Then, we use the first burn_in steps to warm up the hidden state (without backpropagation), and use the rest of the sequence for the policy update.
h_target = None
h_utility = None
with torch.no_grad(): ## First burn some steps
for t in range(args.burn_in):
...
_,h_target = target_network(batch_next_obs_t,h = h_target)
_,h_utility = utility_network(batch_obs_t,h=h_utility)
loss = 0 ## We can start backpropagating
for t in range(args.burn_in, args.seq_length):
with torch.no_grad():
...
q_next,h_target = target_network(batch_next_obs_t,h=h_target,avail_action =batch_next_avail_action_t )
...
targets = batch_reward[:,t] + args.gamma * (1-batch_done[:,t])*vdn_q_max
...
q_values,h_utility = utility_network(batch_obs_t,h=h_utility)
...
loss += F.mse_loss(targets,vqn_q_values)
loss = loss / (args.seq_length - args.burn_in)
optimizer.zero_grad()
loss.backward()
Parallel environments
Parallel environments are particularly useful for on-policy algorithms, since we need to collect new trajectories at each training step using the current policy. For off-policy algorithms, however, the training data is sampled from a replay buffer, so how fast we fill the buffer matters less because off-policy methods can be trained using data from older policies.
Parallel environments may even perform worse compared to single-environment versions.
If we update the policies every train_freq steps in a single environment setting, we would effectively update them every num_parallel_envs * train_freq steps in a parallel setup.
To compensate for this, instead of backpropagating once per training step, we perform multiple epochs of training per step in the parallel setting.
We also keep track of num_updates to monitor the total number of policy updates for each algorithm.
When using multiple environments in parallel with recurrent policies, some episodes may complete before others. We track alive environments at each timestep. This is critical for RNN policies, as the hidden state has an initial size of (num_envs x num_agents, hidden_dim) but eventually becomes of shape (num_alive_envs x num_agents, hidden_dim) when some episodes finish.
TD returns & advantages
For TD(λ) returns use the recursive formula from Reconciling λ-Returns with Experience Replay (Equation 3) . We start by \(R^{\lambda}_T = 0\)
We don’t handle time-outs except in the coma_lbf.py implementation, where we keep track of whether the last step is truncated or done.
If it’s truncated, we add the action-value of the last observation.
We don’t directly estimate the advantages using GAE estimates, we instead use the TD(λ) returns by exploiting the following formula that can be found in page 47 in David Silver’s lecture n 4
Recurrent policies
Instead of unrolling recurrent policies over the entire episode, we split them into short chunks and backpropagate only over those chunks.
The chunk size is configured using the tbptt argument.
After tbptt steps, the hidden state is detached.
To unroll over the full episode, set tbptt to the maximum episode length (the environment’s time_limit).
truncated_actor_loss = None
actor_loss_denominator = None
T = None
for t in range(b_obs.size(1)):
...
actor_loss = -pg_loss - args.entropy_coef*entropy_loss
total_actor_loss += actor_loss
if truncated_actor_loss is None:
truncated_actor_loss = actor_loss
actor_loss_denominator = (b_mask[:,t].sum())
T = 1
else:
truncated_actor_loss += actor_loss
actor_loss_denominator += (b_mask[:,t].sum())
T += 1
if ((t+1) % args.tbptt == 0) or (t == (b_obs.size(1)-1)):
truncated_actor_loss = truncated_actor_loss/(actor_loss_denominator*T)
actor_optimizer.zero_grad()
truncated_actor_loss.backward()
h = h.detach()