Skip to content

Commit

Permalink
add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiyuan5986 committed Mar 30, 2024
1 parent e5513dd commit 258b187
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 11 deletions.
12 changes: 10 additions & 2 deletions examples/benchmarks/MASTER/README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
## overview
This is an alternative version of the MASTER benchmark.

paper: [MASTER: Market-Guided Stock Transformer for Stock Price Forecasting](https://arxiv.org/abs/2312.15235)

codes: [https://github.com/SJTU-Quant/MASTER](https://github.com/SJTU-Quant/MASTER)

## run
You can directly use the bash script to run the codes (you can set the `universe` and `only_backtest` flag in `run.sh`), this `main.py` will test the model with 10 random seeds:
```
bash run.sh
```
or you can just directly use `qrun` tp run the codes:
<!-- or you can just directly use `qrun` tp run the codes (note that you should modify your `qlib`, since we add or modify some files in `qlib/contrib/data/dataset.py`, `qlib/data/dataset/__init__.py`, `qlib/data/dataset/processor.py` and `qlib/contrib/model/pytorch_master.py`):
```
qrun workflow_config_master_Alpha158.yaml
```
``` -->
4 changes: 1 addition & 3 deletions examples/benchmarks/MASTER/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,13 @@

import qlib
from qlib.constant import REG_CN
from qlib.utils import init_instance_by_config, flatten_dict
from qlib.utils import init_instance_by_config
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord, SigAnaRecord
from qlib.tests.data import GetData
from qlib.tests.config import CSI300_BENCH, CSI300_GBDT_TASK
import yaml
import argparse
import os
import torch
import pprint as pp
import numpy as np

Expand Down
33 changes: 27 additions & 6 deletions qlib/contrib/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,9 +354,25 @@ def __iter__(self):
}

# end indice loop

###################################################################################
# lqa: for MASTER
class marketDataHandler(DataHandlerLP):
"""Market Data Handler for MASTER (see `examples/benchmarks/MASTER`)
Args:
instruments (str): instrument list
start_time (str): start time
end_time (str): end time
freq (str): data frequency
infer_processors (list): inference processors
learn_processors (list): learning processors
fit_start_time (str): fit start time
fit_end_time (str): fit end time
process_type (str): process type
filter_pipe (list): filter pipe
inst_processors (list): instrument processors
"""
def __init__(
self,
instruments="csi300",
Expand Down Expand Up @@ -399,6 +415,10 @@ def __init__(

@staticmethod
def get_feature_config():
"""
Get market feature (63-dimensional), which are csi100 index, csi300 index, csi500 index.
The first list is the name to be shown for the feature, and the second list is the feature to fecth.
"""
return (
['Mask($close/Ref($close,1)-1, "sh000300")', 'Mask(Mean($close/Ref($close,1)-1,5), "sh000300")', 'Mask(Std($close/Ref($close,1)-1,5), "sh000300")', 'Mask(Mean($volume,5)/$volume, "sh000300")', 'Mask(Std($volume,5)/$volume, "sh000300")', 'Mask(Mean($close/Ref($close,1)-1,10), "sh000300")', 'Mask(Std($close/Ref($close,1)-1,10), "sh000300")', 'Mask(Mean($volume,10)/$volume, "sh000300")', 'Mask(Std($volume,10)/$volume, "sh000300")', 'Mask(Mean($close/Ref($close,1)-1,20), "sh000300")', 'Mask(Std($close/Ref($close,1)-1,20), "sh000300")', 'Mask(Mean($volume,20)/$volume, "sh000300")', 'Mask(Std($volume,20)/$volume, "sh000300")', 'Mask(Mean($close/Ref($close,1)-1,30), "sh000300")', 'Mask(Std($close/Ref($close,1)-1,30), "sh000300")', 'Mask(Mean($volume,30)/$volume, "sh000300")', 'Mask(Std($volume,30)/$volume, "sh000300")', 'Mask(Mean($close/Ref($close,1)-1,60), "sh000300")', 'Mask(Std($close/Ref($close,1)-1,60), "sh000300")', 'Mask(Mean($volume,60)/$volume, "sh000300")', 'Mask(Std($volume,60)/$volume, "sh000300")',
'Mask($close/Ref($close,1)-1, "sh000903")', 'Mask(Mean($close/Ref($close,1)-1,5), "sh000903")', 'Mask(Std($close/Ref($close,1)-1,5), "sh000903")', 'Mask(Mean($volume,5)/$volume, "sh000903")', 'Mask(Std($volume,5)/$volume, "sh000903")', 'Mask(Mean($close/Ref($close,1)-1,10), "sh000903")', 'Mask(Std($close/Ref($close,1)-1,10), "sh000903")', 'Mask(Mean($volume,10)/$volume, "sh000903")', 'Mask(Std($volume,10)/$volume, "sh000903")', 'Mask(Mean($close/Ref($close,1)-1,20), "sh000903")', 'Mask(Std($close/Ref($close,1)-1,20), "sh000903")', 'Mask(Mean($volume,20)/$volume, "sh000903")', 'Mask(Std($volume,20)/$volume, "sh000903")', 'Mask(Mean($close/Ref($close,1)-1,30), "sh000903")', 'Mask(Std($close/Ref($close,1)-1,30), "sh000903")', 'Mask(Mean($volume,30)/$volume, "sh000903")', 'Mask(Std($volume,30)/$volume, "sh000903")', 'Mask(Mean($close/Ref($close,1)-1,60), "sh000903")', 'Mask(Std($close/Ref($close,1)-1,60), "sh000903")', 'Mask(Mean($volume,60)/$volume, "sh000903")', 'Mask(Std($volume,60)/$volume, "sh000903")',
Expand All @@ -409,6 +429,12 @@ def get_feature_config():
)

class MASTERTSDatasetH(TSDatasetH):
"""
MASTER Time Series Dataset with Handler
Args:
market_data_handler_config (dict): market data handler config
"""
def __init__(
self,
market_data_handler_config = Dict,
Expand Down Expand Up @@ -438,17 +464,12 @@ def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
data = super(TSDatasetH, self)._prepare_seg(ext_slice, **kwargs)

############################## Add market information ###########################
# If we only need label for testing, we do not need to add market information
if not only_label:
marketData = self.get_market_information(ext_slice)
cols = pd.MultiIndex.from_tuples([("feature", feature) for feature in marketData.columns])
marketData = pd.DataFrame(marketData.values, columns = cols, index = marketData.index)
# print(marketData.index)
# print(marketData.columns)
# print(data.index)
# print(data.columns)
data = data.iloc[:,:-1].join(marketData).join(data.iloc[:,-1])
# print(data.columns)
# print(data.shape)
#################################################################################
flt_kwargs = copy.deepcopy(kwargs)
if flt_col is not None:
Expand Down
2 changes: 2 additions & 0 deletions qlib/data/dataset/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ def __call__(self, df):
# So we use numpy to accelerate filling values
nan_select = np.isnan(df.values)
nan_select[:, ~df.columns.isin(cols)] = False

# FIXME: For pandas==2.0.3, the following code will not set the nan value to be self.fill_value
# df.values[nan_select] = self.fill_value

# lqa's method
Expand Down

0 comments on commit 258b187

Please sign in to comment.