From 8f13496c42941d0a85b0f4fea439fb16ea11c402 Mon Sep 17 00:00:00 2001 From: Dongli He Date: Wed, 19 Jun 2024 21:03:23 +0400 Subject: [PATCH] upgrade to 2.0-style for ray.tune --- search_params.py | 46 ++++++++++++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/search_params.py b/search_params.py index 911e7344..aad38ece 100644 --- a/search_params.py +++ b/search_params.py @@ -171,13 +171,11 @@ def prepare_retrain_config(best_config, best_log_dir, retrain): best_config.merge_train_val = False -def load_static_data(config, merge_train_val=False): +def load_static_data(config): """Preload static data once for multiple trials. Args: config (AttributeDict): Config of the experiment. - merge_train_val (bool, optional): Whether to merge the training and validation data. - Defaults to False. Returns: dict: A dict of static data containing datasets, classes, and word_dict. @@ -187,7 +185,7 @@ def load_static_data(config, merge_train_val=False): test_data=config.test_file, val_data=config.val_file, val_size=config.val_size, - merge_train_val=merge_train_val, + merge_train_val=config.merge_train_val, tokenize_text="lm_weight" not in config.network_config, remove_no_label_data=config.remove_no_label_data, ) @@ -231,7 +229,7 @@ def retrain_best_model(exp_name, best_config, best_log_dir, retrain): with open(os.path.join(checkpoint_dir, "params.yml"), "w") as fp: yaml.dump(dict(best_config), fp) - data = load_static_data(best_config, merge_train_val=best_config.merge_train_val) + data = load_static_data(best_config) if retrain: logging.info(f"Re-training with best config: \n{best_config}") @@ -303,7 +301,7 @@ def main(): config = init_search_params_spaces(config, parameter_columns, prefix="") parser.set_defaults(**config) config = AttributeDict(vars(parser.parse_args())) - # no need to include validation during parameter search + # Validation sets are mandatoray during parameter search config.merge_train_val = False config.mode = "min" if config.val_metric == "Loss" else "max" @@ -344,20 +342,32 @@ def main(): Path(config.config).stem if config.config else config.model_name, datetime.now().strftime("%Y%m%d%H%M%S"), ) - analysis = tune.run( - tune.with_parameters(train_libmultilabel_tune, **data), - search_alg=init_search_algorithm(config.search_alg, metric=f"val_{config.val_metric}", mode=config.mode), - scheduler=scheduler, - storage_path=config.result_dir, - num_samples=config.num_samples, - resources_per_trial={"cpu": config.cpu_count, "gpu": config.gpu_count}, - progress_reporter=reporter, - config=config, - name=exp_name, + + tuner = tune.Tuner( + tune.with_resources( + tune.with_parameters(train_libmultilabel_tune, **data), + resources={"cpu": config.cpu_count, "gpu": config.gpu_count}, + ), + param_space=config, + tune_config=tune.TuneConfig( + scheduler=scheduler, + num_samples=config.num_samples, + search_alg=init_search_algorithm( + search_alg=config.search_alg, + metric=f"val_{config.val_metric}", + mode=config.mode, + ), + ), + run_config=ray_train.RunConfig( + name=exp_name, + storage_path=config.result_dir, + progress_reporter=reporter, + ), ) + results = tuner.fit() # Save best model after parameter search. - best_trial = analysis.get_best_trial(metric=f"val_{config.val_metric}", mode=config.mode, scope="all") - retrain_best_model(exp_name, best_trial.config, best_trial.local_path, retrain=not config.no_retrain) + best_result = results.get_best_result(metric=f"val_{config.val_metric}", mode=config.mode, scope="all") + retrain_best_model(exp_name, best_result.config, best_result.path, retrain=not config.no_retrain) if __name__ == "__main__":