Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Push master #4

Merged
merged 3 commits into from
Mar 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,5 @@ tags

./pretrain
.idea/

*.log
16 changes: 16 additions & 0 deletions examples/benchmarks/MASTER/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
## overview
This is an alternative version of the MASTER benchmark.

paper: [MASTER: Market-Guided Stock Transformer for Stock Price Forecasting](https://arxiv.org/abs/2312.15235)

codes: [https://github.com/SJTU-Quant/MASTER](https://github.com/SJTU-Quant/MASTER)

## run
You can directly use the bash script to run the codes (you can set the `universe` and `only_backtest` flag in `run.sh`), this `main.py` will test the model with 10 random seeds:
```
bash run.sh
```
<!-- or you can just directly use `qrun` tp run the codes (note that you should modify your `qlib`, since we add or modify some files in `qlib/contrib/data/dataset.py`, `qlib/data/dataset/__init__.py`, `qlib/data/dataset/processor.py` and `qlib/contrib/model/pytorch_master.py`):
```
qrun workflow_config_master_Alpha158.yaml
``` -->
97 changes: 97 additions & 0 deletions examples/benchmarks/MASTER/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Qlib provides two kinds of interfaces.
(1) Users could define the Quant research workflow by a simple configuration.
(2) Qlib is designed in a modularized way and supports creating research workflow by code just like building blocks.

The interface of (1) is `qrun XXX.yaml`. The interface of (2) is script like this, which nearly does the same thing as `qrun XXX.yaml`
"""
import sys
from pathlib import Path
DIRNAME = Path(__file__).absolute().resolve().parent
sys.path.append(str(DIRNAME))
sys.path.append(str(DIRNAME.parent.parent.parent))

import qlib
from qlib.constant import REG_CN
from qlib.utils import init_instance_by_config
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord, SigAnaRecord
from qlib.tests.data import GetData
import yaml
import argparse
import os
import pprint as pp
import numpy as np

def parse_args():
"""parse arguments. You can add other arguments if needed."""
parser = argparse.ArgumentParser()
parser.add_argument("--only_backtest", action="store_true", help="whether only backtest or not")
return parser.parse_args()

if __name__ == "__main__":
args = parse_args()
# use default data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
qlib.init(provider_uri=provider_uri, region=REG_CN)
with open("./workflow_config_master_Alpha158.yaml", 'r') as f:
basic_config = yaml.safe_load(f)
###################################
# train model
###################################

dataset = init_instance_by_config(basic_config['task']["dataset"])
if not os.path.exists('./model'):
os.mkdir("./model")

all_metrics = {
k: []
for k in [
"IC",
"ICIR",
"Rank IC",
"Rank ICIR",
"1day.excess_return_without_cost.annualized_return",
"1day.excess_return_without_cost.information_ratio",
]
}

for seed in range(0, 10):
print("------------------------")
print(f"seed: {seed}")

basic_config['task']["model"]['kwargs']["seed"] = seed
model = init_instance_by_config(basic_config['task']["model"])

# start exp
if not args.only_backtest:
model.fit(dataset=dataset)
else:
model.load_model(f"./model/{basic_config['market']}master_{seed}.pkl")

with R.start(experiment_name=f"workflow_seed{seed}"):
# prediction
recorder = R.get_recorder()
sr = SignalRecord(model, dataset, recorder)
sr.generate()

# Signal Analysis
sar = SigAnaRecord(recorder)
sar.generate()

# backtest. If users want to use backtest based on their own prediction,
# please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template.
par = PortAnaRecord(recorder, basic_config['port_analysis_config'], "day")
par.generate()

metrics = recorder.list_metrics()
print(metrics)
for k in all_metrics.keys():
all_metrics[k].append(metrics[k])
pp.pprint(all_metrics)

for k in all_metrics.keys():
print(f"{k}: {np.mean(all_metrics[k])} +- {np.std(all_metrics[k])}")
23 changes: 23 additions & 0 deletions examples/benchmarks/MASTER/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
if [ ! -d "./logs" ]; then
mkdir ./logs
fi
if [ ! -d "./backtest" ]; then
mkdir ./backtest
fi

# set the config
universe=csi300
only_backtest=false

sed -i "s/csi.../$universe/g" workflow_config_master_Alpha158.yaml
if [ $universe == 'csi300' ]; then
sed -i "s/SH....../SH000300/g" workflow_config_master_Alpha158.yaml
elif [ $universe == 'csi500' ]; then
sed -i "s/SH....../SH000905/g" workflow_config_master_Alpha158.yaml
fi
if $only_backtest; then
nohup python -u main.py --only_backtest > ./backtest/${universe}.log 2>&1 &
else
nohup python -u main.py > ./logs/${universe}.log 2>&1 &
fi
echo $!
95 changes: 95 additions & 0 deletions examples/benchmarks/MASTER/workflow_config_master_Alpha158.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -5) / Ref($close, -1) - 1"]
market_data_handler_config: &market_data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
signal: <PRED>
topk: 30
n_drop: 30
backtest:
start_time: 2017-01-01
end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
exchange_kwargs:
deal_price: close
task:
model:
class: MASTERModel
module_path: qlib.contrib.model.pytorch_master_ts
kwargs:
seed: 0
n_epochs: 40
lr: 0.000008
train_stop_loss_thred: 0.95
market: *market
benchmark: *benchmark
save_prefix: *market
dataset:
class: MASTERTSDatasetH
module_path: qlib.contrib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
step_len: 8
market_data_handler_config: *market_data_handler_config
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs:
model: <MODEL>
dataset: <DATASET>
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
140 changes: 139 additions & 1 deletion qlib/contrib/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
import numpy as np
import pandas as pd

from qlib.data.dataset import DatasetH
from qlib.data.dataset import DatasetH, TSDatasetH, TSDataSampler
from typing import Callable, Union, List, Tuple, Dict, Text, Optional
from qlib.data.dataset.handler import DataHandler, DataHandlerLP
from qlib.contrib.data.handler import check_transform_proc


device = "cuda" if torch.cuda.is_available() else "cpu"
Expand Down Expand Up @@ -351,3 +354,138 @@ def __iter__(self):
}

# end indice loop

###################################################################################
# lqa: for MASTER
class marketDataHandler(DataHandlerLP):
"""Market Data Handler for MASTER (see `examples/benchmarks/MASTER`)

Args:
instruments (str): instrument list
start_time (str): start time
end_time (str): end time
freq (str): data frequency
infer_processors (list): inference processors
learn_processors (list): learning processors
fit_start_time (str): fit start time
fit_end_time (str): fit end time
process_type (str): process type
filter_pipe (list): filter pipe
inst_processors (list): instrument processors
"""
def __init__(
self,
instruments="csi300",
start_time=None,
end_time=None,
freq="day",
infer_processors=[],
learn_processors=[],
fit_start_time=None,
fit_end_time=None,
process_type=DataHandlerLP.PTYPE_A,
filter_pipe=None,
inst_processors=None,
**kwargs
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)

data_loader = {
"class": "QlibDataLoader",
"kwargs": {
"config": {
"feature": self.get_feature_config(),
},
"filter_pipe": filter_pipe,
"freq": freq,
"inst_processors": inst_processors,
},
}
super().__init__(
instruments=instruments,
start_time=start_time,
end_time=end_time,
data_loader=data_loader,
infer_processors=infer_processors,
learn_processors=learn_processors,
process_type=process_type,
**kwargs
)

@staticmethod
def get_feature_config():
"""
Get market feature (63-dimensional), which are csi100 index, csi300 index, csi500 index.
The first list is the name to be shown for the feature, and the second list is the feature to fecth.
"""
return (
['Mask($close/Ref($close,1)-1, "sh000300")', 'Mask(Mean($close/Ref($close,1)-1,5), "sh000300")', 'Mask(Std($close/Ref($close,1)-1,5), "sh000300")', 'Mask(Mean($volume,5)/$volume, "sh000300")', 'Mask(Std($volume,5)/$volume, "sh000300")', 'Mask(Mean($close/Ref($close,1)-1,10), "sh000300")', 'Mask(Std($close/Ref($close,1)-1,10), "sh000300")', 'Mask(Mean($volume,10)/$volume, "sh000300")', 'Mask(Std($volume,10)/$volume, "sh000300")', 'Mask(Mean($close/Ref($close,1)-1,20), "sh000300")', 'Mask(Std($close/Ref($close,1)-1,20), "sh000300")', 'Mask(Mean($volume,20)/$volume, "sh000300")', 'Mask(Std($volume,20)/$volume, "sh000300")', 'Mask(Mean($close/Ref($close,1)-1,30), "sh000300")', 'Mask(Std($close/Ref($close,1)-1,30), "sh000300")', 'Mask(Mean($volume,30)/$volume, "sh000300")', 'Mask(Std($volume,30)/$volume, "sh000300")', 'Mask(Mean($close/Ref($close,1)-1,60), "sh000300")', 'Mask(Std($close/Ref($close,1)-1,60), "sh000300")', 'Mask(Mean($volume,60)/$volume, "sh000300")', 'Mask(Std($volume,60)/$volume, "sh000300")',
'Mask($close/Ref($close,1)-1, "sh000903")', 'Mask(Mean($close/Ref($close,1)-1,5), "sh000903")', 'Mask(Std($close/Ref($close,1)-1,5), "sh000903")', 'Mask(Mean($volume,5)/$volume, "sh000903")', 'Mask(Std($volume,5)/$volume, "sh000903")', 'Mask(Mean($close/Ref($close,1)-1,10), "sh000903")', 'Mask(Std($close/Ref($close,1)-1,10), "sh000903")', 'Mask(Mean($volume,10)/$volume, "sh000903")', 'Mask(Std($volume,10)/$volume, "sh000903")', 'Mask(Mean($close/Ref($close,1)-1,20), "sh000903")', 'Mask(Std($close/Ref($close,1)-1,20), "sh000903")', 'Mask(Mean($volume,20)/$volume, "sh000903")', 'Mask(Std($volume,20)/$volume, "sh000903")', 'Mask(Mean($close/Ref($close,1)-1,30), "sh000903")', 'Mask(Std($close/Ref($close,1)-1,30), "sh000903")', 'Mask(Mean($volume,30)/$volume, "sh000903")', 'Mask(Std($volume,30)/$volume, "sh000903")', 'Mask(Mean($close/Ref($close,1)-1,60), "sh000903")', 'Mask(Std($close/Ref($close,1)-1,60), "sh000903")', 'Mask(Mean($volume,60)/$volume, "sh000903")', 'Mask(Std($volume,60)/$volume, "sh000903")',
'Mask($close/Ref($close,1)-1, "sh000905")', 'Mask(Mean($close/Ref($close,1)-1,5), "sh000905")', 'Mask(Std($close/Ref($close,1)-1,5), "sh000905")', 'Mask(Mean($volume,5)/$volume, "sh000905")', 'Mask(Std($volume,5)/$volume, "sh000905")', 'Mask(Mean($close/Ref($close,1)-1,10), "sh000905")', 'Mask(Std($close/Ref($close,1)-1,10), "sh000905")', 'Mask(Mean($volume,10)/$volume, "sh000905")', 'Mask(Std($volume,10)/$volume, "sh000905")', 'Mask(Mean($close/Ref($close,1)-1,20), "sh000905")', 'Mask(Std($close/Ref($close,1)-1,20), "sh000905")', 'Mask(Mean($volume,20)/$volume, "sh000905")', 'Mask(Std($volume,20)/$volume, "sh000905")', 'Mask(Mean($close/Ref($close,1)-1,30), "sh000905")', 'Mask(Std($close/Ref($close,1)-1,30), "sh000905")', 'Mask(Mean($volume,30)/$volume, "sh000905")', 'Mask(Std($volume,30)/$volume, "sh000905")', 'Mask(Mean($close/Ref($close,1)-1,60), "sh000905")', 'Mask(Std($close/Ref($close,1)-1,60), "sh000905")', 'Mask(Mean($volume,60)/$volume, "sh000905")', 'Mask(Std($volume,60)/$volume, "sh000905")'],
['Mask($close/Ref($close,1)-1, "sh000300")', 'Mask(Mean($close/Ref($close,1)-1,5), "sh000300")', 'Mask(Std($close/Ref($close,1)-1,5), "sh000300")', 'Mask(Mean($volume,5)/$volume, "sh000300")', 'Mask(Std($volume,5)/$volume, "sh000300")', 'Mask(Mean($close/Ref($close,1)-1,10), "sh000300")', 'Mask(Std($close/Ref($close,1)-1,10), "sh000300")', 'Mask(Mean($volume,10)/$volume, "sh000300")', 'Mask(Std($volume,10)/$volume, "sh000300")', 'Mask(Mean($close/Ref($close,1)-1,20), "sh000300")', 'Mask(Std($close/Ref($close,1)-1,20), "sh000300")', 'Mask(Mean($volume,20)/$volume, "sh000300")', 'Mask(Std($volume,20)/$volume, "sh000300")', 'Mask(Mean($close/Ref($close,1)-1,30), "sh000300")', 'Mask(Std($close/Ref($close,1)-1,30), "sh000300")', 'Mask(Mean($volume,30)/$volume, "sh000300")', 'Mask(Std($volume,30)/$volume, "sh000300")', 'Mask(Mean($close/Ref($close,1)-1,60), "sh000300")', 'Mask(Std($close/Ref($close,1)-1,60), "sh000300")', 'Mask(Mean($volume,60)/$volume, "sh000300")', 'Mask(Std($volume,60)/$volume, "sh000300")',
'Mask($close/Ref($close,1)-1, "sh000903")', 'Mask(Mean($close/Ref($close,1)-1,5), "sh000903")', 'Mask(Std($close/Ref($close,1)-1,5), "sh000903")', 'Mask(Mean($volume,5)/$volume, "sh000903")', 'Mask(Std($volume,5)/$volume, "sh000903")', 'Mask(Mean($close/Ref($close,1)-1,10), "sh000903")', 'Mask(Std($close/Ref($close,1)-1,10), "sh000903")', 'Mask(Mean($volume,10)/$volume, "sh000903")', 'Mask(Std($volume,10)/$volume, "sh000903")', 'Mask(Mean($close/Ref($close,1)-1,20), "sh000903")', 'Mask(Std($close/Ref($close,1)-1,20), "sh000903")', 'Mask(Mean($volume,20)/$volume, "sh000903")', 'Mask(Std($volume,20)/$volume, "sh000903")', 'Mask(Mean($close/Ref($close,1)-1,30), "sh000903")', 'Mask(Std($close/Ref($close,1)-1,30), "sh000903")', 'Mask(Mean($volume,30)/$volume, "sh000903")', 'Mask(Std($volume,30)/$volume, "sh000903")', 'Mask(Mean($close/Ref($close,1)-1,60), "sh000903")', 'Mask(Std($close/Ref($close,1)-1,60), "sh000903")', 'Mask(Mean($volume,60)/$volume, "sh000903")', 'Mask(Std($volume,60)/$volume, "sh000903")',
'Mask($close/Ref($close,1)-1, "sh000905")', 'Mask(Mean($close/Ref($close,1)-1,5), "sh000905")', 'Mask(Std($close/Ref($close,1)-1,5), "sh000905")', 'Mask(Mean($volume,5)/$volume, "sh000905")', 'Mask(Std($volume,5)/$volume, "sh000905")', 'Mask(Mean($close/Ref($close,1)-1,10), "sh000905")', 'Mask(Std($close/Ref($close,1)-1,10), "sh000905")', 'Mask(Mean($volume,10)/$volume, "sh000905")', 'Mask(Std($volume,10)/$volume, "sh000905")', 'Mask(Mean($close/Ref($close,1)-1,20), "sh000905")', 'Mask(Std($close/Ref($close,1)-1,20), "sh000905")', 'Mask(Mean($volume,20)/$volume, "sh000905")', 'Mask(Std($volume,20)/$volume, "sh000905")', 'Mask(Mean($close/Ref($close,1)-1,30), "sh000905")', 'Mask(Std($close/Ref($close,1)-1,30), "sh000905")', 'Mask(Mean($volume,30)/$volume, "sh000905")', 'Mask(Std($volume,30)/$volume, "sh000905")', 'Mask(Mean($close/Ref($close,1)-1,60), "sh000905")', 'Mask(Std($close/Ref($close,1)-1,60), "sh000905")', 'Mask(Mean($volume,60)/$volume, "sh000905")', 'Mask(Std($volume,60)/$volume, "sh000905")']
)

class MASTERTSDatasetH(TSDatasetH):
"""
MASTER Time Series Dataset with Handler

Args:
market_data_handler_config (dict): market data handler config
"""
def __init__(
self,
market_data_handler_config = Dict,
**kwargs,
):
super().__init__(**kwargs)
marketdl = marketDataHandler(**market_data_handler_config)
self.market_dataset = DatasetH(marketdl, segments = self.segments)


def get_market_information(
self,
slc: slice,
) -> Union[List[pd.DataFrame], pd.DataFrame]:
return self.market_dataset.prepare(slc)

def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
dtype = kwargs.pop("dtype", None)
if not isinstance(slc, slice):
slc = slice(*slc)
start, end = slc.start, slc.stop
flt_col = kwargs.pop("flt_col", None)
# TSDatasetH will retrieve more data for complete time-series

ext_slice = self._extend_slice(slc, self.cal, self.step_len)
only_label = kwargs.pop("only_label", False)
data = super(TSDatasetH, self)._prepare_seg(ext_slice, **kwargs)

############################## Add market information ###########################
# If we only need label for testing, we do not need to add market information
if not only_label:
marketData = self.get_market_information(ext_slice)
cols = pd.MultiIndex.from_tuples([("feature", feature) for feature in marketData.columns])
marketData = pd.DataFrame(marketData.values, columns = cols, index = marketData.index)
data = data.iloc[:,:-1].join(marketData).join(data.iloc[:,-1])
#################################################################################
flt_kwargs = copy.deepcopy(kwargs)
if flt_col is not None:
flt_kwargs["col_set"] = flt_col
flt_data = super()._prepare_seg(ext_slice, **flt_kwargs)
assert len(flt_data.columns) == 1
else:
flt_data = None

tsds = TSDataSampler(
data=data,
start=start,
end=end,
step_len=self.step_len,
dtype=dtype,
flt_data=flt_data,
fillna_type = "ffill+bfill"
)
return tsds
Loading
Loading