-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathloaddata.py
130 lines (98 loc) · 3.33 KB
/
loaddata.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#
# Copyright (c) 2019 MagicStack Inc.
# All rights reserved.
#
# See LICENSE for details.
##
import argparse
import collections
import datetime
import json
import progress.bar
import sqlalchemy as sa
import sqlalchemy.orm as orm
import _sqlalchemy.models as m
def bar(label, total):
return progress.bar.Bar(label[:32].ljust(32), max=total)
def bulk_insert(db, label, data, into):
label = f"Creating {len(data)} {label}"
pbar = bar(label, len(data))
while data:
chunk = data[:1000]
data = data[1000:]
db.execute(sa.insert(into), chunk)
db.commit()
pbar.next(len(chunk))
pbar.finish()
def reset_sequence(db, tablename):
tab = sa.table(tablename, sa.column("id"))
db.execute(
sa.select(
sa.func.setval(
f"{tablename}_id_seq",
sa.select(tab.c.id)
.order_by(tab.c.id.desc())
.limit(1)
.scalar_subquery(),
)
)
)
def load_data(filename, engine):
session_factory = orm.sessionmaker(bind=engine)
Session = orm.scoped_session(session_factory)
with Session() as db:
# first clear all the existing data
print(f"purging existing data...")
db.execute(sa.delete(m.Directors))
db.execute(sa.delete(m.Cast))
db.execute(sa.delete(m.Review))
db.execute(sa.delete(m.Movie))
db.execute(sa.delete(m.Person))
db.execute(sa.delete(m.User))
db.commit()
# read the JSON data
print("loading JSON... ", end="", flush=True)
with open(filename, "rt") as f:
records = json.load(f)
data = collections.defaultdict(list)
for rec in records:
rtype = rec["model"].split(".")[-1]
datum = rec["fields"]
if "pk" in rec:
datum["id"] = rec["pk"]
# convert datetime
if rtype == "review":
datum["creation_time"] = datetime.datetime.fromisoformat(
datum["creation_time"]
)
data[rtype].append(datum)
print("done")
with Session() as db:
# bulk create all the users
bulk_insert(db, "users", data["user"], m.User)
# bulk create all the people
bulk_insert(db, "people", data["person"], m.Person)
# bulk create all the movies
bulk_insert(db, "movies", data["movie"], m.Movie)
# bulk create all the reviews
bulk_insert(db, "reviews", data["review"], m.Review)
# bulk create all the directors
bulk_insert(db, "directors", data["directors"], m.Directors)
# bulk create all the cast
bulk_insert(db, "cast", data["cast"], m.Cast)
# reconcile the autoincrementing indexes with the actual indexes
reset_sequence(db, "cast")
reset_sequence(db, "directors")
reset_sequence(db, "movie")
reset_sequence(db, "person")
reset_sequence(db, "user")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Load a specific fixture, old data will be purged."
)
parser.add_argument("filename", type=str, help="The JSON dataset file")
args = parser.parse_args()
engine = sa.create_engine(
"postgresql+asyncpg://sqlalch_bench:edgedbbenchmark@localhost:15432/sqlalch_bench?async_fallback=True"
)
load_data(args.filename, engine)