forked from pytorch/torchtune
-
Notifications
You must be signed in to change notification settings - Fork 2
/
7B_full_ppo_low_memory.yaml
181 lines (161 loc) · 5.47 KB
/
7B_full_ppo_low_memory.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# Config for single device RLHF full finetuning using PPO in ppo_full_finetune_single_device.py
# using a Mistral 7B model.
#
# This config has been tested on an A100 80GB.
# This config uses hyperparameters based on small set of experiments and information
# available from existing implementations.
#
# This config assumes that you've run the following command before launching
# this run:
# tune download weqweasdas/RM-Mistral-7B --output-dir /tmp/RM-Mistral-7B/ --ignore-patterns None
# tune download mistralai/Mistral-7B-Instruct-v0.2 --output-dir /tmp/Mistral-7B-Instruct-v0.2/ --hf-token HF_TOKEN
#
# You'll also need to ensure that {output_dir} exists beforehand, as checkpoints for policy and value models are saved in sub-folders.
# The default config uses an optimizer from bitsandbytes. If you do not have it installed,
# you can install it with
# pip install bitsandbytes
#
# To launch on a single device, run the following command from root:
# tune run ppo_full_finetune_single_device --config mistral/7B_full_ppo_low_memory
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run ppo_full_finetune_single_device --config mistral/7B_full_ppo_low_memory checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# Tokenizer
tokenizer:
_component_: torchtune.models.mistral.mistral_tokenizer
path: /tmp/Mistral-7B-Instruct-v0.2/tokenizer.model
max_seq_len: null
# Dataset
dataset:
_component_: torchtune.datasets.text_completion_dataset
source: trl-internal-testing/sentiment-trl-style
split: train
column: prompt
add_eos: False
policy_model:
_component_: torchtune.models.mistral.mistral_7b
# we need to manually build the mistral classifier model
# because our reward model checkpoint has a larger vocabulary size (due to an added padding token)
reward_and_value_model:
_component_: torchtune.models.mistral._component_builders.mistral_classifier
attn_dropout: 0.0
embed_dim: 4096
intermediate_dim: 14336
max_seq_len: 32768
norm_eps: 1.0e-05
num_classes: 1
num_heads: 32
num_kv_heads: 8
num_layers: 32
vocab_size: 32001
# checkpointer for the policy model - update this if resuming from checkpoint
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Mistral-7B-Instruct-v0.2/
checkpoint_files:
[
"pytorch_model-00001-of-00003.bin",
"pytorch_model-00002-of-00003.bin",
"pytorch_model-00003-of-00003.bin",
]
# this is the only place where you should update `recipe_checkpoint` if resuming training
recipe_checkpoint: null
output_dir: ${output_dir}/policy
model_type: MISTRAL
# this should be setup identically to the policy model checkpointer at the start of training
# ensure `checkpoint_files` always points to the original policy weights, even if resuming training
ref_policy_checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Mistral-7B-Instruct-v0.2/
checkpoint_files:
[
"pytorch_model-00001-of-00003.bin",
"pytorch_model-00002-of-00003.bin",
"pytorch_model-00003-of-00003.bin",
]
output_dir: ${output_dir}/policy
model_type: MISTRAL
# checkpointer for the value model - update `checkpoint_files` if resuming from checkpoint
# since this model will be identical to the reward model it's helpful to initialise this
# from the trained reward model weights
value_checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/RM-Mistral-7B/
checkpoint_files:
[
"model-00001-of-00003.safetensors",
"model-00002-of-00003.safetensors",
"model-00003-of-00003.safetensors",
]
output_dir: ${output_dir}/value
model_type: REWARD
# checkpointer for the reward model, ensure `checkpoint_files`
# always points to the original reward model weights, even if resuming training
reward_checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/RM-Mistral-7B/
checkpoint_files:
[
"model-00001-of-00003.safetensors",
"model-00002-of-00003.safetensors",
"model-00003-of-00003.safetensors",
]
output_dir: ${output_dir}/value
model_type: REWARD
resume_from_checkpoint: False
output_dir: /tmp/mistral7b-ppo-finetune
seed: null
shuffle: True
# Training env
device: cuda
# Training arguments
batch_size: 64
num_steps: 10000
ppo_epochs: 2
ppo_batch_size: 32
gradient_accumulation_steps: 1
# Memory management and performance
compile: True
optimizer:
_component_: bitsandbytes.optim.PagedAdamW
lr: 3e-6
optimizer_in_bwd: True
log_peak_memory_stats: False
enable_activation_checkpointing: True
# Reduced precision
dtype: bf16
# batch size for forward pass during generation
forward_batch_size: 16
max_generated_tokens: 58
temperature: 0.7
top_k: null
# parameter for penalising generations shorter than `min_response_length`
min_response_length: 18
# parameter for penalising generations without a stop token
penalise_no_eos: True
# scalar penalty to apply when penalising
reward_penalty: -3
# tokens to consider as "end of sequence" tokens
stop_token_ids: [
2, # eos_id
28723, # mistral "." token
]
whiten_rewards: False
# GAE hyperparameters
gamma: 1
lmbda: 0.95
# PPO hyperparameters
loss:
_component_: torchtune.rlhf.loss.PPOLoss
epsilon: 0.2
value_coeff: 0.1
value_clip_range: 0.2
kl_coeff: 0.01
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
log_every_n_steps: 1