From c77520c5f99116f726ddfe0c5f1cb6a7e6f87573 Mon Sep 17 00:00:00 2001 From: MogicianXD <1023276135@qq.com> Date: Sat, 30 Mar 2024 18:55:56 +0800 Subject: [PATCH] Save handler file --- examples/benchmarks/MASTER/main.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/examples/benchmarks/MASTER/main.py b/examples/benchmarks/MASTER/main.py index d60b3a019d..c1a7331f9f 100644 --- a/examples/benchmarks/MASTER/main.py +++ b/examples/benchmarks/MASTER/main.py @@ -38,12 +38,21 @@ def parse_args(): GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True) qlib.init(provider_uri=provider_uri, region=REG_CN) with open("./workflow_config_master_Alpha158.yaml", 'r') as f: - basic_config = yaml.safe_load(f) + config = yaml.safe_load(f) + + h_conf = config["task"]["dataset"]["kwargs"]["handler"] + h_path = DIRNAME / 'handler_{config["dataset"]["kwargs"]["segments"]["train"]}_{config["dataset"]["kwargs"]["segments"]["test"]}.pkl' + if not h_path.exists(): + h = init_instance_by_config(h_conf) + h.to_pickle(h_path, dump_all=True) + print('Save preprocessed data to', h_path) + config["dataset"]["kwargs"]["handler"] = f"file://{h_path}" + dataset = init_instance_by_config(config['task']["dataset"]) + ################################### # train model ################################### - dataset = init_instance_by_config(basic_config['task']["dataset"]) if not os.path.exists('./model'): os.mkdir("./model") @@ -63,14 +72,14 @@ def parse_args(): print("------------------------") print(f"seed: {seed}") - basic_config['task']["model"]['kwargs']["seed"] = seed - model = init_instance_by_config(basic_config['task']["model"]) + config['task']["model"]['kwargs']["seed"] = seed + model = init_instance_by_config(config['task']["model"]) # start exp if not args.only_backtest: model.fit(dataset=dataset) else: - model.load_model(f"./model/{basic_config['market']}master_{seed}.pkl") + model.load_model(f"./model/{config['market']}master_{seed}.pkl") with R.start(experiment_name=f"workflow_seed{seed}"): # prediction @@ -84,7 +93,7 @@ def parse_args(): # backtest. If users want to use backtest based on their own prediction, # please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template. - par = PortAnaRecord(recorder, basic_config['port_analysis_config'], "day") + par = PortAnaRecord(recorder, config['port_analysis_config'], "day") par.generate() metrics = recorder.list_metrics()