Skip to content

Commit

Permalink
Support wl_cumcat & kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
dongkwan-kim committed Jun 24, 2024
1 parent 7c431e5 commit 0a70ce3
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions WL4S/wl4s.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,16 @@ def experiment(args, h_or_k_list, splits, all_y, **model_kwargs):
best_wl = torch.zeros(args.runs, dtype=torch.float)

if args.wl_cumcat:
assert args.dtype == "histogram"
cumcat_list, cumcat_h = [], torch.tensor([])
for h in h_or_k_list:
cumcat_h = torch.cat((cumcat_h, h), dim=-1)
cumcat_list.append(cumcat_h)
sum_k_list = [None, None, None]
if args.dtype == "histogram":
for h in h_or_k_list:
cumcat_h = torch.cat((cumcat_h, h), dim=-1)
cumcat_list.append(cumcat_h)
else: # kernel
for i, k_list in enumerate(h_or_k_list):
sum_k_list = k_list if sum_k_list[0] is None else [k + sum_k for k, sum_k in zip(k_list, sum_k_list)]
cumcat_list.append(tuple(sum_k_list))
h_or_k_list = cumcat_list

for run in range(1, args.runs + 1):
Expand Down

0 comments on commit 0a70ce3

Please sign in to comment.