Concave-Augmented Pareto Q-Learning (CAPQL)

class morl_baselines.multi_policy.capql.capql.CAPQL(env, learning_rate: float = 0.0003, gamma: float = 0.99, tau: float = 0.005, buffer_size: int = 1000000, net_arch: List = [256, 256], batch_size: int = 128, num_q_nets: int = 2, alpha: float = 0.2, learning_starts: int = 1000, gradient_updates: int = 1, project_name: str = 'MORL-Baselines', experiment_name: str = 'CAPQL', wandb_entity: str | None = None, log: bool = True, seed: int | None = None, device: device | str = 'auto')

CAPQL algorithm.

MULTI-OBJECTIVE REINFORCEMENT LEARNING: CONVEXITY, STATIONARITY AND PARETO OPTIMALITY Haoye Lu, Daniel Herman & Yaoliang Yu ICLR 2023 Paper: https://openreview.net/pdf?id=TjEzIsyEsQ6 Code based on: https://github.com/haoyelu/CAPQL

CAPQL algorithm with continuous actions.

It extends the Soft-Actor Critic algorithm to multi-objective RL. It learns the policy and Q-networks conditioned on the weight vector.

Parameters:
  • env (gym.Env) – The environment to train on.

  • learning_rate (float, optional) – The learning rate. Defaults to 3e-4.

  • gamma (float, optional) – The discount factor. Defaults to 0.99.

  • tau (float, optional) – The soft update coefficient. Defaults to 0.005.

  • buffer_size (int, optional) – The size of the replay buffer. Defaults to int(1e6).

  • net_arch (List, optional) – The network architecture for the policy and Q-networks.

  • batch_size (int, optional) – The batch size for training. Defaults to 256.

  • num_q_nets (int, optional) – The number of Q-networks to use. Defaults to 2.

  • alpha (float, optional) – The entropy regularization coefficient. Defaults to 0.2.

  • learning_starts (int, optional) – The number of steps to take before starting to train. Defaults to 100.

  • gradient_updates (int, optional) – The number of gradient steps to take per update. Defaults to 1.

  • project_name (str, optional) – The name of the project. Defaults to “MORL Baselines”.

  • experiment_name (str, optional) – The name of the experiment. Defaults to “GPI-PD Continuous Action”.

  • wandb_entity (Optional[str], optional) – The wandb entity. Defaults to None.

  • log (bool, optional) – Whether to log to wandb. Defaults to True.

  • seed (Optional[int], optional) – The seed to use. Defaults to None.

  • device (Union[th.device, str], optional) – The device to use for training. Defaults to “auto”.

eval(obs: ndarray | Tensor, w: ndarray | Tensor, torch_action=False) ndarray | Tensor

Evaluate the policy action for the given observation and weight vector.

get_config()

Get the configuration of the agent.

load(path, load_replay_buffer=True)

Load the agent weights from a file.

save(save_dir='weights/', filename=None, save_replay_buffer=True)

Save the agent’s weights and replay buffer.

train(total_timesteps: int, eval_env: Env, ref_point: ndarray, known_pareto_front: List[ndarray] | None = None, num_eval_weights_for_front: int = 100, num_eval_episodes_for_front: int = 5, num_eval_weights_for_eval: int = 50, eval_freq: int = 10000, reset_num_timesteps: bool = False, checkpoints: bool = False, save_freq: int = 10000)

Train the agent.

Parameters:
  • total_timesteps (int) – Total number of timesteps to train the agent for.

  • eval_env (gym.Env) – Environment to use for evaluation.

  • ref_point (np.ndarray) – Reference point for hypervolume calculation.

  • known_pareto_front (Optional[List[np.ndarray]]) – Optimal Pareto front, if known.

  • num_eval_weights_for_front (int) – Number of weights to evaluate for the Pareto front.

  • num_eval_episodes_for_front – number of episodes to run when evaluating the policy.

  • num_eval_weights_for_eval (int) – Number of weights use when evaluating the Pareto front, e.g., for computing expected utility.

  • eval_freq (int) – Number of timesteps between evaluations during an iteration.

  • reset_num_timesteps (bool) – Whether to reset the number of timesteps.

  • checkpoints (bool) – Whether to save checkpoints.

  • save_freq (int) – Number of timesteps between checkpoints.

update()

Update the policy and the Q-nets.