diff --git a/ding/entry/__init__.py b/ding/entry/__init__.py index e0501b12db..b49b087139 100644 --- a/ding/entry/__init__.py +++ b/ding/entry/__init__.py @@ -28,3 +28,4 @@ from .serial_entry_mbrl import serial_pipeline_dyna, serial_pipeline_dream from .serial_entry_bco import serial_pipeline_bco from .serial_entry_pc import serial_pipeline_pc +from .serial_entry_sil import serial_pipeline_sil diff --git a/ding/entry/serial_entry_sil.py b/ding/entry/serial_entry_sil.py new file mode 100644 index 0000000000..c50a3528a2 --- /dev/null +++ b/ding/entry/serial_entry_sil.py @@ -0,0 +1,137 @@ +from typing import Union, Optional, List, Any, Tuple +import os +import torch +from ditk import logging +from functools import partial +from tensorboardX import SummaryWriter +from copy import deepcopy + +from ding.envs import get_vec_env_setting, create_env_manager +from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \ + create_serial_collector, create_serial_evaluator +from ding.config import read_config, compile_config +from ding.policy import create_policy +from ding.utils import set_pkg_seed +from .utils import random_collect + + +def serial_pipeline_sil( + input_cfg: Union[str, Tuple[dict, dict]], + seed: int = 0, + env_setting: Optional[List[Any]] = None, + model: Optional[torch.nn.Module] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), + dynamic_seed: Optional[bool] = True, +) -> 'Policy': # noqa + """ + Overview: + Serial pipeline entry for off-policy RL. + Arguments: + - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ + ``str`` type means config file path. \ + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ + ``BaseEnv`` subclass, collector env config, and evaluator env config. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. + - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. + - dynamic_seed(:obj:`Optional[bool]`): set dynamic seed for collector. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + if isinstance(input_cfg, str): + cfg, create_cfg = read_config(input_cfg) + else: + cfg, create_cfg = deepcopy(input_cfg) + create_cfg.policy.type = create_cfg.policy.type + '_command' + env_fn = None if env_setting is None else env_setting[0] + cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True) + # Create main components: env, policy + if env_setting is None: + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + else: + env_fn, collector_env_cfg, evaluator_env_cfg = env_setting + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(cfg.seed, dynamic_seed=dynamic_seed) + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command']) + + # Create worker components: learner, collector, evaluator, replay buffer, commander. + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + collector = create_serial_collector( + cfg.policy.collect.collector, + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name + ) + evaluator = create_serial_evaluator( + cfg.policy.eval.evaluator, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name + ) + replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name) + commander = BaseSerialCommander( + cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode + ) + # ========== + # Main loop + # ========== + # Learner's before_run hook. + learner.call_hook('before_run') + + while True: + collect_kwargs = commander.step() + # Evaluate policy performance + if evaluator.should_eval(learner.train_iter): + stop, eval_info = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + # Collect data by default config n_sample/n_episode + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + replay_buffer.push(new_data, cur_collector_envstep=collector.envstep) + + tot_train_data = {'new_data': new_data, 'replay_data': []} + + # Learn policy from collected data + for i in range(cfg.policy.sil_update_per_collect): + # Learner will train ``update_per_collect`` times in one iteration. + train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) + if train_data is None: + # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times + logging.warning( + "Replay buffer's data can only train for {} steps. ".format(i) + + "You can modify data collect config, e.g. increasing n_sample, n_episode." + ) + break + tot_train_data['replay_data'].append(train_data) + learner.train(tot_train_data, collector.envstep) + if learner.policy.get_attribute('priority'): + replay_buffer.update(learner.priority_info) + if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: + break + + # Learner's after_run hook. + learner.call_hook('after_run') + import time + import pickle + import numpy as np + with open(os.path.join(cfg.exp_name, 'result.pkl'), 'wb') as f: + eval_value_raw = [d['eval_episode_return'] for d in eval_info] + final_data = { + 'stop': stop, + 'env_step': collector.envstep, + 'train_iter': learner.train_iter, + 'eval_value': np.mean(eval_value_raw), + 'eval_value_raw': eval_value_raw, + 'finish_time': time.ctime(), + } + pickle.dump(final_data, f) + return policy diff --git a/ding/policy/__init__.py b/ding/policy/__init__.py index 65f3f2757e..d906c8a021 100755 --- a/ding/policy/__init__.py +++ b/ding/policy/__init__.py @@ -50,6 +50,7 @@ from .pc import ProcedureCloningBFSPolicy from .bcq import BCQPolicy +from .sil import SILA2CPolicy, SILPPOPolicy # new-type policy from .ppof import PPOFPolicy diff --git a/ding/policy/command_mode_policy_instance.py b/ding/policy/command_mode_policy_instance.py index 8b6123c063..5641fdf3f8 100755 --- a/ding/policy/command_mode_policy_instance.py +++ b/ding/policy/command_mode_policy_instance.py @@ -36,6 +36,7 @@ from .sql import SQLPolicy from .bc import BehaviourCloningPolicy from .ibc import IBCPolicy +from .sil import SILA2CPolicy, SILPPOPolicy from .dqfd import DQFDPolicy from .r2d3 import R2D3Policy @@ -432,3 +433,13 @@ def _get_setting_learn(self, command_info: dict) -> dict: def _get_setting_eval(self, command_info: dict) -> dict: return {} + + +@POLICY_REGISTRY.register('sil_a2c_command') +class SILA2CCommandModePolicy(SILA2CPolicy, DummyCommandModePolicy): + pass + + +@POLICY_REGISTRY.register('sil_ppo_command') +class SILPPOCommandModePolicy(SILPPOPolicy, DummyCommandModePolicy): + pass diff --git a/ding/policy/sil.py b/ding/policy/sil.py new file mode 100644 index 0000000000..cf0c601e9d --- /dev/null +++ b/ding/policy/sil.py @@ -0,0 +1,434 @@ +from typing import List, Dict, Any +import torch + +from ding.rl_utils import sil_data, sil_error, a2c_data, a2c_error, ppo_data, ppo_error, gae, gae_data +from ding.torch_utils import to_device, to_dtype +from ding.utils import POLICY_REGISTRY, split_data_generator +from .ppo import PPOPolicy +from .a2c import A2CPolicy +from .common_utils import default_preprocess_learn + + +@POLICY_REGISTRY.register('sil_a2c') +class SILA2CPolicy(A2CPolicy): + r""" + Overview: + Policy class of SIL algorithm combined with A2C, paper link: https://arxiv.org/abs/1806.05635 + """ + config = dict( + # (string) RL policy register name (refer to function "register_policy"). + type='sil_a2c', + # (bool) Whether to use cuda for network. + cuda=False, + # (bool) Whether to use on-policy training pipeline(behaviour policy and training policy are the same) + on_policy=True, + priority=False, + # (bool) Whether to use Importance Sampling Weight to correct biased update. If True, priority must be True. + priority_IS_weight=False, + # (int) Number of epochs to use SIL loss to update the policy. + sil_update_per_collect=1, + sil_recompute_adv=True, + learn=dict( + update_per_collect=1, # fixed value, this line should not be modified by users + batch_size=64, + learning_rate=0.001, + # (List[float]) + betas=(0.9, 0.999), + # (float) + eps=1e-8, + # (float) + grad_norm=0.5, + # ============================================================== + # The following configs is algorithm-specific + # ============================================================== + # (float) loss weight of the value network, the weight of policy network is set to 1 + value_weight=0.5, + # (float) loss weight of the entropy regularization, the weight of policy network is set to 1 + entropy_weight=0.01, + # (bool) Whether to normalize advantage. Default to False. + adv_norm=False, + ignore_done=False, + ), + collect=dict( + # (int) collect n_sample data, train model n_iteration times + # n_sample=80, + unroll_len=1, + # ============================================================== + # The following configs is algorithm-specific + # ============================================================== + # (float) discount factor for future reward, defaults int [0, 1] + discount_factor=0.9, + # (float) the trade-off factor lambda to balance 1step td and mc + gae_lambda=0.95, + ), + eval=dict(), + ) + + def _forward_learn(self, data: dict) -> Dict[str, Any]: + r""" + Overview: + Forward and backward function of learn mode. + Arguments: + - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs','adv'] + Returns: + - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss. + """ + # Extract off-policy data + data_sil = data['replay_data'] + data_sil = [ + default_preprocess_learn(data_sil[i], ignore_done=self._cfg.learn.ignore_done, use_nstep=False) + for i in range(len(data_sil)) + ] + # Extract on-policy data + data_onpolicy = data['new_data'] + for i in range(len(data_onpolicy)): + data_onpolicy[i] = {k: data_onpolicy[i][k] for k in ['obs', 'next_obs', 'reward', 'adv', 'value', 'action', 'done']} + data_onpolicy = default_preprocess_learn( + data_onpolicy, ignore_done=self._cfg.learn.ignore_done, use_nstep=False + ) + data_onpolicy['weight'] = None + # Put data to correct device. + if self._cuda: + data_onpolicy = to_device(data_onpolicy, self._device) + data_sil = to_device(data_sil, self._device) + self._learn_model.train() + + for batch in split_data_generator(data_onpolicy, self._cfg.learn.batch_size, shuffle=True): + # forward + output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic') + + adv = batch['adv'] + return_ = batch['value'] + adv + if self._adv_norm: + # norm adv in total train_batch + adv = (adv - adv.mean()) / (adv.std() + 1e-8) + error_data = a2c_data(output['logit'], batch['action'], output['value'], adv, return_, batch['weight']) + + # Calculate A2C loss + a2c_loss = a2c_error(error_data) + wv, we = self._value_weight, self._entropy_weight + a2c_total_loss = a2c_loss.policy_loss + wv * a2c_loss.value_loss - we * a2c_loss.entropy_loss + + # ==================== + # A2C-learning update + # ==================== + + self._optimizer.zero_grad() + a2c_total_loss.backward() + + grad_norm = torch.nn.utils.clip_grad_norm_( + list(self._learn_model.parameters()), + max_norm=self._grad_norm, + ) + self._optimizer.step() + + for batch in data_sil: + # forward + with torch.no_grad(): + recomputed_value = self._learn_model.forward(batch['obs'], mode='compute_critic')['value'] + recomputed_next_value = self._learn_model.forward(batch['next_obs'], mode='compute_critic')['value'] + + traj_flag = batch.get('traj_flag', None) # traj_flag indicates termination of trajectory + compute_adv_data = gae_data( + recomputed_value, recomputed_next_value, batch['reward'], batch['done'], traj_flag + ) + recomputed_adv = gae(compute_adv_data, self._gamma, self._gae_lambda) + + recomputed_returns = recomputed_value + recomputed_adv + output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic') + + adv = batch['adv'] if not self._cfg.sil_recompute_adv else recomputed_adv + return_ = batch['value'] + adv if not self._cfg.sil_recompute_adv else recomputed_returns + if self._adv_norm: + # norm adv in total train_batch + adv = (adv - adv.mean()) / (adv.std() + 1e-8) + error_data = sil_data(output['logit'], batch['action'], output['value'], adv, return_, batch['weight']) + + # Calculate SIL loss + sil_loss, sil_info = sil_error(error_data) + wv = self._value_weight + sil_total_loss = sil_loss.policy_loss + wv * sil_loss.value_loss + + # ==================== + # SIL-learning update + # ==================== + + self._optimizer.zero_grad() + sil_total_loss.backward() + + grad_norm = torch.nn.utils.clip_grad_norm_( + list(self._learn_model.parameters()), + max_norm=self._grad_norm, + ) + self._optimizer.step() + + # ============= + # after update + # ============= + # only record last updates information in logger + return { + 'cur_lr': self._optimizer.param_groups[0]['lr'], + 'total_loss': sil_total_loss.item() + a2c_total_loss.item(), + 'sil_total_loss': sil_total_loss.item(), + 'a2c_total_loss': a2c_total_loss.item(), + 'sil_policy_loss': sil_loss.policy_loss.item(), + 'policy_loss': a2c_loss.policy_loss.item(), + 'sil_value_loss': sil_loss.value_loss.item(), + 'value_loss': a2c_loss.value_loss.item(), + 'entropy_loss': a2c_loss.entropy_loss.item(), + 'policy_clipfrac': sil_info.policy_clipfrac, + 'value_clipfrac': sil_info.value_clipfrac, + 'adv_abs_max': adv.abs().max().item(), + 'grad_norm': grad_norm, + } + + def _monitor_vars_learn(self) -> List[str]: + return list( + set( + super()._monitor_vars_learn() + [ + 'sil_policy_loss', 'sil_value_loss', 'a2c_total_loss', 'sil_total_loss', 'policy_clipfrac', + 'value_clipfrac' + ] + ) + ) + + +@POLICY_REGISTRY.register('sil_ppo') +class SILPPOPolicy(PPOPolicy): + r""" + Overview: + Policy class of SIL algorithm combined with PPO, paper link: https://arxiv.org/abs/1806.05635 + """ + config = dict( + # (str) RL policy register name (refer to function "POLICY_REGISTRY"). + type='sil_ppo', + # (bool) Whether to use cuda for network. + cuda=False, + # (bool) Whether the RL algorithm is on-policy or off-policy. (Note: in practice PPO can be off-policy used) + on_policy=True, + # (bool) Whether to use priority(priority sample, IS weight, update priority) + priority=False, + # (bool) Whether to use Importance Sampling Weight to correct biased update due to priority. + # If True, priority must be True. + priority_IS_weight=False, + # (bool) Whether to recompurete advantages in each iteration of on-policy PPO + recompute_adv=True, + # (str) Which kind of action space used in PPOPolicy, ['discrete', 'continuous', 'hybrid'] + action_space='discrete', + # (bool) Whether to use nstep return to calculate value target, otherwise, use return = adv + value + nstep_return=False, + # (bool) Whether to enable multi-agent training, i.e.: MAPPO + multi_agent=False, + # (bool) Whether to need policy data in process transition + transition_with_policy_data=True, + # (int) Number of epochs to use SIL loss to update the policy. + sil_update_per_collect=1, + learn=dict( + epoch_per_collect=10, + batch_size=64, + learning_rate=3e-4, + grad_norm=0.5, + # ============================================================== + # The following configs is algorithm-specific + # ============================================================== + # (float) The loss weight of value network, policy network weight is set to 1 + value_weight=0.5, + # (float) The loss weight of entropy regularization, policy network weight is set to 1 + entropy_weight=0.0, + # (float) PPO clip ratio, defaults to 0.2 + clip_ratio=0.2, + # (bool) Whether to use advantage norm in a whole training batch + adv_norm=True, + value_norm=True, + ppo_param_init=True, + grad_clip_type='clip_norm', + grad_clip_value=0.5, + ignore_done=False, + ), + collect=dict( + # (int) Only one of [n_sample, n_episode] shoule be set + # n_sample=64, + # (int) Cut trajectories into pieces with length "unroll_len". + unroll_len=1, + # ============================================================== + # The following configs is algorithm-specific + # ============================================================== + # (float) Reward's future discount factor, aka. gamma. + discount_factor=0.99, + # (float) GAE lambda factor for the balance of bias and variance(1-step td and mc) + gae_lambda=0.95, + ), + eval=dict(), + ) + + def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: + r""" + Overview: + Forward and backward function of learn mode. + Arguments: + - data (:obj:`dict`): Dict type data + Returns: + - info_dict (:obj:`Dict[str, Any]`): + Including current lr, total_loss, policy_loss, value_loss, entropy_loss, \ + adv_abs_max, approx_kl, clipfrac + """ + # Extract off-policy data + data_sil = data['replay_data'] + data_sil = [ + default_preprocess_learn(data_sil[i], ignore_done=self._cfg.learn.ignore_done, use_nstep=False) + for i in range(len(data_sil)) + ] + # Extract on-policy data + data_onpolicy = data['new_data'] + for i in range(len(data_onpolicy)): + data_onpolicy[i] = {k: data_onpolicy[i][k] for k in ['obs', 'adv', 'value', 'action', 'done', 'next_obs', 'reward', 'logit']} + data_onpolicy = default_preprocess_learn( + data_onpolicy, ignore_done=self._cfg.learn.ignore_done, use_nstep=False + ) + data_onpolicy['weight'] = None + # Put data to correct device. + if self._cuda: + data_onpolicy = to_device(data_onpolicy, self._device) + data_sil = to_device(data_sil, self._device) + self._learn_model.train() + # Convert dtype for on-policy data. + data_onpolicy['obs'] = to_dtype(data_onpolicy['obs'], torch.float32) + if 'next_obs' in data_onpolicy: + data_onpolicy['next_obs'] = to_dtype(data_onpolicy['next_obs'], torch.float32) + # Convert dtype for sil-data. + for i in range(len(data_sil)): + data_sil[i]['obs'] = to_dtype(data_sil[i]['obs'], torch.float32) + if 'next_obs' in data_sil[0]: + for i in range(len(data_sil)): + data_sil[i]['next_obs'] = to_dtype(data_sil[i]['next_obs'], torch.float32) + # ==================== + # PPO forward + # ==================== + return_infos = [] + self._learn_model.train() + + for epoch in range(self._cfg.learn.epoch_per_collect): + if self._recompute_adv: # calculate new value using the new updated value network + with torch.no_grad(): + value = self._learn_model.forward(data_onpolicy['obs'], mode='compute_critic')['value'] + next_value = self._learn_model.forward(data_onpolicy['next_obs'], mode='compute_critic')['value'] + if self._value_norm: + value *= self._running_mean_std.std + next_value *= self._running_mean_std.std + + traj_flag = data_onpolicy.get('traj_flag', None) # traj_flag indicates termination of trajectory + compute_adv_data = gae_data( + value, next_value, data_onpolicy['reward'], data_onpolicy['done'], traj_flag + ) + data_onpolicy['adv'] = gae(compute_adv_data, self._gamma, self._gae_lambda) + + unnormalized_returns = value + data_onpolicy['adv'] + + if self._value_norm: + data_onpolicy['value'] = value / self._running_mean_std.std + data_onpolicy['return'] = unnormalized_returns / self._running_mean_std.std + self._running_mean_std.update(unnormalized_returns.cpu().numpy()) + else: + data_onpolicy['value'] = value + data_onpolicy['return'] = unnormalized_returns + + else: # don't recompute adv + if self._value_norm: + unnormalized_return = data_onpolicy['adv'] + data_onpolicy['value'] * self._running_mean_std.std + data_onpolicy['return'] = unnormalized_return / self._running_mean_std.std + self._running_mean_std.update(unnormalized_return.cpu().numpy()) + else: + data_onpolicy['return'] = data_onpolicy['adv'] + data_onpolicy['value'] + + for batch in split_data_generator(data_onpolicy, self._cfg.learn.batch_size, shuffle=True): + output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic') + adv = batch['adv'] + if self._adv_norm: + # Normalize advantage in a train_batch + adv = (adv - adv.mean()) / (adv.std() + 1e-8) + + # Calculate ppo error + ppo_batch = ppo_data( + output['logit'], batch['logit'], batch['action'], output['value'], batch['value'], adv, + batch['return'], batch['weight'] + ) + ppo_loss, ppo_info = ppo_error(ppo_batch, self._clip_ratio) + wv, we = self._value_weight, self._entropy_weight + total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss + + self._optimizer.zero_grad() + total_loss.backward() + self._optimizer.step() + + return_info = { + 'cur_lr': self._optimizer.defaults['lr'], + 'ppo_total_loss': total_loss.item(), + 'policy_loss': ppo_loss.policy_loss.item(), + 'value_loss': ppo_loss.value_loss.item(), + 'entropy_loss': ppo_loss.entropy_loss.item(), + 'adv_max': adv.max().item(), + 'adv_mean': adv.mean().item(), + 'value_mean': output['value'].mean().item(), + 'value_max': output['value'].max().item(), + 'approx_kl': ppo_info.approx_kl, + 'clipfrac': ppo_info.clipfrac, + } + return_infos.append(return_info) + + return_info_real = { + k: sum([return_infos[i][k] + for i in range(len(return_infos))]) / len([return_infos[i][k] for i in range(len(return_infos))]) + for k in return_infos[0].keys() + } + + for batch in data_sil: + # forward + output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic') + + adv = batch['adv'] + return_ = batch['value'] + adv + if self._adv_norm: + # norm adv in total train_batch + adv = (adv - adv.mean()) / (adv.std() + 1e-8) + error_data = sil_data(output['logit'], batch['action'], output['value'], adv, return_, batch['weight']) + + # Calculate SIL loss + sil_loss, sil_info = sil_error(error_data) + wv = self._value_weight + sil_total_loss = sil_loss.policy_loss + wv * sil_loss.value_loss + + # ==================== + # SIL-learning update + # ==================== + + self._optimizer.zero_grad() + sil_total_loss.backward() + + grad_norm = torch.nn.utils.clip_grad_norm_( + list(self._learn_model.parameters()), + max_norm=self.config["learn"]["grad_norm"], + ) + self._optimizer.step() + + sil_learn_info = { + 'total_loss': sil_total_loss.item() + return_info_real['ppo_total_loss'], + 'sil_total_loss': sil_total_loss.item(), + 'sil_policy_loss': sil_loss.policy_loss.item(), + 'sil_value_loss': sil_loss.value_loss.item(), + 'policy_clipfrac': sil_info.policy_clipfrac, + 'value_clipfrac': sil_info.value_clipfrac + } + + return_info_real.update(sil_learn_info) + return return_info_real + + def _monitor_vars_learn(self) -> List[str]: + variables = list( + set( + super()._monitor_vars_learn() + [ + 'sil_policy_loss', 'sil_value_loss', 'ppo_total_loss', 'sil_total_loss', 'policy_clipfrac', + 'value_clipfrac' + ] + ) + ) + return variables diff --git a/ding/rl_utils/__init__.py b/ding/rl_utils/__init__.py index 080b37ead5..b2b8dcc08b 100644 --- a/ding/rl_utils/__init__.py +++ b/ding/rl_utils/__init__.py @@ -23,3 +23,4 @@ from .acer import acer_policy_error, acer_value_error, acer_trust_region_update from .sampler import ArgmaxSampler, MultinomialSampler, MuSampler, ReparameterizationSampler, HybridStochasticSampler, \ HybridDeterminsticSampler +from .sil import sil_error, sil_data diff --git a/ding/rl_utils/sil.py b/ding/rl_utils/sil.py new file mode 100644 index 0000000000..e3405a2455 --- /dev/null +++ b/ding/rl_utils/sil.py @@ -0,0 +1,44 @@ +from collections import namedtuple +import torch +import torch.nn.functional as F + +sil_data = namedtuple('sil_data', ['logit', 'action', 'value', 'adv', 'return_', 'weight']) +sil_loss = namedtuple('sil_loss', ['policy_loss', 'value_loss']) +sil_info = namedtuple('sil_info', ['policy_clipfrac', 'value_clipfrac']) + + +def sil_error(data: namedtuple) -> namedtuple: + """ + Overview: + Implementation of SIL(Self-Imitation Learning) (arXiv:1806.05635) + Arguments: + - data (:obj:`namedtuple`): SIL input data with fields shown in ``sil_data`` + Returns: + - sil_loss (:obj:`namedtuple`): the SIL loss item, all of them are the differentiable 0-dim tensor + Shapes: + - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim + - action (:obj:`torch.LongTensor`): :math:`(B, )` + - value (:obj:`torch.FloatTensor`): :math:`(B, )` + - adv (:obj:`torch.FloatTensor`): :math:`(B, )` + - return (:obj:`torch.FloatTensor`): :math:`(B, )` + - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` + - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor + - value_loss (:obj:`torch.FloatTensor`): :math:`()` + """ + logit, action, value, adv, return_, weight = data + if weight is None: + weight = torch.ones_like(value) + dist = torch.distributions.categorical.Categorical(logits=logit) + logp = dist.log_prob(action) + + # Clip the negative part of adv. + policy_clipfrac = adv.lt(0).float().mean().item() + adv = adv.clamp_min(0) + policy_loss = -(logp * adv * weight).mean() + + # Clip the negative part of the distance between value and return. + rv_dist = return_ - value + value_clipfrac = rv_dist.lt(0).float().mean().item() + rv_dist = rv_dist.clamp_min(0) + value_loss = (F.mse_loss(rv_dist, torch.zeros_like(rv_dist), reduction='none') * weight).mean() + return sil_loss(policy_loss, value_loss), sil_info(policy_clipfrac, value_clipfrac) diff --git a/ding/rl_utils/tests/test_sil.py b/ding/rl_utils/tests/test_sil.py new file mode 100644 index 0000000000..32af5006b4 --- /dev/null +++ b/ding/rl_utils/tests/test_sil.py @@ -0,0 +1,26 @@ +import pytest +import torch +from ding.rl_utils import sil_data, sil_error + +random_weight = torch.rand(4) + 1 +weight_args = [None, random_weight] + + +@pytest.mark.unittest +@pytest.mark.parametrize('weight, ', weight_args) +def test_sil(weight): + B, N = 4, 32 + logit = torch.randn(B, N).requires_grad_(True) + action = torch.randint(0, N, size=(B, )) + value = torch.randn(B).requires_grad_(True) + adv = torch.rand(B) + return_ = torch.randn(B) * 2 + data = sil_data(logit, action, value, adv, return_, weight) + loss, info = sil_error(data) + assert all([l.shape == tuple() for l in loss]) + assert logit.grad is None + assert value.grad is None + total_loss = sum(loss) + total_loss.backward() + assert isinstance(logit.grad, torch.Tensor) + assert isinstance(value.grad, torch.Tensor) diff --git a/dizoo/atari/config/serial/freeway/freeway_sil_a2c_config.py b/dizoo/atari/config/serial/freeway/freeway_sil_a2c_config.py new file mode 100644 index 0000000000..157212ee9d --- /dev/null +++ b/dizoo/atari/config/serial/freeway/freeway_sil_a2c_config.py @@ -0,0 +1,56 @@ +from easydict import EasyDict + +collector_env_num = 8 +evaluator_env_num = 8 +freeway_sil_config = dict( + exp_name='freeway_sil_a2c_seed0', + env=dict( + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=8, + env_id='FreewayNoFrameskip-v4', + # 'ALE/freewayRevenge-v5' is available. But special setting is needed after gym make. + stop_value=int(1e9), + frame_stack=4, + ), + policy=dict( + cuda=True, + sil_update_per_collect=2, + model=dict( + obs_shape=[4, 84, 84], + action_shape=3, + encoder_hidden_size_list=[128, 128, 512], + critic_head_hidden_size=512, + actor_head_hidden_size=512, + ), + learn=dict( + batch_size=40, + learning_rate=3e-4, + entropy_weight=0.001, + adv_norm=True, + ), + collect=dict( + n_sample=320, + discount_factor=0.99, + gae_lambda=0.95, + ), + ), +) +freeway_sil_config = EasyDict(freeway_sil_config) +main_config = freeway_sil_config + +freeway_sil_create_config = dict( + env=dict( + type='atari', + import_names=['dizoo.atari.envs.atari_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='sil_a2c'), +) +freeway_sil_create_config = EasyDict(freeway_sil_create_config) +create_config = freeway_sil_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial_onpolicy -c freeway_sil_a2c_config.py -s 0` + from ding.entry import serial_pipeline_sil + serial_pipeline_sil((main_config, create_config), seed=0, max_env_step=int(3e7)) diff --git a/dizoo/atari/config/serial/frostbite/frostbite_sil_a2c_config.py b/dizoo/atari/config/serial/frostbite/frostbite_sil_a2c_config.py new file mode 100644 index 0000000000..34f3465a78 --- /dev/null +++ b/dizoo/atari/config/serial/frostbite/frostbite_sil_a2c_config.py @@ -0,0 +1,56 @@ +from easydict import EasyDict + +collector_env_num = 8 +evaluator_env_num = 8 +frostbite_sil_config = dict( + exp_name='frostbite_sil_a2c_seed0', + env=dict( + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=8, + env_id='FrostbiteNoFrameskip-v4', + # 'ALE/frostbiteRevenge-v5' is available. But special setting is needed after gym make. + stop_value=int(1e9), + frame_stack=4, + ), + policy=dict( + cuda=True, + sil_update_per_collect=2, + model=dict( + obs_shape=[4, 84, 84], + action_shape=18, + encoder_hidden_size_list=[128, 128, 512], + critic_head_hidden_size=512, + actor_head_hidden_size=512, + ), + learn=dict( + batch_size=40, + learning_rate=3e-4, + entropy_weight=0.001, + adv_norm=True, + ), + collect=dict( + n_sample=320, + discount_factor=0.99, + gae_lambda=0.95, + ), + ), +) +frostbite_sil_config = EasyDict(frostbite_sil_config) +main_config = frostbite_sil_config + +frostbite_sil_create_config = dict( + env=dict( + type='atari', + import_names=['dizoo.atari.envs.atari_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='sil_a2c'), +) +frostbite_sil_create_config = EasyDict(frostbite_sil_create_config) +create_config = frostbite_sil_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial_onpolicy -c frostbite_sil_a2c_config.py -s 0` + from ding.entry import serial_pipeline_sil + serial_pipeline_sil((main_config, create_config), seed=0, max_env_step=int(3e7)) diff --git a/dizoo/atari/config/serial/gravitar/gravitar_sil_a2c_config.py b/dizoo/atari/config/serial/gravitar/gravitar_sil_a2c_config.py new file mode 100644 index 0000000000..1be32b3813 --- /dev/null +++ b/dizoo/atari/config/serial/gravitar/gravitar_sil_a2c_config.py @@ -0,0 +1,56 @@ +from easydict import EasyDict + +collector_env_num = 8 +evaluator_env_num = 8 +gravitar_sil_config = dict( + exp_name='gravitar_sil_a2c_seed0', + env=dict( + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=8, + env_id='GravitarNoFrameskip-v4', + # 'ALE/gravitarRevenge-v5' is available. But special setting is needed after gym make. + stop_value=int(1e9), + frame_stack=4, + ), + policy=dict( + cuda=True, + sil_update_per_collect=2, + model=dict( + obs_shape=[4, 84, 84], + action_shape=18, + encoder_hidden_size_list=[128, 128, 512], + critic_head_hidden_size=512, + actor_head_hidden_size=512, + ), + learn=dict( + batch_size=40, + learning_rate=3e-4, + entropy_weight=0.001, + adv_norm=True, + ), + collect=dict( + n_sample=320, + discount_factor=0.99, + gae_lambda=0.95, + ), + ), +) +gravitar_sil_config = EasyDict(gravitar_sil_config) +main_config = gravitar_sil_config + +gravitar_sil_create_config = dict( + env=dict( + type='atari', + import_names=['dizoo.atari.envs.atari_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='sil_a2c'), +) +gravitar_sil_create_config = EasyDict(gravitar_sil_create_config) +create_config = gravitar_sil_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial_onpolicy -c gravitar_sil_a2c_config.py -s 0` + from ding.entry import serial_pipeline_sil + serial_pipeline_sil((main_config, create_config), seed=0, max_env_step=int(3e7)) diff --git a/dizoo/atari/config/serial/hero/hero_sil_a2c_config.py b/dizoo/atari/config/serial/hero/hero_sil_a2c_config.py new file mode 100644 index 0000000000..2cb8bd2ae3 --- /dev/null +++ b/dizoo/atari/config/serial/hero/hero_sil_a2c_config.py @@ -0,0 +1,56 @@ +from easydict import EasyDict + +collector_env_num = 8 +evaluator_env_num = 8 +hero_sil_config = dict( + exp_name='hero_sil_a2c_seed0', + env=dict( + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=8, + env_id='HeroNoFrameskip-v4', + # 'ALE/heroRevenge-v5' is available. But special setting is needed after gym make. + stop_value=int(1e9), + frame_stack=4, + ), + policy=dict( + cuda=True, + sil_update_per_collect=2, + model=dict( + obs_shape=[4, 84, 84], + action_shape=18, + encoder_hidden_size_list=[128, 128, 512], + critic_head_hidden_size=512, + actor_head_hidden_size=512, + ), + learn=dict( + batch_size=40, + learning_rate=3e-4, + entropy_weight=0.001, + adv_norm=True, + ), + collect=dict( + n_sample=320, + discount_factor=0.99, + gae_lambda=0.95, + ), + ), +) +hero_sil_config = EasyDict(hero_sil_config) +main_config = hero_sil_config + +hero_sil_create_config = dict( + env=dict( + type='atari', + import_names=['dizoo.atari.envs.atari_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='sil_a2c'), +) +hero_sil_create_config = EasyDict(hero_sil_create_config) +create_config = hero_sil_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial_onpolicy -c hero_sil_a2c_config.py -s 0` + from ding.entry import serial_pipeline_sil + serial_pipeline_sil((main_config, create_config), seed=0, max_env_step=int(3e7)) diff --git a/dizoo/atari/config/serial/montezuma/montezuma_sil_a2c_config.py b/dizoo/atari/config/serial/montezuma/montezuma_sil_a2c_config.py new file mode 100644 index 0000000000..c14a40ea6e --- /dev/null +++ b/dizoo/atari/config/serial/montezuma/montezuma_sil_a2c_config.py @@ -0,0 +1,56 @@ +from easydict import EasyDict + +collector_env_num = 8 +evaluator_env_num = 8 +montezuma_sil_config = dict( + exp_name='montezuma_sil_a2c_seed0', + env=dict( + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=8, + env_id='MontezumaRevengeNoFrameskip-v4', + # 'ALE/MontezumaRevenge-v5' is available. But special setting is needed after gym make. + stop_value=int(1e9), + frame_stack=4, + ), + policy=dict( + cuda=True, + sil_update_per_collect=2, + model=dict( + obs_shape=[4, 84, 84], + action_shape=18, + encoder_hidden_size_list=[128, 128, 512], + critic_head_hidden_size=512, + actor_head_hidden_size=512, + ), + learn=dict( + batch_size=40, + learning_rate=3e-4, + entropy_weight=0.001, + adv_norm=True, + ), + collect=dict( + n_sample=320, + discount_factor=0.99, + gae_lambda=0.95, + ), + ), +) +montezuma_sil_config = EasyDict(montezuma_sil_config) +main_config = montezuma_sil_config + +montezuma_sil_create_config = dict( + env=dict( + type='atari', + import_names=['dizoo.atari.envs.atari_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='sil_a2c'), +) +montezuma_sil_create_config = EasyDict(montezuma_sil_create_config) +create_config = montezuma_sil_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial_onpolicy -c montezuma_sil_a2c_config.py -s 0` + from ding.entry import serial_pipeline_sil + serial_pipeline_sil((main_config, create_config), seed=0, max_env_step=int(3e7)) diff --git a/dizoo/atari/config/serial/private_eye/private_eye_sil_a2c_config.py b/dizoo/atari/config/serial/private_eye/private_eye_sil_a2c_config.py new file mode 100644 index 0000000000..3870690047 --- /dev/null +++ b/dizoo/atari/config/serial/private_eye/private_eye_sil_a2c_config.py @@ -0,0 +1,56 @@ +from easydict import EasyDict + +collector_env_num = 8 +evaluator_env_num = 8 +private_eye_sil_config = dict( + exp_name='private_eye_sil_a2c_seed0', + env=dict( + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=8, + env_id='PrivateEyeNoFrameskip-v4', + # 'ALE/private_eyeRevenge-v5' is available. But special setting is needed after gym make. + stop_value=int(1e9), + frame_stack=4, + ), + policy=dict( + cuda=True, + sil_update_per_collect=2, + model=dict( + obs_shape=[4, 84, 84], + action_shape=18, + encoder_hidden_size_list=[128, 128, 512], + critic_head_hidden_size=512, + actor_head_hidden_size=512, + ), + learn=dict( + batch_size=40, + learning_rate=3e-4, + entropy_weight=0.001, + adv_norm=True, + ), + collect=dict( + n_sample=320, + discount_factor=0.99, + gae_lambda=0.95, + ), + ), +) +private_eye_sil_config = EasyDict(private_eye_sil_config) +main_config = private_eye_sil_config + +private_eye_sil_create_config = dict( + env=dict( + type='atari', + import_names=['dizoo.atari.envs.atari_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='sil_a2c'), +) +private_eye_sil_create_config = EasyDict(private_eye_sil_create_config) +create_config = private_eye_sil_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial_onpolicy -c private_eye_sil_a2c_config.py -s 0` + from ding.entry import serial_pipeline_sil + serial_pipeline_sil((main_config, create_config), seed=0, max_env_step=int(3e7)) diff --git a/dizoo/box2d/lunarlander/config/lunarlander_ppo_config.py b/dizoo/box2d/lunarlander/config/lunarlander_ppo_config.py new file mode 100644 index 0000000000..1668376d1f --- /dev/null +++ b/dizoo/box2d/lunarlander/config/lunarlander_ppo_config.py @@ -0,0 +1,53 @@ +from easydict import EasyDict + +lunarlander_ppo_config = dict( + exp_name='lunarlander_ppo_seed0', + env=dict( + collector_env_num=8, + evaluator_env_num=8, + env_id='LunarLander-v2', + n_evaluator_episode=8, + stop_value=200, + ), + policy=dict( + cuda=True, + model=dict( + obs_shape=8, + action_shape=4, + ), + learn=dict( + update_per_collect=4, + batch_size=64, + learning_rate=0.001, + value_weight=0.5, + entropy_weight=0.01, + clip_ratio=0.2, + nstep=1, + nstep_return=False, + adv_norm=True, + ), + collect=dict( + n_sample=128, + unroll_len=1, + discount_factor=0.99, + gae_lambda=0.95, + ), + ), +) +lunarlander_ppo_config = EasyDict(lunarlander_ppo_config) +main_config = lunarlander_ppo_config +lunarlander_ppo_create_config = dict( + env=dict( + type='lunarlander', + import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='ppo'), +) +lunarlander_ppo_create_config = EasyDict(lunarlander_ppo_create_config) +create_config = lunarlander_ppo_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial -c lunarlander_offppo_config.py -s 0` + from ding.entry import serial_pipeline + serial_pipeline((main_config, create_config), seed=0, max_env_step=int(1e7)) diff --git a/dizoo/box2d/lunarlander/config/lunarlander_sil_a2c_config.py b/dizoo/box2d/lunarlander/config/lunarlander_sil_a2c_config.py new file mode 100644 index 0000000000..ebcebdb6be --- /dev/null +++ b/dizoo/box2d/lunarlander/config/lunarlander_sil_a2c_config.py @@ -0,0 +1,51 @@ +from easydict import EasyDict + +collector_env_num = 4 +evaluator_env_num = 4 +lunarlander_sil_config = dict( + exp_name='lunarlander_sil_a2c_seed0', + env=dict( + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + env_id='LunarLander-v2', + n_evaluator_episode=evaluator_env_num, + stop_value=200, + ), + policy=dict( + cuda=False, + sil_update_per_collect=5, + model=dict( + obs_shape=8, + action_shape=4, + ), + learn=dict( + batch_size=160, + learning_rate=3e-4, + entropy_weight=0.001, + adv_norm=True, + ), + collect=dict( + n_sample=320, + discount_factor=0.99, + gae_lambda=0.95, + ), + ), +) +lunarlander_sil_config = EasyDict(lunarlander_sil_config) +main_config = lunarlander_sil_config + +lunarlander_sil_create_config = dict( + env=dict( + type='lunarlander', + import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='sil_a2c'), +) +lunarlander_sil_create_config = EasyDict(lunarlander_sil_create_config) +create_config = lunarlander_sil_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial_onpolicy -c lunarlander_sil_a2c_config.py -s 0` + from ding.entry import serial_pipeline_sil + serial_pipeline_sil((main_config, create_config), seed=0) diff --git a/dizoo/box2d/lunarlander/config/lunarlander_sil_ppo_config.py b/dizoo/box2d/lunarlander/config/lunarlander_sil_ppo_config.py new file mode 100644 index 0000000000..73b8ab2d9c --- /dev/null +++ b/dizoo/box2d/lunarlander/config/lunarlander_sil_ppo_config.py @@ -0,0 +1,54 @@ +from easydict import EasyDict + +lunarlander_sil_ppo_config = dict( + exp_name='lunarlander_sil_ppo_seed0', + env=dict( + collector_env_num=8, + evaluator_env_num=8, + env_id='LunarLander-v2', + n_evaluator_episode=8, + stop_value=200, + ), + policy=dict( + cuda=True, + sil_update_per_collect=1, + model=dict( + obs_shape=8, + action_shape=4, + ), + learn=dict( + update_per_collect=4, + batch_size=64, + learning_rate=0.001, + value_weight=0.5, + entropy_weight=0.01, + clip_ratio=0.2, + nstep=1, + nstep_return=False, + adv_norm=True, + ), + collect=dict( + n_sample=128, + unroll_len=1, + discount_factor=0.99, + gae_lambda=0.95, + ), + ), +) +lunarlander_sil_ppo_config = EasyDict(lunarlander_sil_ppo_config) +main_config = lunarlander_sil_ppo_config +lunarlander_sil_ppo_create_config = dict( + env=dict( + type='lunarlander', + import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='sil_ppo'), +) +lunarlander_sil_ppo_create_config = EasyDict(lunarlander_sil_ppo_create_config) +create_config = lunarlander_sil_ppo_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial_sil -c lunarlander_sil_ppo_config.py -s 0` + from ding.entry import serial_pipeline_sil + serial_pipeline_sil((main_config, create_config), seed=0, max_env_step=int(1e7)) diff --git a/dizoo/classic_control/cartpole/config/cartpole_sil_a2c_config.py b/dizoo/classic_control/cartpole/config/cartpole_sil_a2c_config.py new file mode 100644 index 0000000000..dc45bdaabd --- /dev/null +++ b/dizoo/classic_control/cartpole/config/cartpole_sil_a2c_config.py @@ -0,0 +1,49 @@ +from easydict import EasyDict + +cartpole_sil_config = dict( + exp_name='cartpole_sil_a2c_seed0', + env=dict( + collector_env_num=8, + evaluator_env_num=5, + n_evaluator_episode=5, + stop_value=195, + ), + policy=dict( + cuda=False, + sil_update_per_collect=5, + model=dict( + obs_shape=4, + action_shape=2, + encoder_hidden_size_list=[128, 128, 64], + ), + learn=dict( + batch_size=40, + learning_rate=0.001, + ), + collect=dict( + # (int) collect n_sample data, train model n_iteration times + n_sample=80, + # (float) the trade-off factor lambda to balance 1step td and mc + gae_lambda=0.95, + ), + eval=dict(evaluator=dict(eval_freq=50, )), + ), +) +cartpole_sil_config = EasyDict(cartpole_sil_config) +main_config = cartpole_sil_config + +cartpole_sil_create_config = dict( + env=dict( + type='cartpole', + import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'], + ), + env_manager=dict(type='base'), + policy=dict(type='sil_a2c'), +) +cartpole_sil_create_config = EasyDict(cartpole_sil_create_config) +create_config = cartpole_sil_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial_onpolicy -c cartpole_sil_config.py -s 0` + from ding.entry import serial_pipeline_sil + serial_pipeline_sil((main_config, create_config), seed=0) diff --git a/dizoo/classic_control/cartpole/config/cartpole_sil_ppo_config.py b/dizoo/classic_control/cartpole/config/cartpole_sil_ppo_config.py new file mode 100644 index 0000000000..281d719361 --- /dev/null +++ b/dizoo/classic_control/cartpole/config/cartpole_sil_ppo_config.py @@ -0,0 +1,57 @@ +from easydict import EasyDict + +cartpole_sil_ppo_config = dict( + exp_name='cartpole_sil_ppo_seed0', + env=dict( + collector_env_num=8, + evaluator_env_num=5, + n_evaluator_episode=5, + stop_value=195, + ), + policy=dict( + cuda=False, + action_space='discrete', + sil_update_per_collect=1, + model=dict( + obs_shape=4, + action_shape=2, + action_space='discrete', + encoder_hidden_size_list=[64, 64, 128], + critic_head_hidden_size=128, + actor_head_hidden_size=128, + ), + learn=dict( + epoch_per_collect=2, + batch_size=64, + learning_rate=0.001, + value_weight=0.5, + entropy_weight=0.01, + clip_ratio=0.2, + learner=dict(hook=dict(save_ckpt_after_iter=100)), + ), + collect=dict( + n_sample=256, + unroll_len=1, + discount_factor=0.9, + gae_lambda=0.95, + ), + eval=dict(evaluator=dict(eval_freq=100, ), ), + ), +) +cartpole_sil_ppo_config = EasyDict(cartpole_sil_ppo_config) +main_config = cartpole_sil_ppo_config +cartpole_sil_ppo_create_config = dict( + env=dict( + type='cartpole', + import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'], + ), + env_manager=dict(type='base'), + policy=dict(type='sil_ppo'), +) +cartpole_sil_ppo_create_config = EasyDict(cartpole_sil_ppo_create_config) +create_config = cartpole_sil_ppo_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial_sil -c cartpole_sil_ppo_config.py -s 0` + from ding.entry import serial_pipeline_sil + serial_pipeline_sil((main_config, create_config), seed=0) diff --git a/dizoo/minigrid/config/minigrid_sil_a2c_config.py b/dizoo/minigrid/config/minigrid_sil_a2c_config.py new file mode 100644 index 0000000000..7c29abdf31 --- /dev/null +++ b/dizoo/minigrid/config/minigrid_sil_a2c_config.py @@ -0,0 +1,57 @@ +from easydict import EasyDict + +collector_env_num = 4 +evaluator_env_num = 4 +minigrid_sil_config = dict( + exp_name='minigrid_sil_a2c_seed0', + env=dict( + collector_env_num=8, + evaluator_env_num=5, + # typical MiniGrid env id: + # {'MiniGrid-Empty-8x8-v0', 'MiniGrid-FourRooms-v0', 'MiniGrid-DoorKey-8x8-v0','MiniGrid-DoorKey-16x16-v0'}, + # please refer to https://github.com/Farama-Foundation/MiniGrid for details. + env_id='MiniGrid-DoorKey-8x8-v0', + n_evaluator_episode=5, + max_step=300, + stop_value=0.96, + ), + policy=dict( + cuda=False, + sil_update_per_collect=1, + model=dict( + obs_shape=2835, + action_shape=7, + encoder_hidden_size_list=[256, 128, 64, 64], + ), + learn=dict( + batch_size=64, + learning_rate=0.0003, + value_weight=0.5, + entropy_weight=0.001, + adv_norm=True, + ), + collect=dict( + n_sample=128, + discount_factor=0.99, + gae_lambda=0.95, + ), + ), +) +minigrid_sil_config = EasyDict(minigrid_sil_config) +main_config = minigrid_sil_config + +minigrid_sil_create_config = dict( + env=dict( + type='minigrid', + import_names=['dizoo.minigrid.envs.minigrid_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='sil_a2c'), +) +minigrid_sil_create_config = EasyDict(minigrid_sil_create_config) +create_config = minigrid_sil_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial_onpolicy -c minigrid_sil_a2c_config.py -s 0` + from ding.entry import serial_pipeline_sil + serial_pipeline_sil((main_config, create_config), seed=0) diff --git a/dizoo/minigrid/config/minigrid_sil_ppo_config.py b/dizoo/minigrid/config/minigrid_sil_ppo_config.py new file mode 100644 index 0000000000..77925548c6 --- /dev/null +++ b/dizoo/minigrid/config/minigrid_sil_ppo_config.py @@ -0,0 +1,64 @@ +from easydict import EasyDict + +collector_env_num = 8 +minigrid_sil_ppo_config = dict( + exp_name="minigrid_sil_ppo_seed0", + env=dict( + collector_env_num=8, + evaluator_env_num=5, + n_evaluator_episode=5, + # typical MiniGrid env id: + # {'MiniGrid-Empty-8x8-v0', 'MiniGrid-FourRooms-v0', 'MiniGrid-DoorKey-8x8-v0','MiniGrid-DoorKey-16x16-v0'}, + # please refer to https://github.com/Farama-Foundation/MiniGrid for details. + env_id='MiniGrid-Empty-8x8-v0', + max_step=300, + stop_value=0.96, + ), + policy=dict( + cuda=True, + recompute_adv=True, + sil_update_per_collect=1, + action_space='discrete', + model=dict( + obs_shape=2835, + action_shape=7, + action_space='discrete', + encoder_hidden_size_list=[256, 128, 64, 64], + ), + learn=dict( + epoch_per_collect=10, + update_per_collect=1, + batch_size=320, + learning_rate=3e-4, + value_weight=0.5, + entropy_weight=0.001, + clip_ratio=0.2, + adv_norm=True, + value_norm=True, + ), + collect=dict( + collector_env_num=collector_env_num, + n_sample=int(3200), + unroll_len=1, + discount_factor=0.99, + gae_lambda=0.95, + ), + ), +) +minigrid_sil_ppo_config = EasyDict(minigrid_sil_ppo_config) +main_config = minigrid_sil_ppo_config +minigrid_sil_ppo_create_config = dict( + env=dict( + type='minigrid', + import_names=['dizoo.minigrid.envs.minigrid_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='sil_ppo'), +) +minigrid_sil_ppo_create_config = EasyDict(minigrid_sil_ppo_create_config) +create_config = minigrid_sil_ppo_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial_sil -c minigrid_sil_ppo_config.py -s 0` + from ding.entry import serial_pipeline_sil + serial_pipeline_sil((main_config, create_config), seed=0) diff --git a/dizoo/minigrid/envs/minigrid_env.py b/dizoo/minigrid/envs/minigrid_env.py index e0bdbfbc07..f495dfa59f 100644 --- a/dizoo/minigrid/envs/minigrid_env.py +++ b/dizoo/minigrid/envs/minigrid_env.py @@ -60,7 +60,7 @@ def reset(self) -> np.ndarray: self._env = ObsPlusPrevActRewWrapper(self._env) self._init_flag = True if self._flat_obs: - self._observation_space = gym.spaces.Box(0, 1, shape=(2835, ), dytpe=np.float32) + self._observation_space = gym.spaces.Box(0, 1, shape=(2835, )) else: self._observation_space = self._env.observation_space # to be compatiable with subprocess env manager @@ -70,7 +70,7 @@ def reset(self) -> np.ndarray: self._observation_space.dtype = np.dtype('float32') self._action_space = self._env.action_space self._reward_space = gym.spaces.Box( - low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32 + low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ) ) self._eval_episode_return = 0