Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ffill+bfill是不是还是会在测试集里引入未来信息呀? #25

Open
YHY-10 opened this issue Jan 1, 2025 · 2 comments
Open

Comments

@YHY-10
Copy link

YHY-10 commented Jan 1, 2025

感谢作者的工作,很有启发性
最近follow了MASTER、MATCC等系列工作,尽管作者更新了opensouce data,但是pickle文件里里面还是使用ffill+bfill,应该还是会引入未来信息?

@LITONG99
Copy link
Contributor

LITONG99 commented Jan 8, 2025

No. If you don't know the label, you cannot compare the model prediction to it.

@weituo2002
Copy link

会有影响,我刚做了实验的,就是将重复行完全去除之后,IC值将降到0.04左右
你只要在base_model.py中加入一个删除重复行的函数,则个重复行就是由于填充引起的:
#删除有重复行的股票数据
def drop_duplicates(x, tolerance=1e-10):
"""
检查每只股票的8天数据中是否存在重复的特征行
Args:
x: 特征张量 shape (N, T, F),其中:
N 是股票数量
T=8 是每只股票的交易日数
F=221 是每天的特征数
Returns:
mask: 布尔掩码,标记不含重复行的股票位置
会检查每只股票的8天数据中是否有重复的特征行
如果发现某只股票在任意两天的特征完全相同,就将这只股票的所有数据删除
保持了数据的完整性,要么保留一只股票的全部8天数据,要么完全删除
"""
N, T, F = x.shape
# 创建掩码,初始全为True
mask = torch.ones(N, dtype=torch.bool, device=x.device)

# 对每只股票检查其8天的数据
for stock_idx in range(N):
    stock_data = x[stock_idx]  # shape: (8, 221)

    # 检查这只股票的8天数据中是否有重复行
    found_duplicate = False
    feature_dict = {}

    # 比较8天中任意两天是否相同
    for day_idx in range(T):
        feat_key = tuple(stock_data[day_idx].cpu().numpy())
        if feat_key in feature_dict:
            # 发现重复行,标记这只股票需要删除
            found_duplicate = True
            break
        else:
            feature_dict[feat_key] = day_idx

    if found_duplicate:
        # 如果发现重复行,将这只股票标记为False
        mask[stock_idx] = False

return mask

然后在 def train_epoch(self, data_loader):中修改每个批次的循环,增加删除重复行的函数:
for batch_idx, data in enumerate(data_loader):
data = torch.squeeze(data, dim=0)
feature = data[:, :, 0:-1].to(self.device) # shape: (N, 8, 221)
label = data[:, -1, -1].to(self.device) # shape: (N,)

        #print(f"Batch {batch_idx} - Original size: {feature.shape[0]}")

        # 处理极端值
        mask, label = drop_extreme(label)
        feature = feature[mask, :, :]
        #print(f"After extreme drop: {feature.shape[0]}")

        if feature.shape[0] > 0:
            # 处理重复行
            dup_mask = drop_duplicates(feature)
            #print(f"Duplicate mask sum: {dup_mask.sum()} out of {len(dup_mask)}")
            feature = feature[dup_mask, :, :]
            label = label[dup_mask]
            #print(f"After duplicate drop: {feature.shape[0]}")

            if feature.shape[0] > 0:
                # 标准化处理
                label = zscore(label)
               # print(f"Label after zscore - mean: {label.mean():.4f}, std: {label.std():.4f}")

最终的结果会下降不少:
IC: 0.0474 pm 0.0061
ICIR: 0.2996 pm 0.0499
RIC: 0.0513 pm 0.0067
RICIR: 0.3015 pm 0.0533

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants