Skip to content

Commit

Permalink
Ensure 'with' is used with Sessions where possible
Browse files Browse the repository at this point in the history
  • Loading branch information
tlocke committed Sep 7, 2023
1 parent fc288e9 commit 4eff054
Show file tree
Hide file tree
Showing 62 changed files with 4,185 additions and 4,470 deletions.
7 changes: 1 addition & 6 deletions chellow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,13 @@ def create_app(testing=False):
db_upgrade(app.root_path)
chellow.dloads.startup(app.instance_path)

sess = None
try:
sess = Session()
with Session() as sess:
configuration = sess.execute(
select(Contract).where(Contract.name == "configuration")
).scalar_one()
props = configuration.make_properties()
api_props = props.get("api", {})
api.description = api_props.get("description", "Access Chellow data")
finally:
if sess is not None:
sess.close()

for module in get_importer_modules():
if not testing:
Expand Down
145 changes: 64 additions & 81 deletions chellow/bank_holidays.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,67 @@
bh_importer = None


def _run(log_f, sess):

log_f("Starting to check bank holidays")
contract = Contract.get_non_core_by_name(sess, "bank_holidays")
contract_props = contract.make_properties()

if contract_props.get("enabled", False):
url_str = contract_props["url"]

log_f(f"Downloading from {url_str}.")
res = requests.get(url_str)
log_f(" ".join(("Received", str(res.status_code), res.reason)))
PREFIX = "DTSTART;VALUE=DATE:"
hols = collections.defaultdict(list)
for line in res.text.splitlines():
if line.startswith(PREFIX):
dt = utc_datetime_parse(line[-8:], "%Y%m%d")
hols[dt.year].append(dt)

for year in sorted(hols.keys()):
year_start = utc_datetime(year, 1, 1)
year_finish = year_start + relativedelta(years=1) - HH
rs = (
sess.query(RateScript)
.filter(
RateScript.contract == contract,
RateScript.start_date == year_start,
)
.first()
)
if rs is None:
log_f(f"Adding a new rate script starting at {hh_format(year_start)}.")

latest_rs = (
sess.query(RateScript)
.filter(RateScript.contract == contract)
.order_by(RateScript.start_date.desc())
.first()
)

contract.update_rate_script(
sess,
latest_rs,
latest_rs.start_date,
year_finish,
loads(latest_rs.script),
)
rs = contract.insert_rate_script(sess, year_start, {})

script = {"bank_holidays": [v.strftime("%Y-%m-%d") for v in hols[year]]}

contract.update_rate_script(sess, rs, rs.start_date, rs.finish_date, script)
sess.commit()
log_f(f"Updated rate script starting at {hh_format(year_start)}.")
else:
log_f(
"The automatic importer is disabled. To enable it, edit the contract "
"properties to set 'enabled' to True."
)


class BankHolidayImporter(threading.Thread):
def __init__(self):
super(BankHolidayImporter, self).__init__(name="Bank Holiday Importer")
Expand Down Expand Up @@ -58,90 +119,12 @@ def log(self, message):
def run(self):
while not self.stopped.isSet():
if self.lock.acquire(False):
sess = None
try:
sess = Session()
self.log("Starting to check bank holidays")
contract = Contract.get_non_core_by_name(sess, "bank_holidays")
contract_props = contract.make_properties()

if contract_props.get("enabled", False):
url_str = contract_props["url"]

self.log("Downloading from " + url_str + ".")
res = requests.get(url_str)
self.log(
" ".join(("Received", str(res.status_code), res.reason))
)
PREFIX = "DTSTART;VALUE=DATE:"
hols = collections.defaultdict(list)
for line in res.text.splitlines():
if line.startswith(PREFIX):
dt = utc_datetime_parse(line[-8:], "%Y%m%d")
hols[dt.year].append(dt)

for year in sorted(hols.keys()):
year_start = utc_datetime(year, 1, 1)
year_finish = year_start + relativedelta(years=1) - HH
rs = (
sess.query(RateScript)
.filter(
RateScript.contract == contract,
RateScript.start_date == year_start,
)
.first()
)
if rs is None:
self.log(
"Adding a new rate script starting at "
+ hh_format(year_start)
+ "."
)

latest_rs = (
sess.query(RateScript)
.filter(RateScript.contract == contract)
.order_by(RateScript.start_date.desc())
.first()
)

contract.update_rate_script(
sess,
latest_rs,
latest_rs.start_date,
year_finish,
loads(latest_rs.script),
)
rs = contract.insert_rate_script(sess, year_start, {})

script = {
"bank_holidays": [
v.strftime("%Y-%m-%d") for v in hols[year]
]
}

contract.update_rate_script(
sess, rs, rs.start_date, rs.finish_date, script
)
sess.commit()
self.log(
"Updated rate script starting at "
+ hh_format(year_start)
+ "."
)
else:
self.log(
"The automatic importer is disabled. To "
"enable it, edit the contract properties to "
"set 'enabled' to True."
)

with Session() as sess:
_run(self.log, sess)
except BaseException:
self.log("Outer problem " + traceback.format_exc())
sess.rollback()
self.log(f"Outer problem {traceback.format_exc()}")
finally:
if sess is not None:
sess.close()
self.lock.release()
self.log("Finished checking bank holidays.")

Expand Down
151 changes: 74 additions & 77 deletions chellow/e/bill_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,97 +64,94 @@ def status(self):
return "Not running"

def run(self):
sess = None
try:
sess = Session()
batch = Batch.get_by_id(sess, self.batch_id)
with Session() as sess:
batch = Batch.get_by_id(sess, self.batch_id)

bill_types = keydefaultdict(lambda k: BillType.get_by_code(sess, k))
bill_types = keydefaultdict(lambda k: BillType.get_by_code(sess, k))

tprs = keydefaultdict(
lambda k: None if k is None else Tpr.get_by_code(sess, k)
)
tprs = keydefaultdict(
lambda k: None if k is None else Tpr.get_by_code(sess, k)
)

read_types = keydefaultdict(lambda k: ReadType.get_by_code(sess, k))

for bf in (
sess.query(BatchFile)
.filter(BatchFile.batch == batch)
.order_by(BatchFile.upload_timestamp)
):
self.parser = _process_batch_file(sess, bf, self._log)
for self.bill_num, raw_bill in enumerate(self.parser.make_raw_bills()):
if "error" in raw_bill:
self.failed_bills.append(raw_bill)
else:
try:
mpan_core = raw_bill["mpan_core"]
supply = Supply.get_by_mpan_core(sess, mpan_core)
with sess.begin_nested():
bill = batch.insert_bill(
sess,
raw_bill["account"],
raw_bill["reference"],
raw_bill["issue_date"],
raw_bill["start_date"],
raw_bill["finish_date"],
raw_bill["kwh"],
raw_bill["net"],
raw_bill["vat"],
raw_bill["gross"],
bill_types[raw_bill["bill_type_code"]],
raw_bill["breakdown"],
supply,
)
for raw_read in raw_bill["reads"]:
bill.insert_read(
read_types = keydefaultdict(lambda k: ReadType.get_by_code(sess, k))

for bf in (
sess.query(BatchFile)
.filter(BatchFile.batch == batch)
.order_by(BatchFile.upload_timestamp)
):
self.parser = _process_batch_file(sess, bf, self._log)
for self.bill_num, raw_bill in enumerate(
self.parser.make_raw_bills()
):
if "error" in raw_bill:
self.failed_bills.append(raw_bill)
else:
try:
mpan_core = raw_bill["mpan_core"]
supply = Supply.get_by_mpan_core(sess, mpan_core)
with sess.begin_nested():
bill = batch.insert_bill(
sess,
tprs[raw_read["tpr_code"]],
raw_read["coefficient"],
raw_read["units"],
raw_read["msn"],
raw_read["mpan"],
raw_read["prev_date"],
raw_read["prev_value"],
read_types[raw_read["prev_type_code"]],
raw_read["pres_date"],
raw_read["pres_value"],
read_types[raw_read["pres_type_code"]],
raw_bill["account"],
raw_bill["reference"],
raw_bill["issue_date"],
raw_bill["start_date"],
raw_bill["finish_date"],
raw_bill["kwh"],
raw_bill["net"],
raw_bill["vat"],
raw_bill["gross"],
bill_types[raw_bill["bill_type_code"]],
raw_bill["breakdown"],
supply,
)
self.successful_bills.append(raw_bill)
except KeyError as e:
err = raw_bill.get("error", "")
raw_bill["error"] = err + " " + str(e)
self.failed_bills.append(raw_bill)
except BadRequest as e:
raw_bill["error"] = str(e.description)
self.failed_bills.append(raw_bill)

if len(self.failed_bills) == 0:
sess.commit()
self._log(
"All the bills have been successfully loaded and attached "
"to the batch."
)
else:
sess.rollback()
self._log(
f"The import has finished, but there were {len(self.failed_bills)} "
f"failures, and so the whole import has been rolled back."
)
for raw_read in raw_bill["reads"]:
bill.insert_read(
sess,
tprs[raw_read["tpr_code"]],
raw_read["coefficient"],
raw_read["units"],
raw_read["msn"],
raw_read["mpan"],
raw_read["prev_date"],
raw_read["prev_value"],
read_types[raw_read["prev_type_code"]],
raw_read["pres_date"],
raw_read["pres_value"],
read_types[raw_read["pres_type_code"]],
)
self.successful_bills.append(raw_bill)
except KeyError as e:
err = raw_bill.get("error", "")
raw_bill["error"] = err + " " + str(e)
self.failed_bills.append(raw_bill)
except BadRequest as e:
raw_bill["error"] = str(e.description)
self.failed_bills.append(raw_bill)

if len(self.failed_bills) == 0:
sess.commit()
self._log(
"All the bills have been successfully loaded and attached "
"to the batch."
)
else:
sess.rollback()
self._log(
f"The import has finished, but there were "
f"{len(self.failed_bills)} "
f"failures, and so the whole import has been rolled back."
)

except BadRequest as e:
sess.rollback()
msg = f"Problem: {e.description}"
if e.__cause__ is not None:
msg += f" {traceback.format_exc()}"
self._log(msg)
except BaseException:
sess.rollback()
self._log(f"I've encountered a problem: {traceback.format_exc()}")
finally:
if sess is not None:
sess.close()

def make_fields(self):
with import_lock:
Expand Down
6 changes: 0 additions & 6 deletions chellow/e/bill_parsers/activity_mop_stark_xlsx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from werkzeug.exceptions import BadRequest

from chellow.models import Session
from chellow.utils import parse_mpan_core, to_utc


Expand Down Expand Up @@ -65,9 +64,7 @@ def _set_last_line(self, i, line):
return line

def make_raw_bills(self):
sess = None
try:
sess = Session()
bills = []
issue_date = self.get_start_date("C", 6)

Expand Down Expand Up @@ -124,8 +121,5 @@ def make_raw_bills(self):
)
except BadRequest as e:
raise BadRequest(f"Row number: {row} {e.description}")
finally:
if sess is not None:
sess.close()

return bills
Loading

0 comments on commit 4eff054

Please sign in to comment.