Replay Buffers

Multiple implementations of replay buffers are available in the library. These are listed below:

Multi-Objective Replay Buffer

class morl_baselines.common.buffer.ReplayBuffer(obs_shape, action_dim, rew_dim=1, max_size=100000, obs_dtype=<class 'numpy.float32'>, action_dtype=<class 'numpy.float32'>)

Multi-objective replay buffer for multi-objective reinforcement learning.

Initialize the replay buffer.

Parameters:
  • obs_shape – Shape of the observations

  • action_dim – Dimension of the actions

  • rew_dim – Dimension of the rewards

  • max_size – Maximum size of the buffer

  • obs_dtype – Data type of the observations

  • action_dtype – Data type of the actions

add(obs, action, reward, next_obs, done)

Add a new experience to the buffer.

Parameters:
  • obs – Observation

  • action – Action

  • reward – Reward

  • next_obs – Next observation

  • done – Done

get_all_data(max_samples=None)

Get all the data in the buffer (with a maximum specified).

Parameters:

max_samples – Maximum number of samples to return

Returns:

A tuple of (observations, actions, rewards, next observations, dones)

sample(batch_size, replace=True, use_cer=False, to_tensor=False, device=None)

Sample a batch of experiences from the buffer.

Parameters:
  • batch_size – Batch size

  • replace – Whether to sample with replacement

  • use_cer – Whether to use CER

  • to_tensor – Whether to convert the data to PyTorch tensors

  • device – Device to use

Returns:

A tuple of (observations, actions, rewards, next observations, dones)

sample_obs(batch_size, replace=True, to_tensor=False, device=None)

Sample a batch of observations from the buffer.

Parameters:
  • batch_size – Batch size

  • replace – Whether to sample with replacement

  • to_tensor – Whether to convert the data to PyTorch tensors

  • device – Device to use

Returns:

A batch of observations

Diverse Replay Buffer

class morl_baselines.common.diverse_buffer.DiverseMemory(main_capacity: int, sec_capacity: int = 0, trace_diversity: bool = True, crowding_diversity: bool = True, value_function=<function DiverseMemory.<lambda>>, e: float = 0.01, a: float = 2)

Prioritized Replay Buffer with integrated secondary Diverse Replay Buffer. Code extracted from https://github.com/axelabels/DynMORL.

Initializes the DiverseMemory.

Parameters:
  • main_capacity – Normal prioritized replay capacity

  • sec_capacity – Size of the secondary diverse replay buffer, if 0, the buffer functions as a normal prioritized Replay Buffer (default: {0})

  • trace_diversity – Whether diversity should be enforced at trace-level (True) or at transition-level (False)

  • crowding_diversity – Whether a crowding distance is applied to compute diversity

  • value_function – When applied to a trace, this function should return the trace’s value to be used in the crowding distance computation

  • e – epsilon to be added to errors (default: {0.01})

  • a – Power to which the error will be raised, if a==0, functionality is reduced to a replay buffer without prioritization (default: {2})

add(error, sample, trace_id=None, pred_idx=None, tree_id=None)

Add the sample to the replay buffer, with a priority proportional to its error. If trace_id is provided, the sample and the other samples with the same id will be treated as a trace when determining diversity.

Parameters:
  • error – Error

  • sample – The transition to be stored

  • trace_id – The trace’s identifier (default: {None})

  • tree_id – The tree for which the error is relevant (default: {None})

Returns:

The index of the node in which the sample was stored

add_sample(transition, error, write=None)

Stores the transition into the priority tree.

Parameters:
  • transition – Tuple containing the trace id, the sample and the previous sample’s index

  • error – Dictionary containing the error for each tree

  • write – Index to write the transition to

add_tree(tree_id)

Adds a secondary priority tree.

Parameters:

tree_id – The secondary tree’s id

dupe(trg_i, src_i)

Copies the tree src_i into a new tree trg_i.

Parameters:
  • trg_i – target tree identifier

  • src_i – source tree identifier

extract_trace(start: int)

Determines the end of the trace starting at position start.

Parameters:

start – Trace’s starting position

Returns:

The trace’s end position

get(indices: list)

Given a list of node indices, this method returns the data stored at those indices.

Parameters:

indices – List of indices

Returns:

array of transitions

get_data(include_indices: bool = False)

Get all the data stored in the replay buffer.

Parameters:

include_indices – Whether to include each sample’s position in the replay buffer (default: {False})

Returns:

The data

get_error(idx, tree_id=None)

Given a node’s idx, this method returns the corresponding error in the tree identified by tree_id.

Parameters:
  • idx – Node’s index

  • tree_id – Identifies the tree to update (default: {None})

Returns:

Error

get_sec_write(secondary_traces, trace, reserved_idx=None)

Given a trace, find free spots in the secondary memory to store it by recursively removing past traces with a low crowding distance.

get_trace_value(trace_tuple)

Applies the value_function to the trace’s data to compute its value.

Parameters:

trace_tuple – Tuple containing the trace and the trace’s indices

Returns:

The trace’s value

main_mem_is_full()

Because of the circular way in which we fill the memory, checking whether the current write position is free is sufficient to know if the memory is full.

move_to_sec(start: int, end: int)

Move the trace spanning from start to end to the secondary replay buffer.

Parameters:
  • start – Start position of the trace

  • end – End position of the trace

remove_trace(trace)

Removes the trace from the main memory.

Parameters:

trace – List of indices for the trace

sample(n: int, tree_id=None)

Sample n transitions from the replay buffer, following the priorities of the tree identified by tree_id.

Parameters:
  • n – Number of transitions to sample

  • tree_id – identifier of the tree whose priorities should be followed (default: {None})

Returns:

pair of (indices, transitions)

sec_distances(traces)

Give a set of traces, this method computes each trace’s crowding distance.

Parameters:

traces – List of trace tuples

Returns:

List of distances

update(idx: int, error: float, tree_id=None)

Given a node’s idx, this method updates the corresponding priority in the tree identified by tree_id.

Parameters:
  • idx – Node’s index

  • error – New error

  • tree_id – Identifies the tree to update (default: {None})

Prioritized Replay Buffer

class morl_baselines.common.prioritized_buffer.PrioritizedReplayBuffer(obs_shape, action_dim, rew_dim=1, max_size=100000, obs_dtype=<class 'numpy.float32'>, action_dtype=<class 'numpy.float32'>, min_priority=1e-05)

Prioritized Replay Buffer.

Initialize the Prioritized Replay Buffer.

Parameters:
  • obs_shape – Shape of the observations

  • action_dim – Dimension of the actions

  • rew_dim – Dimension of the rewards

  • max_size – Maximum size of the buffer

  • obs_dtype – Data type of the observations

  • action_dtype – Data type of the actions

  • min_priority – Minimum priority of the buffer

add(obs, action, reward, next_obs, done, priority=None)

Add a new experience to the buffer.

Parameters:
  • obs – Observation

  • action – Action

  • reward – Reward

  • next_obs – Next observation

  • done – Done

  • priority – Priority of the new experience

get_all_data(max_samples=None, to_tensor=False, device=None)

Get all the data in the buffer.

Parameters:
  • max_samples – Maximum number of samples to return

  • to_tensor – Whether to convert the batch to a tensor

  • device – Device to move the tensor to

Returns:

batch – Batch of experiences

sample(batch_size, to_tensor=False, device=None)

Sample a batch of experience tuples from the buffer.

Parameters:
  • batch_size – Number of experiences to sample

  • to_tensor – Whether to convert the batch to a tensor

  • device – Device to move the tensor to

Returns:

batch – Batch of experiences

sample_obs(batch_size, to_tensor=False, device=None)

Sample a batch of observations from the buffer.

Parameters:
  • batch_size – Number of observations to sample

  • to_tensor – Whether to convert the batch to a tensor

  • device – Device to move the tensor to

Returns:

batch – Batch of observations

update_priorities(idxes, priorities)

Update the priorities of the experiences at idxes.

Parameters:
  • idxes – Indexes of the experiences to update

  • priorities – New priorities of the experiences

Accrued Reward Replay Buffer

class morl_baselines.common.accrued_reward_buffer.AccruedRewardReplayBuffer(obs_shape, action_shape, rew_dim=1, max_size=100000, obs_dtype=<class 'numpy.float32'>, action_dtype=<class 'numpy.float32'>)

Replay buffer with accrued rewards stored (for ESR algorithms).

Initialize the Replay Buffer.

Parameters:
  • obs_shape – Shape of the observations

  • action_shape – Shape of the actions

  • rew_dim – Dimension of the rewards

  • max_size – Maximum size of the buffer

  • obs_dtype – Data type of the observations

  • action_dtype – Data type of the actions

add(obs, accrued_reward, action, reward, next_obs, done)

Add a new experience to memory.

Parameters:
  • obs – Observation

  • accrued_reward – Accrued reward

  • action – Action

  • reward – Reward

  • next_obs – Next observation

  • done – Done

cleanup()

Cleanup the buffer.

get_all_data(to_tensor=False, device=None)

Returns the whole buffer.

Parameters:
  • to_tensor – Whether to convert the data to tensors or not

  • device – Device to use for the tensors

Returns:

Tuple of (obs, accrued_rewards, actions, rewards, next_obs, dones)

sample(batch_size, replace=True, use_cer=False, to_tensor=False, device=None)

Sample a batch of experiences.

Parameters:
  • batch_size – Number of elements to sample

  • replace – Whether to sample with replacement or not

  • use_cer – Whether to use CER or not

  • to_tensor – Whether to convert the data to tensors or not

  • device – Device to use for the tensors

Returns:

Tuple of (obs, accrued_rewards, actions, rewards, next_obs, dones)