-
Notifications
You must be signed in to change notification settings - Fork 1
/
average_prediction.py
32 lines (29 loc) · 1.25 KB
/
average_prediction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import argparse
import pandas as pd
import os
import utils
import numpy as np
def average_prediction(istart,istop):
""" averages model's prediction in the submission folder, starting from version @istart to @istop excluded """
config = utils.read_config()
fname = os.path.join(config.SUBMISSIONS_FOLDER,utils.get_submission_name(config))
l = [pd.read_csv(fname.format(version=i)) for i in range(istart,istop)]
for i in range(1,istop-istart):
assert((l[0].flight_id.values == l[i].flight_id.values).all())
tows = np.array([li.tow.values for li in l])
results = pd.DataFrame({"flight_id":l[0].flight_id.values,"tow":np.mean(tows,axis=0)})
print("results.shape",results.shape)
return results
def main():
import readers
parser = argparse.ArgumentParser(
description="averages model's prediction in the submission folder, starting from version @istart to @istop excluded, results is dumped in @out_csv",
)
parser.add_argument("-istart",default=0,type=int)
parser.add_argument("-istop",type=int)
parser.add_argument("-out_csv")
args = parser.parse_args()
res = average_prediction(args.istart,args.istop)
res.to_csv(args.out_csv,index=False)
if __name__ == '__main__':
main()