Skip to content

Commit

Permalink
upgrade to 2.0-style for ray.tune
Browse files Browse the repository at this point in the history
  • Loading branch information
donglihe-hub committed Jun 20, 2024
1 parent cb30e6a commit 8f13496
Showing 1 changed file with 28 additions and 18 deletions.
46 changes: 28 additions & 18 deletions search_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,11 @@ def prepare_retrain_config(best_config, best_log_dir, retrain):
best_config.merge_train_val = False


def load_static_data(config, merge_train_val=False):
def load_static_data(config):
"""Preload static data once for multiple trials.
Args:
config (AttributeDict): Config of the experiment.
merge_train_val (bool, optional): Whether to merge the training and validation data.
Defaults to False.
Returns:
dict: A dict of static data containing datasets, classes, and word_dict.
Expand All @@ -187,7 +185,7 @@ def load_static_data(config, merge_train_val=False):
test_data=config.test_file,
val_data=config.val_file,
val_size=config.val_size,
merge_train_val=merge_train_val,
merge_train_val=config.merge_train_val,
tokenize_text="lm_weight" not in config.network_config,
remove_no_label_data=config.remove_no_label_data,
)
Expand Down Expand Up @@ -231,7 +229,7 @@ def retrain_best_model(exp_name, best_config, best_log_dir, retrain):
with open(os.path.join(checkpoint_dir, "params.yml"), "w") as fp:
yaml.dump(dict(best_config), fp)

data = load_static_data(best_config, merge_train_val=best_config.merge_train_val)
data = load_static_data(best_config)

if retrain:
logging.info(f"Re-training with best config: \n{best_config}")
Expand Down Expand Up @@ -303,7 +301,7 @@ def main():
config = init_search_params_spaces(config, parameter_columns, prefix="")
parser.set_defaults(**config)
config = AttributeDict(vars(parser.parse_args()))
# no need to include validation during parameter search
# Validation sets are mandatoray during parameter search
config.merge_train_val = False
config.mode = "min" if config.val_metric == "Loss" else "max"

Expand Down Expand Up @@ -344,20 +342,32 @@ def main():
Path(config.config).stem if config.config else config.model_name,
datetime.now().strftime("%Y%m%d%H%M%S"),
)
analysis = tune.run(
tune.with_parameters(train_libmultilabel_tune, **data),
search_alg=init_search_algorithm(config.search_alg, metric=f"val_{config.val_metric}", mode=config.mode),
scheduler=scheduler,
storage_path=config.result_dir,
num_samples=config.num_samples,
resources_per_trial={"cpu": config.cpu_count, "gpu": config.gpu_count},
progress_reporter=reporter,
config=config,
name=exp_name,

tuner = tune.Tuner(
tune.with_resources(
tune.with_parameters(train_libmultilabel_tune, **data),
resources={"cpu": config.cpu_count, "gpu": config.gpu_count},
),
param_space=config,
tune_config=tune.TuneConfig(
scheduler=scheduler,
num_samples=config.num_samples,
search_alg=init_search_algorithm(
search_alg=config.search_alg,
metric=f"val_{config.val_metric}",
mode=config.mode,
),
),
run_config=ray_train.RunConfig(
name=exp_name,
storage_path=config.result_dir,
progress_reporter=reporter,
),
)
results = tuner.fit()
# Save best model after parameter search.
best_trial = analysis.get_best_trial(metric=f"val_{config.val_metric}", mode=config.mode, scope="all")
retrain_best_model(exp_name, best_trial.config, best_trial.local_path, retrain=not config.no_retrain)
best_result = results.get_best_result(metric=f"val_{config.val_metric}", mode=config.mode, scope="all")
retrain_best_model(exp_name, best_result.config, best_result.path, retrain=not config.no_retrain)


if __name__ == "__main__":
Expand Down

0 comments on commit 8f13496

Please sign in to comment.