Skip to content

Commit

Permalink
Save handler file
Browse files Browse the repository at this point in the history
  • Loading branch information
MogicianXD committed Mar 30, 2024
1 parent 07b8ba9 commit c77520c
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions examples/benchmarks/MASTER/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,21 @@ def parse_args():
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
qlib.init(provider_uri=provider_uri, region=REG_CN)
with open("./workflow_config_master_Alpha158.yaml", 'r') as f:
basic_config = yaml.safe_load(f)
config = yaml.safe_load(f)

h_conf = config["task"]["dataset"]["kwargs"]["handler"]
h_path = DIRNAME / 'handler_{config["dataset"]["kwargs"]["segments"]["train"]}_{config["dataset"]["kwargs"]["segments"]["test"]}.pkl'
if not h_path.exists():
h = init_instance_by_config(h_conf)
h.to_pickle(h_path, dump_all=True)
print('Save preprocessed data to', h_path)
config["dataset"]["kwargs"]["handler"] = f"file://{h_path}"
dataset = init_instance_by_config(config['task']["dataset"])

###################################
# train model
###################################

dataset = init_instance_by_config(basic_config['task']["dataset"])
if not os.path.exists('./model'):
os.mkdir("./model")

Expand All @@ -63,14 +72,14 @@ def parse_args():
print("------------------------")
print(f"seed: {seed}")

basic_config['task']["model"]['kwargs']["seed"] = seed
model = init_instance_by_config(basic_config['task']["model"])
config['task']["model"]['kwargs']["seed"] = seed
model = init_instance_by_config(config['task']["model"])

# start exp
if not args.only_backtest:
model.fit(dataset=dataset)
else:
model.load_model(f"./model/{basic_config['market']}master_{seed}.pkl")
model.load_model(f"./model/{config['market']}master_{seed}.pkl")

with R.start(experiment_name=f"workflow_seed{seed}"):
# prediction
Expand All @@ -84,7 +93,7 @@ def parse_args():

# backtest. If users want to use backtest based on their own prediction,
# please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template.
par = PortAnaRecord(recorder, basic_config['port_analysis_config'], "day")
par = PortAnaRecord(recorder, config['port_analysis_config'], "day")
par.generate()

metrics = recorder.list_metrics()
Expand Down

0 comments on commit c77520c

Please sign in to comment.