This is the depository for StockCL: Selective Contrastive Learning for Stock Trend Forecasting via Learnable Concepts, DASFAA 2024
. In this paper, we develop a novel selective contrastive learning framework named StockCL for stock trend forecasting, which is applicable to any stock trend forecasting models. Our key insight is to identify latent concepts that drive the stock trends and select reliable contrastive pairs according to the samples’ belonging concepts and their label similarity.
The paper and the poster can be found in ./paper
. The code can be found in ./scripts
.
The time complexity of the concept pool component is
Evaluation Metrics. Stock trend forecasting is mainly used in quantitative investment for the ranking of each stock in term of their future daily return. Hence, we use the four popular ranking based evaluation metrics: IC, ICIR, Rank IC and Rank ICIR. At each date
where
Experiment Setup We use PyTorch to develop StockCL and the stock trend forecasting models. All experiments are run on an NVIDIA RTX2080Ti GPU. The dimension
Training curves on CSI500 dataset using ALSTM as stock trend forecasting model.
To understand how StockCL helps overcome the overfitting issue and improve the forecasting model's generalization ability, we draw the training curves of ALSTM on CSI500 as an example, as shown above. Without StockCL, the validation performance drops obviously after the 7th epoch while model's performance on training set keeps increasing. Without additional supervision signals, the limited but complex training stock data makes the model overfit on the training data too early, limiting the model's generalization ability. With StockCL, however, the validation performance reaches the peak on the 11th and the following performance does not drop significantly. With additional supervision signals from contrastive learning, the model can better learn the generic data distribution rather than simply memorizing the training samples. In this sense, with StockCL, the model performance on training set drops compared with the forecasting model itself, while the performance on the validation set gets higher and more stable.To run the project, you should first refer to qlib, download the standard dataset and prepare the environment. Then run the script
python multi_y_run.py -data_time=new3 -base_model=ALSTM -stock_set=csi300 -y_con -memory