-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
126 lines (99 loc) · 5.57 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""Sample script for using sb3-gym-interface to train an RL agent
Examples:
(DEBUG)
python train.py --offline --env <myGymEnv-v0> -t 1000 --eval_freq 500 --reward_threshold
(OFFICIAL)
python train.py --env <myGymEnv-v0> -t 5000000 --eval_freq 40000 --seed 42 --now 12 --algo ppo --reward_threshold
"""
from pprint import pprint
import argparse
import pdb
import sys
import socket
import os
import gym
import torch
import wandb
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import SubprocVecEnv
# import random_envs
# from customvecenvs.RandomVecEnv import RandomSubprocVecEnv
from utils.utils import *
from policy.policy import Policy
def main():
args.eval_freq = max(args.eval_freq // args.now, 1) # Making eval_freq behave w.r.t. global timesteps, so it follows --timesteps convention
torch.set_num_threads(max(6, args.now)) # hard-coded for now. Avoids taking up all CPUs when parallelizing with multiple environments and processes on hephaestus
assert args.env is not None
if args.test_env is None:
args.test_env = args.env
pprint(vars(args))
set_seed(args.seed)
random_string = get_random_string(5)
wandb.init(config=vars(args),
project="<PROJECT_NAME>",
group=('default_group' if args.group is None else args.group),
name=args.algo+'_seed'+str(args.seed)+'_'+random_string,
save_code=True,
tags=None,
notes=args.notes,
mode=('online' if not args.offline else 'disabled'))
run_path = "runs/"+str(args.env)+"/"+get_run_name(args)+"_"+random_string+"/"
create_dirs(run_path)
save_config(vars(args), run_path)
wandb.config.path = run_path
wandb.config.hostname = socket.gethostname()
# env = gym.make(args.env)
env = make_vec_env(args.env, n_envs=args.now, seed=args.seed, vec_env_cls=SubprocVecEnv)
test_env = gym.make(args.test_env)
policy = Policy(algo=args.algo,
env=env,
lr=1e-3,
gradient_steps=args.gradient_steps,
device=args.device,
seed=args.seed)
print('--- POLICY TRAINING ---')
avg_return, std_return, best_policy, info = policy.train(timesteps=args.timesteps,
stopAtRewardThreshold=args.reward_threshold,
n_eval_episodes=args.eval_episodes,
eval_freq=args.eval_freq,
best_model_save_path=run_path,
return_best_model=True)
policy.save_state_dict(run_path+"final_model.pth")
policy.save_full_state(run_path+"final_full_state.zip")
torch.save(best_policy, run_path+"overall_best.pth")
wandb.save(run_path+"overall_best.pth")
print('\n\nMean reward and stdev:', avg_return, std_return)
wandb.run.summary["train_avg_return"] = avg_return
wandb.run.summary["train_std_return"] = std_return
wandb.run.summary["which_best_model"] = info['which_one']
print('\n\n--- POLICY EVALUATION ---')
test_env = make_vec_env(args.test_env, n_envs=args.now, seed=args.seed, vec_env_cls=SubprocVecEnv)
policy = Policy(algo=args.algo, env=test_env, device=args.device, seed=args.seed)
policy.load_state_dict(best_policy)
avg_return, std_return = policy.eval(n_eval_episodes=args.test_episodes)
print('\n\nTest return (avg & std):', avg_return, std_return)
wandb.run.summary["test_avg_return"] = avg_return
wandb.run.summary["test_std_return"] = std_return
wandb.finish()
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--env', default=None, type=str, help='Train gym env [Hopper-v3, ...]')
parser.add_argument('--test_env', default=None, type=str, help='Test gym env')
parser.add_argument('--group', default=None, type=str, help='Wandb run group')
parser.add_argument('--algo', default='sac', type=str, help='RL Algo [ppo, sac]')
parser.add_argument('--lr', default=None, type=float, help='Learning rate')
parser.add_argument('--gradient_steps', default=-1, type=int, help='Number of gradient steps when policy is updated in sb3 using SAC. -1 means as many as --args.now')
parser.add_argument('--now', default=1, type=int, help='Number of parallel environments, i.e. Number Of Workers')
parser.add_argument('--timesteps', '-t', default=1000, type=int, help='Training timesteps (global across all parallel environments)')
parser.add_argument('--eval_freq', default=10000, type=int, help='timesteps frequency for training evaluations (global across all parallel environments)')
parser.add_argument('--reward_threshold', default=False, action='store_true', help='Stop at reward threshold')
parser.add_argument('--eval_episodes', default=50, type=int, help='# episodes for training evaluations')
parser.add_argument('--test_episodes', default=100, type=int, help='# episodes for test evaluations')
parser.add_argument('--seed', default=0, type=int, help='Random seed')
parser.add_argument('--device', default='cpu', type=str, help='<cpu,cuda>')
parser.add_argument('--notes', default=None, type=str, help='Wandb notes')
parser.add_argument('--offline', default=False, action='store_true', help='Offline run without wandb')
return parser.parse_args()
args = parse_args()
if __name__ == '__main__':
main()