diff --git a/.gitignore b/.gitignore index 8854c25e99..7f0a74198f 100644 --- a/.gitignore +++ b/.gitignore @@ -49,3 +49,5 @@ tags ./pretrain .idea/ + +*.log diff --git a/examples/benchmarks/MASTER_w_qrun/README.md b/examples/benchmarks/MASTER/README.md similarity index 100% rename from examples/benchmarks/MASTER_w_qrun/README.md rename to examples/benchmarks/MASTER/README.md diff --git a/examples/benchmarks/MASTER_w_qrun/main.py b/examples/benchmarks/MASTER/main.py similarity index 100% rename from examples/benchmarks/MASTER_w_qrun/main.py rename to examples/benchmarks/MASTER/main.py diff --git a/examples/benchmarks/MASTER_w_qrun/run.sh b/examples/benchmarks/MASTER/run.sh similarity index 100% rename from examples/benchmarks/MASTER_w_qrun/run.sh rename to examples/benchmarks/MASTER/run.sh diff --git a/examples/benchmarks/MASTER_w_qrun/workflow_config_master_Alpha158.yaml b/examples/benchmarks/MASTER/workflow_config_master_Alpha158.yaml similarity index 98% rename from examples/benchmarks/MASTER_w_qrun/workflow_config_master_Alpha158.yaml rename to examples/benchmarks/MASTER/workflow_config_master_Alpha158.yaml index d2a60ce5ab..9a479a27d3 100644 --- a/examples/benchmarks/MASTER_w_qrun/workflow_config_master_Alpha158.yaml +++ b/examples/benchmarks/MASTER/workflow_config_master_Alpha158.yaml @@ -63,6 +63,7 @@ task: train_stop_loss_thred: 0.95 market: *market benchmark: *benchmark + save_prefix: *market dataset: class: MASTERTSDatasetH module_path: qlib.contrib.data.dataset diff --git a/qlib/contrib/data/dataset.py b/qlib/contrib/data/dataset.py index 8b40dba1fc..3729317201 100644 --- a/qlib/contrib/data/dataset.py +++ b/qlib/contrib/data/dataset.py @@ -7,7 +7,10 @@ import numpy as np import pandas as pd -from qlib.data.dataset import DatasetH +from qlib.data.dataset import DatasetH, TSDatasetH, TSDataSampler +from typing import Callable, Union, List, Tuple, Dict, Text, Optional +from qlib.data.dataset.handler import DataHandler, DataHandlerLP +from qlib.contrib.data.handler import check_transform_proc device = "cuda" if torch.cuda.is_available() else "cpu" @@ -351,3 +354,117 @@ def __iter__(self): } # end indice loop +################################################################################### +# lqa: for MASTER +class marketDataHandler(DataHandlerLP): + def __init__( + self, + instruments="csi300", + start_time=None, + end_time=None, + freq="day", + infer_processors=[], + learn_processors=[], + fit_start_time=None, + fit_end_time=None, + process_type=DataHandlerLP.PTYPE_A, + filter_pipe=None, + inst_processors=None, + **kwargs + ): + infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) + learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time) + + data_loader = { + "class": "QlibDataLoader", + "kwargs": { + "config": { + "feature": self.get_feature_config(), + }, + "filter_pipe": filter_pipe, + "freq": freq, + "inst_processors": inst_processors, + }, + } + super().__init__( + instruments=instruments, + start_time=start_time, + end_time=end_time, + data_loader=data_loader, + infer_processors=infer_processors, + learn_processors=learn_processors, + process_type=process_type, + **kwargs + ) + + @staticmethod + def get_feature_config(): + return ( + ['Mask($close/Ref($close,1)-1, "sh000300")', 'Mask(Mean($close/Ref($close,1)-1,5), "sh000300")', 'Mask(Std($close/Ref($close,1)-1,5), "sh000300")', 'Mask(Mean($volume,5)/$volume, "sh000300")', 'Mask(Std($volume,5)/$volume, "sh000300")', 'Mask(Mean($close/Ref($close,1)-1,10), "sh000300")', 'Mask(Std($close/Ref($close,1)-1,10), "sh000300")', 'Mask(Mean($volume,10)/$volume, "sh000300")', 'Mask(Std($volume,10)/$volume, "sh000300")', 'Mask(Mean($close/Ref($close,1)-1,20), "sh000300")', 'Mask(Std($close/Ref($close,1)-1,20), "sh000300")', 'Mask(Mean($volume,20)/$volume, "sh000300")', 'Mask(Std($volume,20)/$volume, "sh000300")', 'Mask(Mean($close/Ref($close,1)-1,30), "sh000300")', 'Mask(Std($close/Ref($close,1)-1,30), "sh000300")', 'Mask(Mean($volume,30)/$volume, "sh000300")', 'Mask(Std($volume,30)/$volume, "sh000300")', 'Mask(Mean($close/Ref($close,1)-1,60), "sh000300")', 'Mask(Std($close/Ref($close,1)-1,60), "sh000300")', 'Mask(Mean($volume,60)/$volume, "sh000300")', 'Mask(Std($volume,60)/$volume, "sh000300")', + 'Mask($close/Ref($close,1)-1, "sh000903")', 'Mask(Mean($close/Ref($close,1)-1,5), "sh000903")', 'Mask(Std($close/Ref($close,1)-1,5), "sh000903")', 'Mask(Mean($volume,5)/$volume, "sh000903")', 'Mask(Std($volume,5)/$volume, "sh000903")', 'Mask(Mean($close/Ref($close,1)-1,10), "sh000903")', 'Mask(Std($close/Ref($close,1)-1,10), "sh000903")', 'Mask(Mean($volume,10)/$volume, "sh000903")', 'Mask(Std($volume,10)/$volume, "sh000903")', 'Mask(Mean($close/Ref($close,1)-1,20), "sh000903")', 'Mask(Std($close/Ref($close,1)-1,20), "sh000903")', 'Mask(Mean($volume,20)/$volume, "sh000903")', 'Mask(Std($volume,20)/$volume, "sh000903")', 'Mask(Mean($close/Ref($close,1)-1,30), "sh000903")', 'Mask(Std($close/Ref($close,1)-1,30), "sh000903")', 'Mask(Mean($volume,30)/$volume, "sh000903")', 'Mask(Std($volume,30)/$volume, "sh000903")', 'Mask(Mean($close/Ref($close,1)-1,60), "sh000903")', 'Mask(Std($close/Ref($close,1)-1,60), "sh000903")', 'Mask(Mean($volume,60)/$volume, "sh000903")', 'Mask(Std($volume,60)/$volume, "sh000903")', + 'Mask($close/Ref($close,1)-1, "sh000905")', 'Mask(Mean($close/Ref($close,1)-1,5), "sh000905")', 'Mask(Std($close/Ref($close,1)-1,5), "sh000905")', 'Mask(Mean($volume,5)/$volume, "sh000905")', 'Mask(Std($volume,5)/$volume, "sh000905")', 'Mask(Mean($close/Ref($close,1)-1,10), "sh000905")', 'Mask(Std($close/Ref($close,1)-1,10), "sh000905")', 'Mask(Mean($volume,10)/$volume, "sh000905")', 'Mask(Std($volume,10)/$volume, "sh000905")', 'Mask(Mean($close/Ref($close,1)-1,20), "sh000905")', 'Mask(Std($close/Ref($close,1)-1,20), "sh000905")', 'Mask(Mean($volume,20)/$volume, "sh000905")', 'Mask(Std($volume,20)/$volume, "sh000905")', 'Mask(Mean($close/Ref($close,1)-1,30), "sh000905")', 'Mask(Std($close/Ref($close,1)-1,30), "sh000905")', 'Mask(Mean($volume,30)/$volume, "sh000905")', 'Mask(Std($volume,30)/$volume, "sh000905")', 'Mask(Mean($close/Ref($close,1)-1,60), "sh000905")', 'Mask(Std($close/Ref($close,1)-1,60), "sh000905")', 'Mask(Mean($volume,60)/$volume, "sh000905")', 'Mask(Std($volume,60)/$volume, "sh000905")'], + ['Mask($close/Ref($close,1)-1, "sh000300")', 'Mask(Mean($close/Ref($close,1)-1,5), "sh000300")', 'Mask(Std($close/Ref($close,1)-1,5), "sh000300")', 'Mask(Mean($volume,5)/$volume, "sh000300")', 'Mask(Std($volume,5)/$volume, "sh000300")', 'Mask(Mean($close/Ref($close,1)-1,10), "sh000300")', 'Mask(Std($close/Ref($close,1)-1,10), "sh000300")', 'Mask(Mean($volume,10)/$volume, "sh000300")', 'Mask(Std($volume,10)/$volume, "sh000300")', 'Mask(Mean($close/Ref($close,1)-1,20), "sh000300")', 'Mask(Std($close/Ref($close,1)-1,20), "sh000300")', 'Mask(Mean($volume,20)/$volume, "sh000300")', 'Mask(Std($volume,20)/$volume, "sh000300")', 'Mask(Mean($close/Ref($close,1)-1,30), "sh000300")', 'Mask(Std($close/Ref($close,1)-1,30), "sh000300")', 'Mask(Mean($volume,30)/$volume, "sh000300")', 'Mask(Std($volume,30)/$volume, "sh000300")', 'Mask(Mean($close/Ref($close,1)-1,60), "sh000300")', 'Mask(Std($close/Ref($close,1)-1,60), "sh000300")', 'Mask(Mean($volume,60)/$volume, "sh000300")', 'Mask(Std($volume,60)/$volume, "sh000300")', + 'Mask($close/Ref($close,1)-1, "sh000903")', 'Mask(Mean($close/Ref($close,1)-1,5), "sh000903")', 'Mask(Std($close/Ref($close,1)-1,5), "sh000903")', 'Mask(Mean($volume,5)/$volume, "sh000903")', 'Mask(Std($volume,5)/$volume, "sh000903")', 'Mask(Mean($close/Ref($close,1)-1,10), "sh000903")', 'Mask(Std($close/Ref($close,1)-1,10), "sh000903")', 'Mask(Mean($volume,10)/$volume, "sh000903")', 'Mask(Std($volume,10)/$volume, "sh000903")', 'Mask(Mean($close/Ref($close,1)-1,20), "sh000903")', 'Mask(Std($close/Ref($close,1)-1,20), "sh000903")', 'Mask(Mean($volume,20)/$volume, "sh000903")', 'Mask(Std($volume,20)/$volume, "sh000903")', 'Mask(Mean($close/Ref($close,1)-1,30), "sh000903")', 'Mask(Std($close/Ref($close,1)-1,30), "sh000903")', 'Mask(Mean($volume,30)/$volume, "sh000903")', 'Mask(Std($volume,30)/$volume, "sh000903")', 'Mask(Mean($close/Ref($close,1)-1,60), "sh000903")', 'Mask(Std($close/Ref($close,1)-1,60), "sh000903")', 'Mask(Mean($volume,60)/$volume, "sh000903")', 'Mask(Std($volume,60)/$volume, "sh000903")', + 'Mask($close/Ref($close,1)-1, "sh000905")', 'Mask(Mean($close/Ref($close,1)-1,5), "sh000905")', 'Mask(Std($close/Ref($close,1)-1,5), "sh000905")', 'Mask(Mean($volume,5)/$volume, "sh000905")', 'Mask(Std($volume,5)/$volume, "sh000905")', 'Mask(Mean($close/Ref($close,1)-1,10), "sh000905")', 'Mask(Std($close/Ref($close,1)-1,10), "sh000905")', 'Mask(Mean($volume,10)/$volume, "sh000905")', 'Mask(Std($volume,10)/$volume, "sh000905")', 'Mask(Mean($close/Ref($close,1)-1,20), "sh000905")', 'Mask(Std($close/Ref($close,1)-1,20), "sh000905")', 'Mask(Mean($volume,20)/$volume, "sh000905")', 'Mask(Std($volume,20)/$volume, "sh000905")', 'Mask(Mean($close/Ref($close,1)-1,30), "sh000905")', 'Mask(Std($close/Ref($close,1)-1,30), "sh000905")', 'Mask(Mean($volume,30)/$volume, "sh000905")', 'Mask(Std($volume,30)/$volume, "sh000905")', 'Mask(Mean($close/Ref($close,1)-1,60), "sh000905")', 'Mask(Std($close/Ref($close,1)-1,60), "sh000905")', 'Mask(Mean($volume,60)/$volume, "sh000905")', 'Mask(Std($volume,60)/$volume, "sh000905")'] + ) + +class MASTERTSDatasetH(TSDatasetH): + def __init__( + self, + market_data_handler_config = Dict, + **kwargs, + ): + super().__init__(**kwargs) + marketdl = marketDataHandler(**market_data_handler_config) + self.market_dataset = DatasetH(marketdl, segments = self.segments) + + + def get_market_information( + self, + slc: slice, + ) -> Union[List[pd.DataFrame], pd.DataFrame]: + return self.market_dataset.prepare(slc) + + def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler: + dtype = kwargs.pop("dtype", None) + if not isinstance(slc, slice): + slc = slice(*slc) + start, end = slc.start, slc.stop + flt_col = kwargs.pop("flt_col", None) + # TSDatasetH will retrieve more data for complete time-series + + ext_slice = self._extend_slice(slc, self.cal, self.step_len) + only_label = kwargs.pop("only_label", False) + data = super(TSDatasetH, self)._prepare_seg(ext_slice, **kwargs) + + ############################## Add market information ########################### + if not only_label: + marketData = self.get_market_information(ext_slice) + cols = pd.MultiIndex.from_tuples([("feature", feature) for feature in marketData.columns]) + marketData = pd.DataFrame(marketData.values, columns = cols, index = marketData.index) + # print(marketData.index) + # print(marketData.columns) + # print(data.index) + # print(data.columns) + data = data.iloc[:,:-1].join(marketData).join(data.iloc[:,-1]) + # print(data.columns) + # print(data.shape) + ################################################################################# + flt_kwargs = copy.deepcopy(kwargs) + if flt_col is not None: + flt_kwargs["col_set"] = flt_col + flt_data = super()._prepare_seg(ext_slice, **flt_kwargs) + assert len(flt_data.columns) == 1 + else: + flt_data = None + + tsds = TSDataSampler( + data=data, + start=start, + end=end, + step_len=self.step_len, + dtype=dtype, + flt_data=flt_data, + fillna_type = "ffill+bfill" + ) + return tsds \ No newline at end of file diff --git a/qlib/contrib/model/pytorch_master_ts.py b/qlib/contrib/model/pytorch_master_ts.py new file mode 100644 index 0000000000..63e2871832 --- /dev/null +++ b/qlib/contrib/model/pytorch_master_ts.py @@ -0,0 +1,413 @@ +import numpy as np +import pandas as pd +import copy +from typing import Optional, List, Tuple, Union, Text +import tqdm +import pprint as pp +import math + +import torch +from torch.utils.data import DataLoader +from torch.utils.data import Sampler +from torch import nn +from torch.nn.modules.linear import Linear +from torch.nn.modules.dropout import Dropout +from torch.nn.modules.normalization import LayerNorm +import torch.optim as optim + +import qlib +# from qlib.utils import init_instance_by_config +# from qlib.data.dataset import Dataset, DataHandlerLP, DatasetH +from ...data.dataset import DatasetH +from ...data.dataset.handler import DataHandlerLP +from ...model.base import Model +# from qlib.contrib.data.dataset import TSDataSampler +# from qlib.workflow.record_temp import SigAnaRecord, PortAnaRecord +# from qlib.workflow import R, Experiment +# from qlib.workflow.task.utils import TimeAdjuster + +class PositionalEncoding(nn.Module): + def __init__(self, d_model, max_len=100): + super(PositionalEncoding, self).__init__() + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe) + + def forward(self, x): + return x + self.pe[:x.shape[1], :] + + +class SAttention(nn.Module): + def __init__(self, d_model, nhead, dropout): + super().__init__() + + self.d_model = d_model + self.nhead = nhead + self.temperature = math.sqrt(self.d_model/nhead) + + self.qtrans = nn.Linear(d_model, d_model, bias=False) + self.ktrans = nn.Linear(d_model, d_model, bias=False) + self.vtrans = nn.Linear(d_model, d_model, bias=False) + + attn_dropout_layer = [] + for i in range(nhead): + attn_dropout_layer.append(Dropout(p=dropout)) + self.attn_dropout = nn.ModuleList(attn_dropout_layer) + + # input LayerNorm + self.norm1 = LayerNorm(d_model, eps=1e-5) + + # FFN layerNorm + self.norm2 = LayerNorm(d_model, eps=1e-5) + self.ffn = nn.Sequential( + Linear(d_model, d_model), + nn.ReLU(), + Dropout(p=dropout), + Linear(d_model, d_model), + Dropout(p=dropout) + ) + + def forward(self, x): + x = self.norm1(x) + q = self.qtrans(x).transpose(0,1) + k = self.ktrans(x).transpose(0,1) + v = self.vtrans(x).transpose(0,1) + + dim = int(self.d_model/self.nhead) + att_output = [] + for i in range(self.nhead): + if i==self.nhead-1: + qh = q[:, :, i * dim:] + kh = k[:, :, i * dim:] + vh = v[:, :, i * dim:] + else: + qh = q[:, :, i * dim:(i + 1) * dim] + kh = k[:, :, i * dim:(i + 1) * dim] + vh = v[:, :, i * dim:(i + 1) * dim] + + atten_ave_matrixh = torch.softmax(torch.matmul(qh, kh.transpose(1, 2)) / self.temperature, dim=-1) + if self.attn_dropout: + atten_ave_matrixh = self.attn_dropout[i](atten_ave_matrixh) + att_output.append(torch.matmul(atten_ave_matrixh, vh).transpose(0, 1)) + att_output = torch.concat(att_output, dim=-1) + + # FFN + xt = x + att_output + xt = self.norm2(xt) + att_output = xt + self.ffn(xt) + + return att_output + + +class TAttention(nn.Module): + def __init__(self, d_model, nhead, dropout): + super().__init__() + self.d_model = d_model + self.nhead = nhead + self.qtrans = nn.Linear(d_model, d_model, bias=False) + self.ktrans = nn.Linear(d_model, d_model, bias=False) + self.vtrans = nn.Linear(d_model, d_model, bias=False) + + self.attn_dropout = [] + if dropout > 0: + for i in range(nhead): + self.attn_dropout.append(Dropout(p=dropout)) + self.attn_dropout = nn.ModuleList(self.attn_dropout) + + # input LayerNorm + self.norm1 = LayerNorm(d_model, eps=1e-5) + # FFN layerNorm + self.norm2 = LayerNorm(d_model, eps=1e-5) + # FFN + self.ffn = nn.Sequential( + Linear(d_model, d_model), + nn.ReLU(), + Dropout(p=dropout), + Linear(d_model, d_model), + Dropout(p=dropout) + ) + + def forward(self, x): + x = self.norm1(x) + q = self.qtrans(x) + k = self.ktrans(x) + v = self.vtrans(x) + + dim = int(self.d_model / self.nhead) + att_output = [] + for i in range(self.nhead): + if i==self.nhead-1: + qh = q[:, :, i * dim:] + kh = k[:, :, i * dim:] + vh = v[:, :, i * dim:] + else: + qh = q[:, :, i * dim:(i + 1) * dim] + kh = k[:, :, i * dim:(i + 1) * dim] + vh = v[:, :, i * dim:(i + 1) * dim] + atten_ave_matrixh = torch.softmax(torch.matmul(qh, kh.transpose(1, 2)), dim=-1) + if self.attn_dropout: + atten_ave_matrixh = self.attn_dropout[i](atten_ave_matrixh) + att_output.append(torch.matmul(atten_ave_matrixh, vh)) + att_output = torch.concat(att_output, dim=-1) + + # FFN + xt = x + att_output + xt = self.norm2(xt) + att_output = xt + self.ffn(xt) + + return att_output + + +class Gate(nn.Module): + def __init__(self, d_input, d_output, beta=1.0): + super().__init__() + self.trans = nn.Linear(d_input, d_output) + self.d_output =d_output + self.t = beta + + def forward(self, gate_input): + output = self.trans(gate_input) + output = torch.softmax(output/self.t, dim=-1) + return self.d_output*output + + +class TemporalAttention(nn.Module): + def __init__(self, d_model): + super().__init__() + self.trans = nn.Linear(d_model, d_model, bias=False) + + def forward(self, z): + h = self.trans(z) # [N, T, D] + query = h[:, -1, :].unsqueeze(-1) + lam = torch.matmul(h, query).squeeze(-1) # [N, T, D] --> [N, T] + lam = torch.softmax(lam, dim=1).unsqueeze(1) + output = torch.matmul(lam, z).squeeze(1) # [N, 1, T], [N, T, D] --> [N, 1, D] + return output + + +class MASTER(nn.Module): + def __init__(self, d_feat=158, d_model=256, t_nhead=4, s_nhead=2, T_dropout_rate=0.5, S_dropout_rate=0.5, + gate_input_start_index=158, gate_input_end_index=221, beta=None): + super(MASTER, self).__init__() + # market + self.gate_input_start_index = gate_input_start_index + self.gate_input_end_index = gate_input_end_index + self.d_gate_input = (gate_input_end_index - gate_input_start_index) # F' + self.feature_gate = Gate(self.d_gate_input, d_feat, beta=beta) + + self.x2y = nn.Linear(d_feat, d_model) + self.pe = PositionalEncoding(d_model) + self.tatten = TAttention(d_model=d_model, nhead=t_nhead, dropout=T_dropout_rate) + self.satten = SAttention(d_model=d_model, nhead=s_nhead, dropout=S_dropout_rate) + self.temporalatten = TemporalAttention(d_model=d_model) + self.decoder = nn.Linear(d_model, 1) + + + def forward(self, x): + src = x[:, :, :self.gate_input_start_index] # N, T, D + gate_input = x[:, -1, self.gate_input_start_index:self.gate_input_end_index] + src = src * torch.unsqueeze(self.feature_gate(gate_input), dim=1) + + x = self.x2y(src) + x = self.pe(x) + x = self.tatten(x) + x = self.satten(x) + x = self.temporalatten(x) + output = self.decoder(x).squeeze(-1) + + + return output + +def calc_ic(pred, label): + df = pd.DataFrame({'pred':pred, 'label':label}) + ic = df['pred'].corr(df['label']) + ric = df['pred'].corr(df['label'], method='spearman') + return ic, ric + + +class DailyBatchSamplerRandom(Sampler): + def __init__(self, data_source, shuffle=False): + self.data_source = data_source + self.shuffle = shuffle + # calculate number of samples in each batch + self.daily_count = pd.Series(index=self.data_source.get_index()).groupby("datetime").size().values + self.daily_index = np.roll(np.cumsum(self.daily_count), 1) # calculate begin index of each batch + self.daily_index[0] = 0 + + def __iter__(self): + if self.shuffle: + index = np.arange(len(self.daily_count)) + np.random.shuffle(index) + for i in index: + yield np.arange(self.daily_index[i], self.daily_index[i] + self.daily_count[i]) + else: + for idx, count in zip(self.daily_index, self.daily_count): + yield np.arange(idx, idx + count) + + def __len__(self): + return len(self.data_source) + + +class MASTERModel(Model): + def __init__(self, d_feat: int = 158, d_model: int = 256, t_nhead: int = 4, s_nhead: int = 2, gate_input_start_index=158, gate_input_end_index=221, + T_dropout_rate=0.5, S_dropout_rate=0.5, beta=None, n_epochs = 40, lr = 8e-6, GPU=0, seed=0, train_stop_loss_thred=None, save_path = 'model/', save_prefix= '', benchmark = 'SH000300', market = 'csi300', only_backtest = False): + + self.d_model = d_model + self.d_feat = d_feat + + self.gate_input_start_index = gate_input_start_index + self.gate_input_end_index = gate_input_end_index + + self.T_dropout_rate = T_dropout_rate + self.S_dropout_rate = S_dropout_rate + self.t_nhead = t_nhead + self.s_nhead = s_nhead + self.beta = beta + + self.n_epochs = n_epochs + self.lr = lr + self.device = torch.device(f"cuda:{GPU}" if torch.cuda.is_available() else "cpu") + self.seed = seed + self.train_stop_loss_thred = train_stop_loss_thred + self.benchmark = benchmark + self.market = market + self.infer_exp_name = f"{self.market}_MASTER_seed{self.seed}_backtest" + + self.fitted = False + if self.market == 'csi300': + self.beta = 10 + else: + self.beta = 5 + if self.seed is not None: + np.random.seed(self.seed) + torch.manual_seed(self.seed) + self.model = MASTER(d_feat=self.d_feat, d_model=self.d_model, t_nhead=self.t_nhead, s_nhead=self.s_nhead, + T_dropout_rate=self.T_dropout_rate, S_dropout_rate=self.S_dropout_rate, + gate_input_start_index=self.gate_input_start_index, + gate_input_end_index=self.gate_input_end_index, beta=self.beta) + self.train_optimizer = optim.Adam(self.model.parameters(), self.lr) + self.model.to(self.device) + + self.save_path = save_path + self.save_prefix = save_prefix + self.only_backtest = only_backtest + + def init_model(self): + if self.model is None: + raise ValueError("model has not been initialized") + + self.train_optimizer = optim.Adam(self.model.parameters(), self.lr) + self.model.to(self.device) + + def load_model(self, param_path): + try: + self.model.load_state_dict(torch.load(param_path, map_location=self.device)) + self.fitted = True + except: + raise ValueError("Model not found.") + + def loss_fn(self, pred, label): + mask = ~torch.isnan(label) + loss = (pred[mask]-label[mask])**2 + return torch.mean(loss) + + def train_epoch(self, data_loader): + self.model.train() + losses = [] + + for data in data_loader: + data = torch.squeeze(data, dim=0) + ''' + data.shape: (N, T, F) + N - number of stocks + T - length of lookback_window, 8 + F - 158 factors + 63 market information + 1 label + ''' + feature = data[:, :, 0:-1].to(self.device) + label = data[:, -1, -1].to(self.device) + assert not torch.any(torch.isnan(label)) + + pred = self.model(feature.float()) + loss = self.loss_fn(pred, label) + losses.append(loss.item()) + + self.train_optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0) + self.train_optimizer.step() + + return float(np.mean(losses)) + + def test_epoch(self, data_loader): + self.model.eval() + losses = [] + + for data in data_loader: + data = torch.squeeze(data, dim=0) + feature = data[:, :, 0:-1].to(self.device) + label = data[:, -1, -1].to(self.device) + pred = self.model(feature.float()) + loss = self.loss_fn(pred, label) + losses.append(loss.item()) + + return float(np.mean(losses)) + + def _init_data_loader(self, data, shuffle=True, drop_last=True): + sampler = DailyBatchSamplerRandom(data, shuffle) + data_loader = DataLoader(data, sampler=sampler, drop_last=drop_last) + return data_loader + + def load_param(self, param_path): + self.model.load_state_dict(torch.load(param_path, map_location=self.device)) + self.fitted = True + + def fit(self, dataset: DatasetH): + dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) + dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) + train_loader = self._init_data_loader(dl_train, shuffle=True, drop_last=True) + valid_loader = self._init_data_loader(dl_valid, shuffle=False, drop_last=True) + + self.fitted = True + best_param = None + best_val_loss = 1e3 + + for step in range(self.n_epochs): + train_loss = self.train_epoch(train_loader) + val_loss = self.test_epoch(valid_loader) + + print("Epoch %d, train_loss %.6f, valid_loss %.6f " % (step, train_loss, val_loss)) + if best_val_loss > val_loss: + best_param = copy.deepcopy(self.model.state_dict()) + best_val_loss = val_loss + + if train_loss <= self.train_stop_loss_thred: + break + torch.save(best_param, f'{self.save_path}{self.save_prefix}master_{self.seed}.pkl') + + def predict(self, dataset: DatasetH, use_pretrained = True): + if use_pretrained: + self.load_param(f'{self.save_path}{self.save_prefix}master_{self.seed}.pkl') + if not self.fitted: + raise ValueError("model is not fitted yet!") + + dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I) + test_loader = self._init_data_loader(dl_test, shuffle=False, drop_last=False) + + pred_all = [] + + self.model.eval() + for data in test_loader: + data = torch.squeeze(data, dim=0) + feature = data[:, :, 0:-1].to(self.device) + with torch.no_grad(): + pred = self.model(feature.float()).detach().cpu().numpy() + pred_all.append(pred.ravel()) + + + pred_all = pd.DataFrame(np.concatenate(pred_all), index=dl_test.get_index()) + # pred_all = pred_all.loc[self.label_all.index] + # rec = self.backtest() + return pred_all diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index aacd58389a..f014a78f13 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -272,6 +272,375 @@ def _get_extrema(segments, idx: int, cmp: Callable, key_func=pd.Timestamp): return candidate + +# class TSDataSampler: +# """ +# (T)ime-(S)eries DataSampler +# This is the result of TSDatasetH + +# It works like `torch.data.utils.Dataset`, it provides a very convenient interface for constructing time-series +# dataset based on tabular data. +# - On time step dimension, the smaller index indicates the historical data and the larger index indicates the future +# data. + +# If user have further requirements for processing data, user could process them based on `TSDataSampler` or create +# more powerful subclasses. + +# Known Issues: +# - For performance issues, this Sampler will convert dataframe into arrays for better performance. This could result +# in a different data type + + +# Indices design: +# TSDataSampler has a index mechanism to help users query time-series data efficiently. + +# The definition of related variables: +# data_arr: np.ndarray +# The original data. it will contains all the original data. +# The querying are often for time-series of a specific stock. +# By leveraging this data charactoristics to speed up querying, the multi-index of data_arr is rearranged in (instrument, datetime) order + +# data_index: pd.MultiIndex with index order +# it has the same shape with `idx_map`. Each elements of them are expected to be aligned. + +# idx_map: np.ndarray +# It is the indexable data. It originates from data_arr, and then filtered by 1) `start` and `end` 2) `flt_data` +# The extra data in data_arr is useful in following cases +# 1) creating meaningful time series data before `start` instead of padding them with zeros +# 2) some data are excluded by `flt_data` (e.g. no sample pair for that index). but they are still used in time-series in X + +# Finnally, it will look like. + +# array([[ 0, 0], +# [ 1, 0], +# [ 2, 0], +# ..., +# [241, 348], +# [242, 348], +# [243, 348]], dtype=int32) + +# It list all indexable data(some data only used in historical time series data may not be indexabla), the values are the corresponding row and col in idx_df +# idx_df: pd.DataFrame +# It aims to map the key to the original position in data_arr + +# For example, it may look like (NOTE: the index for a instrument time-series is continoues in memory) + +# instrument SH600000 SH600008 SH600009 SH600010 SH600011 SH600015 ... +# datetime +# 2017-01-03 0 242 473 717 NaN 974 ... +# 2017-01-04 1 243 474 718 NaN 975 ... +# 2017-01-05 2 244 475 719 NaN 976 ... +# 2017-01-06 3 245 476 720 NaN 977 ... + +# With these two indices(idx_map, idx_df) and original data(data_arr), we can make the following queries fast (implemented in __getitem__) +# (1) Get the i-th indexable sample(time-series): (indexable sample index) -> [idx_map] -> (row col) -> [idx_df] -> (index in data_arr) +# (2) Get the specific sample by : (, i.e. ) -> [idx_df] -> (index in data_arr) +# (3) Get the index of a time-series data: (get the , refer to (1), (2)) -> [idx_df] -> (all indices in data_arr for time-series) +# """ + +# # Please refer to the docstring of TSDataSampler for the definition of following attributes +# data_arr: np.ndarray +# data_index: pd.MultiIndex +# idx_map: np.ndarray +# idx_df: pd.DataFrame + +# def __init__( +# self, +# data: pd.DataFrame, +# start, +# end, +# step_len: int, +# fillna_type: str = "none", +# dtype=None, +# flt_data=None, +# ): +# """ +# Build a dataset which looks like torch.data.utils.Dataset. + +# Parameters +# ---------- +# data : pd.DataFrame +# The raw tabular data whose index order is <"datetime", "instrument"> +# start : +# The indexable start time +# end : +# The indexable end time +# step_len : int +# The length of the time-series step +# fillna_type : int +# How will qlib handle the sample if there is on sample in a specific date. +# none: +# fill with np.nan +# ffill: +# ffill with previous sample +# ffill+bfill: +# ffill with previous samples first and fill with later samples second +# flt_data : pd.Series +# a column of data(True or False) to filter data. Its index order is <"datetime", "instrument"> +# None: +# kepp all data + +# """ +# self.start = start +# self.end = end +# self.step_len = step_len +# self.fillna_type = fillna_type +# assert get_level_index(data, "datetime") == 0 +# self.data = data.swaplevel().sort_index().copy() +# data.drop( +# data.columns, axis=1, inplace=True +# ) # data is useless since it's passed to a transposed one, hard code to free the memory of this dataframe to avoid three big dataframe in the memory(including: data, self.data, self.data_arr) + +# kwargs = {"object": self.data} +# if dtype is not None: +# kwargs["dtype"] = dtype + +# self.data_arr = np.array(**kwargs) # Get index from numpy.array will much faster than DataFrame.values! +# # NOTE: +# # - append last line with full NaN for better performance in `__getitem__` +# # - Keep the same dtype will result in a better performance +# self.data_arr = np.append( +# self.data_arr, +# np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype), +# axis=0, +# ) +# self.nan_idx = -1 # The last line is all NaN + +# # the data type will be changed +# # The index of usable data is between start_idx and end_idx +# self.idx_df, self.idx_map = self.build_index(self.data) +# self.data_index = deepcopy(self.data.index) + +# if flt_data is not None: +# if isinstance(flt_data, pd.DataFrame): +# assert len(flt_data.columns) == 1 +# flt_data = flt_data.iloc[:, 0] +# # NOTE: bool(np.nan) is True !!!!!!!! +# # make sure reindex comes first. Otherwise extra NaN may appear. +# flt_data = flt_data.swaplevel() +# flt_data = flt_data.reindex(self.data_index).fillna(False).astype(bool) +# self.flt_data = flt_data.values +# self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map) +# self.data_index = self.data_index[np.where(self.flt_data)[0]] +# self.idx_map = self.idx_map2arr(self.idx_map) +# self.idx_map, self.data_index = self.slice_idx_map_and_data_index( +# self.idx_map, self.idx_df, self.data_index, start, end +# ) + +# self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance +# del self.data # save memory + +# @staticmethod +# def slice_idx_map_and_data_index( +# idx_map, +# idx_df, +# data_index, +# start, +# end, +# ): +# assert ( +# len(idx_map) == data_index.shape[0] +# ) # make sure idx_map and data_index is same so index of idx_map can be used on data_index + +# start_row_idx, end_row_idx = idx_df.index.slice_locs(start=time_to_slc_point(start), end=time_to_slc_point(end)) + +# time_flter_idx = (idx_map[:, 0] < end_row_idx) & (idx_map[:, 0] >= start_row_idx) +# return idx_map[time_flter_idx], data_index[time_flter_idx] + +# @staticmethod +# def idx_map2arr(idx_map): +# # pytorch data sampler will have better memory control without large dict or list +# # - https://github.com/pytorch/pytorch/issues/13243 +# # - https://github.com/airctic/icevision/issues/613 +# # So we convert the dict into int array. +# # The arr_map is expected to behave the same as idx_map + +# dtype = np.int32 +# # set a index out of bound to indicate the none existing +# no_existing_idx = (np.iinfo(dtype).max, np.iinfo(dtype).max) + +# max_idx = max(idx_map.keys()) +# arr_map = [] +# for i in range(max_idx + 1): +# arr_map.append(idx_map.get(i, no_existing_idx)) +# arr_map = np.array(arr_map, dtype=dtype) +# return arr_map + +# @staticmethod +# def flt_idx_map(flt_data, idx_map): +# idx = 0 +# new_idx_map = {} +# for i, exist in enumerate(flt_data): +# if exist: +# new_idx_map[idx] = idx_map[i] +# idx += 1 +# return new_idx_map + +# def get_index(self): +# """ +# Get the pandas index of the data, it will be useful in following scenarios +# - Special sampler will be used (e.g. user want to sample day by day) +# """ +# return self.data_index.swaplevel() # to align the order of multiple index of original data received by __init__ + +# def config(self, **kwargs): +# # Config the attributes +# for k, v in kwargs.items(): +# setattr(self, k, v) + +# @staticmethod +# def build_index(data: pd.DataFrame) -> Tuple[pd.DataFrame, dict]: +# """ +# The relation of the data + +# Parameters +# ---------- +# data : pd.DataFrame +# A DataFrame with index in order + +# RSQR5 RESI5 WVMA5 LABEL0 +# instrument datetime +# SH600000 2017-01-03 0.016389 0.461632 -1.154788 -0.048056 +# 2017-01-04 0.884545 -0.110597 -1.059332 -0.030139 +# 2017-01-05 0.507540 -0.535493 -1.099665 -0.644983 +# 2017-01-06 -1.267771 -0.669685 -1.636733 0.295366 +# 2017-01-09 0.339346 0.074317 -0.984989 0.765540 + +# Returns +# ------- +# Tuple[pd.DataFrame, dict]: +# 1) the first element: reshape the original index into a 2D dataframe +# instrument SH600000 SH600008 SH600009 SH600010 SH600011 SH600015 ... +# datetime +# 2017-01-03 0 242 473 717 NaN 974 ... +# 2017-01-04 1 243 474 718 NaN 975 ... +# 2017-01-05 2 244 475 719 NaN 976 ... +# 2017-01-06 3 245 476 720 NaN 977 ... +# 2) the second element: {: } +# """ +# # object incase of pandas converting int to float +# idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object) +# idx_df = lazy_sort_index(idx_df.unstack()) +# # NOTE: the correctness of `__getitem__` depends on columns sorted here +# idx_df = lazy_sort_index(idx_df, axis=1).T + +# idx_map = {} +# for i, (_, row) in enumerate(idx_df.iterrows()): +# for j, real_idx in enumerate(row): +# if not np.isnan(real_idx): +# idx_map[real_idx] = (i, j) +# return idx_df, idx_map + +# @property +# def empty(self): +# return len(self) == 0 + +# def _get_indices(self, row: int, col: int) -> np.array: +# """ +# get series indices of self.data_arr from the row, col indices of self.idx_df + +# Parameters +# ---------- +# row : int +# the row in self.idx_df +# col : int +# the col in self.idx_df + +# Returns +# ------- +# np.array: +# The indices of data of the data +# """ +# indices = self.idx_arr[max(row - self.step_len + 1, 0) : row + 1, col] + +# if len(indices) < self.step_len: +# indices = np.concatenate([np.full((self.step_len - len(indices),), np.nan), indices]) + +# if self.fillna_type == "ffill": +# indices = np_ffill(indices) +# elif self.fillna_type == "ffill+bfill": +# indices = np_ffill(np_ffill(indices)[::-1])[::-1] +# else: +# assert self.fillna_type == "none" +# return indices + +# def _get_row_col(self, idx) -> Tuple[int]: +# """ +# get the col index and row index of a given sample index in self.idx_df + +# Parameters +# ---------- +# idx : +# the input of `__getitem__` + +# Returns +# ------- +# Tuple[int]: +# the row and col index +# """ +# # The the right row number `i` and col number `j` in idx_df +# if isinstance(idx, (int, np.integer)): +# real_idx = idx +# if 0 <= real_idx < len(self.idx_map): +# i, j = self.idx_map[real_idx] # TODO: The performance of this line is not good +# else: +# raise KeyError(f"{real_idx} is out of [0, {len(self.idx_map)})") +# elif isinstance(idx, tuple): +# # ["datetime", "instruments"] +# date, inst = idx +# date = pd.Timestamp(date) +# i = bisect.bisect_right(self.idx_df.index, date) - 1 +# # NOTE: This relies on the idx_df columns sorted in `__init__` +# j = bisect.bisect_left(self.idx_df.columns, inst) +# else: +# raise NotImplementedError(f"This type of input is not supported") +# return i, j + +# def __getitem__(self, idx: Union[int, Tuple[object, str], List[int]]): +# """ +# # We have two method to get the time-series of a sample +# tsds is a instance of TSDataSampler + +# # 1) sample by int index directly +# tsds[len(tsds) - 1] + +# # 2) sample by index +# tsds['2016-12-31', "SZ300315"] + +# # The return value will be similar to the data retrieved by following code +# df.loc(axis=0)['2015-01-01':'2016-12-31', "SZ300315"].iloc[-30:] + +# Parameters +# ---------- +# idx : Union[int, Tuple[object, str]] +# """ +# # Multi-index type +# mtit = (list, np.ndarray) +# if isinstance(idx, mtit): +# indices = [self._get_indices(*self._get_row_col(i)) for i in idx] +# indices = np.concatenate(indices) +# else: +# indices = self._get_indices(*self._get_row_col(idx)) + +# # 1) for better performance, use the last nan line for padding the lost date +# # 2) In case of precision problems. We use np.float64. # TODO: I'm not sure if whether np.float64 will result in +# # precision problems. It will not cause any problems in my tests at least +# indices = np.nan_to_num(indices.astype(np.float64), nan=self.nan_idx).astype(int) + +# if (np.diff(indices) == 1).all(): # slicing instead of indexing for speeding up. +# data = self.data_arr[indices[0] : indices[-1] + 1] +# else: +# data = self.data_arr[indices] +# if isinstance(idx, mtit): +# # if we get multiple indexes, addition dimension should be added. +# # +# data = data.reshape(-1, self.step_len, *data.shape[1:]) +# return data + +# def __len__(self): +# return len(self.idx_map) + +# v0.8.6 class TSDataSampler: """ (T)ime-(S)eries DataSampler @@ -289,69 +658,10 @@ class TSDataSampler: - For performance issues, this Sampler will convert dataframe into arrays for better performance. This could result in a different data type - - Indices design: - TSDataSampler has a index mechanism to help users query time-series data efficiently. - - The definition of related variables: - data_arr: np.ndarray - The original data. it will contains all the original data. - The querying are often for time-series of a specific stock. - By leveraging this data charactoristics to speed up querying, the multi-index of data_arr is rearranged in (instrument, datetime) order - - data_index: pd.MultiIndex with index order - it has the same shape with `idx_map`. Each elements of them are expected to be aligned. - - idx_map: np.ndarray - It is the indexable data. It originates from data_arr, and then filtered by 1) `start` and `end` 2) `flt_data` - The extra data in data_arr is useful in following cases - 1) creating meaningful time series data before `start` instead of padding them with zeros - 2) some data are excluded by `flt_data` (e.g. no sample pair for that index). but they are still used in time-series in X - - Finnally, it will look like. - - array([[ 0, 0], - [ 1, 0], - [ 2, 0], - ..., - [241, 348], - [242, 348], - [243, 348]], dtype=int32) - - It list all indexable data(some data only used in historical time series data may not be indexabla), the values are the corresponding row and col in idx_df - idx_df: pd.DataFrame - It aims to map the key to the original position in data_arr - - For example, it may look like (NOTE: the index for a instrument time-series is continoues in memory) - - instrument SH600000 SH600008 SH600009 SH600010 SH600011 SH600015 ... - datetime - 2017-01-03 0 242 473 717 NaN 974 ... - 2017-01-04 1 243 474 718 NaN 975 ... - 2017-01-05 2 244 475 719 NaN 976 ... - 2017-01-06 3 245 476 720 NaN 977 ... - - With these two indices(idx_map, idx_df) and original data(data_arr), we can make the following queries fast (implemented in __getitem__) - (1) Get the i-th indexable sample(time-series): (indexable sample index) -> [idx_map] -> (row col) -> [idx_df] -> (index in data_arr) - (2) Get the specific sample by : (, i.e. ) -> [idx_df] -> (index in data_arr) - (3) Get the index of a time-series data: (get the , refer to (1), (2)) -> [idx_df] -> (all indices in data_arr for time-series) """ - # Please refer to the docstring of TSDataSampler for the definition of following attributes - data_arr: np.ndarray - data_index: pd.MultiIndex - idx_map: np.ndarray - idx_df: pd.DataFrame - def __init__( - self, - data: pd.DataFrame, - start, - end, - step_len: int, - fillna_type: str = "none", - dtype=None, - flt_data=None, + self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none", dtype=None, flt_data=None ): """ Build a dataset which looks like torch.data.utils.Dataset. @@ -359,7 +669,7 @@ def __init__( Parameters ---------- data : pd.DataFrame - The raw tabular data whose index order is <"datetime", "instrument"> + The raw tabular data start : The indexable start time end : @@ -375,7 +685,7 @@ def __init__( ffill+bfill: ffill with previous samples first and fill with later samples second flt_data : pd.Series - a column of data(True or False) to filter data. Its index order is <"datetime", "instrument"> + a column of data(True or False) to filter data. None: kepp all data @@ -385,10 +695,7 @@ def __init__( self.step_len = step_len self.fillna_type = fillna_type assert get_level_index(data, "datetime") == 0 - self.data = data.swaplevel().sort_index().copy() - data.drop( - data.columns, axis=1, inplace=True - ) # data is useless since it's passed to a transposed one, hard code to free the memory of this dataframe to avoid three big dataframe in the memory(including: data, self.data, self.data_arr) + self.data = lazy_sort_index(data) kwargs = {"object": self.data} if dtype is not None: @@ -399,9 +706,7 @@ def __init__( # - append last line with full NaN for better performance in `__getitem__` # - Keep the same dtype will result in a better performance self.data_arr = np.append( - self.data_arr, - np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype), - axis=0, + self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype), axis=0 ) self.nan_idx = -1 # The last line is all NaN @@ -416,35 +721,18 @@ def __init__( flt_data = flt_data.iloc[:, 0] # NOTE: bool(np.nan) is True !!!!!!!! # make sure reindex comes first. Otherwise extra NaN may appear. - flt_data = flt_data.swaplevel() - flt_data = flt_data.reindex(self.data_index).fillna(False).astype(bool) + flt_data = flt_data.reindex(self.data_index).fillna(False).astype(np.bool) self.flt_data = flt_data.values self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map) self.data_index = self.data_index[np.where(self.flt_data)[0]] self.idx_map = self.idx_map2arr(self.idx_map) - self.idx_map, self.data_index = self.slice_idx_map_and_data_index( - self.idx_map, self.idx_df, self.data_index, start, end - ) + self.start_idx, self.end_idx = self.data_index.slice_locs( + start=time_to_slc_point(start), end=time_to_slc_point(end) + ) self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance - del self.data # save memory - - @staticmethod - def slice_idx_map_and_data_index( - idx_map, - idx_df, - data_index, - start, - end, - ): - assert ( - len(idx_map) == data_index.shape[0] - ) # make sure idx_map and data_index is same so index of idx_map can be used on data_index - start_row_idx, end_row_idx = idx_df.index.slice_locs(start=time_to_slc_point(start), end=time_to_slc_point(end)) - - time_flter_idx = (idx_map[:, 0] < end_row_idx) & (idx_map[:, 0] >= start_row_idx) - return idx_map[time_flter_idx], data_index[time_flter_idx] + del self.data # save memory @staticmethod def idx_map2arr(idx_map): @@ -480,7 +768,7 @@ def get_index(self): Get the pandas index of the data, it will be useful in following scenarios - Special sampler will be used (e.g. user want to sample day by day) """ - return self.data_index.swaplevel() # to align the order of multiple index of original data received by __init__ + return self.data_index[self.start_idx : self.end_idx] def config(self, **kwargs): # Config the attributes @@ -495,33 +783,25 @@ def build_index(data: pd.DataFrame) -> Tuple[pd.DataFrame, dict]: Parameters ---------- data : pd.DataFrame - A DataFrame with index in order - - RSQR5 RESI5 WVMA5 LABEL0 - instrument datetime - SH600000 2017-01-03 0.016389 0.461632 -1.154788 -0.048056 - 2017-01-04 0.884545 -0.110597 -1.059332 -0.030139 - 2017-01-05 0.507540 -0.535493 -1.099665 -0.644983 - 2017-01-06 -1.267771 -0.669685 -1.636733 0.295366 - 2017-01-09 0.339346 0.074317 -0.984989 0.765540 + The dataframe with Returns ------- Tuple[pd.DataFrame, dict]: 1) the first element: reshape the original index into a 2D dataframe - instrument SH600000 SH600008 SH600009 SH600010 SH600011 SH600015 ... + instrument SH600000 SH600004 SH600006 SH600007 SH600008 SH600009 ... datetime - 2017-01-03 0 242 473 717 NaN 974 ... - 2017-01-04 1 243 474 718 NaN 975 ... - 2017-01-05 2 244 475 719 NaN 976 ... - 2017-01-06 3 245 476 720 NaN 977 ... + 2021-01-11 0 1 2 3 4 5 ... + 2021-01-12 4146 4147 4148 4149 4150 4151 ... + 2021-01-13 8293 8294 8295 8296 8297 8298 ... + 2021-01-14 12441 12442 12443 12444 12445 12446 ... 2) the second element: {: } """ # object incase of pandas converting int to float idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object) idx_df = lazy_sort_index(idx_df.unstack()) # NOTE: the correctness of `__getitem__` depends on columns sorted here - idx_df = lazy_sort_index(idx_df, axis=1).T + idx_df = lazy_sort_index(idx_df, axis=1) idx_map = {} for i, (_, row) in enumerate(idx_df.iterrows()): @@ -579,11 +859,11 @@ def _get_row_col(self, idx) -> Tuple[int]: """ # The the right row number `i` and col number `j` in idx_df if isinstance(idx, (int, np.integer)): - real_idx = idx - if 0 <= real_idx < len(self.idx_map): + real_idx = self.start_idx + idx + if self.start_idx <= real_idx < self.end_idx: i, j = self.idx_map[real_idx] # TODO: The performance of this line is not good else: - raise KeyError(f"{real_idx} is out of [0, {len(self.idx_map)})") + raise KeyError(f"{real_idx} is out of [{self.start_idx}, {self.end_idx})") elif isinstance(idx, tuple): # ["datetime", "instruments"] date, inst = idx @@ -626,10 +906,7 @@ def __getitem__(self, idx: Union[int, Tuple[object, str], List[int]]): # precision problems. It will not cause any problems in my tests at least indices = np.nan_to_num(indices.astype(np.float64), nan=self.nan_idx).astype(int) - if (np.diff(indices) == 1).all(): # slicing instead of indexing for speeding up. - data = self.data_arr[indices[0] : indices[-1] + 1] - else: - data = self.data_arr[indices] + data = self.data_arr[indices] if isinstance(idx, mtit): # if we get multiple indexes, addition dimension should be added. # @@ -637,9 +914,8 @@ def __getitem__(self, idx: Union[int, Tuple[object, str], List[int]]): return data def __len__(self): - return len(self.idx_map) - - + return self.end_idx - self.start_idx + class TSDatasetH(DatasetH): """ (T)ime-(S)eries Dataset (H)andler diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py index 714693d181..e3bc710c56 100644 --- a/qlib/data/dataset/processor.py +++ b/qlib/data/dataset/processor.py @@ -194,7 +194,12 @@ def __call__(self, df): # So we use numpy to accelerate filling values nan_select = np.isnan(df.values) nan_select[:, ~df.columns.isin(cols)] = False - df.values[nan_select] = self.fill_value + # df.values[nan_select] = self.fill_value + + # lqa's method + value_tmp = df.values + value_tmp[nan_select] = self.fill_value + df = pd.DataFrame(value_tmp, columns = df.columns, index = df.index) return df