Skip to content

Commit

Permalink
NiFi scripts: cohort export script update.
Browse files Browse the repository at this point in the history
  • Loading branch information
vladd-bit committed May 23, 2024
1 parent c09d07e commit 86496ed
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 108 deletions.
272 changes: 167 additions & 105 deletions nifi/user-scripts/cogstack_cohort_generate_data.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import json
import sys
import logging
import datetime
from datetime import datetime
import os
import traceback
import regex
import multiprocess

from multiprocess import Pool, Queue
from collections import defaultdict, Counter
from datetime import datetime
from utils.ethnicity_map import ethnicity_map, ethnicity_map_detail
from utils.generic import chunk, dict2json_file
from utils.ethnicity_map import ethnicity_map
from utils.generic import chunk, dict2json_file, dict2json_truncate_add_to_file

# default values from /deploy/nifi.env
USER_SCRIPT_LOGS_DIR = os.getenv("USER_SCRIPT_LOGS_DIR", "")

LOG_FILE_NAME = "cohort_export.log"
log_file_path = os.path.join(USER_SCRIPT_LOGS_DIR, str(LOG_FILE_NAME))


ANNOTATION_DOCUMENT_ID_FIELD_NAME = "meta.docid"
DOCUMENT_ID_FIELD_NAME = "docid"
Expand All @@ -34,14 +40,17 @@

CPU_THREADS = os.getenv("CPU_THREADS", int(multiprocess.cpu_count() / 2))

INPUT_FOLDER_PATH = ""

# json file(s) containing annotations exported by NiFi, the input format is expected to be one provided
# by MedCAT Service which was stored in an Elasticsearch index
INPUT_PATIENT_RECORD_FILE_NAME_PATTERN = ""
INPUT_ANNOTATIONS_RECORDS_FILE_NAME_PATTERN = ""

for arg in sys.argv:
_arg = arg.split("=", 1)
if _arg[0] == "annotation_document_id_field_name":
ANNOTATION_DOCUMENT_ID_FIELD_NAME = _arg[1]
if _arg[0] == "input_patient_records_path":
INPUT_PATIENT_RECORDS_PATH = _arg[1]
if _arg[0] == "input_annotations_records_path":
INPUT_ANNOTATIONS_RECORDS_PATH = _arg[1]
if _arg[0] == "date_time_format":
DATE_TIME_FORMAT = _arg[1]
if _arg[0] == "patient_id_field_name":
Expand All @@ -62,37 +71,15 @@
TIMEOUT = _arg[1]
if _arg[0] == "output_folder_path":
OUTPUT_FOLDER_PATH = _arg[1]
if _arg[0] == "input_folder_path":
INPUT_FOLDER_PATH = _arg[1]
if _arg[0] == "input_patient_record_file_name_pattern":
INPUT_PATIENT_RECORD_FILE_NAME_PATTERN = _arg[1]
if _arg[0] == "input_annotations_records_file_name_pattern":
INPUT_ANNOTATIONS_RECORDS_FILE_NAME_PATTERN = _arg[1]


# json file containing annotations exported by NiFi, the input format is expected to be one provided
# by MedCAT Service which was stored in an Elasticsearch index
input_annotations = json.loads(open(INPUT_ANNOTATIONS_RECORDS_PATH, mode="r+").read())

# json file containing record data from a SQL database or from Elasticsearch
input_patient_record_data = json.loads(open(INPUT_PATIENT_RECORDS_PATH, mode="r+").read())

# cui2ptt_pos.jsonl each line is a dictionary of cui and the value is a dictionary of patients with a count {<cui>: {<patient_id>:<count>, ...}}\n...
cui2ptt_pos = defaultdict(Counter) # store the count of a SNOMED term for a patient

# cui2ptt_tsp.jsonl each line is a dictionary of cui and the value is a dictionary of patients with a timestamp {<cui>: {<patient_id>:<tsp>, ...}}\n...
cui2ptt_tsp = defaultdict(lambda: defaultdict(int)) # store the first mention timestamp of a SNOMED term for a patient

# doc2ptt is a dictionary {<doc_id> : <patient_id>, ...}
doc2ptt = {}

# ptt2sex.json a dictionary for gender of each patient {<patient_id>:<gender>, ...}
ptt2sex = {}
# ptt2eth.json a dictionary for ethnicity of each patient {<patient_id>:<ethnicity>, ...}
ptt2eth = {}
# ptt2dob.json a dictionary for date of birth of each patient {<patient_id>:<dob>, ...}
ptt2dob = {}
# ptt2age.json a dictionary for age of each patient {<patient_id>:<age>, ...}
ptt2age = {}
# ptt2dod.json a dictionary for dod if the patient has died {<patient_id>:<dod>, ...}
ptt2dod = {}


def process_patient_records(patient_records: list):
def _process_patient_records(patient_records: list):
_ptt2sex, _ptt2eth, _ptt2dob, _ptt2age, _ptt2dod, _doc2ptt = {}, {}, {}, {}, {}, {}

for patient_record in patient_records:
Expand All @@ -116,8 +103,7 @@ def process_patient_records(patient_records: list):
_ptt2sex[patient_record[PATIENT_ID_FIELD_NAME]] = _tmp_gender

dob = datetime.strptime(patient_record[PATIENT_BIRTH_DATE_FIELD_NAME], DATE_TIME_FORMAT)

dod = patient_record[PATIENT_DEATH_DATE_FIELD_NAME] if PATIENT_DEATH_DATE_FIELD_NAME in patient_record.keys() else None
dod = patient_record[PATIENT_DEATH_DATE_FIELD_NAME] if PATIENT_DEATH_DATE_FIELD_NAME in patient_record.keys() else None
patient_age = 0

if dod not in [None, "null", 0]:
Expand Down Expand Up @@ -145,108 +131,184 @@ def process_patient_records(patient_records: list):

return _ptt2sex, _ptt2eth, _ptt2dob, _ptt2age, _ptt2dod, _doc2ptt

def process_annotation_records(annotation_records: list, _doc2ptt: dict):

def _process_annotation_records(annotation_records: list, _doc2ptt: dict):

_cui2ptt_pos = defaultdict(Counter)
_cui2ptt_tsp = defaultdict(lambda: defaultdict(int))

try:

# for each part of the MedCAT output (e.g., part_0.pickle)
for annotation_record in annotation_records:
annotation_entity = annotation_record
if "_source" in annotation_record.keys():
annotation_entity = annotation_record["_source"]
docid = annotation_entity[ANNOTATION_DOCUMENT_ID_FIELD_NAME]

if docid in list(_doc2ptt.keys()):
patient_id = _doc2ptt[docid]
if str(docid) in _doc2ptt.keys():
patient_id = _doc2ptt[str(docid)]
cui = annotation_entity["nlp.cui"]

if annotation_entity["nlp.meta_anns"]["Subject"]["value"] == "Patient" and annotation_entity["nlp.meta_anns"]["Presence"]["value"] == "True" and annotation_entity["nlp.meta_anns"]["Time"]["value"] != "Future":
_cui2ptt_pos[cui][patient_id] += 1

print(annotation_record)
if "timestamp" in annotation_entity.keys():
time = int(round(datetime.fromisoformat(annotation_entity["timestamp"]).timestamp()))
time = int(round(datetime.fromisoformat(annotation_entity["timestamp"]).timestamp()))
print(patient_id)

if _cui2ptt_tsp[cui][patient_id] == 0 or time < _cui2ptt_tsp[cui][patient_id]:
_cui2ptt_tsp[cui][patient_id] = time
except Exception:
raise Exception("exception generated by process_annotation_records: " + str(traceback.format_exc()))

return _cui2ptt_pos, _cui2ptt_tsp

patient_process_pool_results = []
annotation_process_pool_results = []

with Pool(processes=CPU_THREADS) as patient_process_pool:
results = list()
def multiprocess_patient_records(input_patient_record_data: dict):

rec_que = Queue()
# ptt2sex.json a dictionary for gender of each patient {<patient_id>:<gender>, ...}
ptt2sex = {}
# ptt2eth.json a dictionary for ethnicity of each patient {<patient_id>:<ethnicity>, ...}
ptt2eth = {}
# ptt2dob.json a dictionary for date of birth of each patient {<patient_id>:<dob>, ...}
ptt2dob = {}
# ptt2age.json a dictionary for age of each patient {<patient_id>:<age>, ...}
ptt2age = {}
# ptt2dod.json a dictionary for dod if the patient has died {<patient_id>:<dod>, ...}
ptt2dod = {}

record_chunks = list(chunk(input_patient_record_data, CPU_THREADS))
# doc2ptt is a dictionary {<doc_id> : <patient_id>, ...}
doc2ptt = {}

counter = 0
for record_chunk in record_chunks:
rec_que.put(record_chunk)
patient_process_pool_results.append(patient_process_pool.starmap_async(process_patient_records, [(rec_que.get(),)], chunksize=1, error_callback=logging.error))
counter += 1
patient_process_pool_results = []

try:
for result in patient_process_pool_results:
result_data = result.get(timeout=TIMEOUT)
_ptt2sex, _ptt2eth, _ptt2dob, _ptt2age, _ptt2dod, _doc2ptt = result_data[0][0], result_data[0][1], result_data[0][2], result_data[0][3], result_data[0][4], result_data[0][5]

ptt2sex.update(_ptt2sex)
ptt2eth.update(_ptt2eth)
ptt2dob.update(_ptt2dob)
ptt2age.update(_ptt2age)
ptt2dod.update(_ptt2dod)
doc2ptt.update(_doc2ptt)
with Pool(processes=CPU_THREADS) as patient_process_pool:
rec_que = Queue()

except Exception:
raise Exception("exception generated by worker: " + str(traceback.format_exc()))
record_chunks = list(chunk(input_patient_record_data, CPU_THREADS))

counter = 0
for record_chunk in record_chunks:
rec_que.put(record_chunk)
patient_process_pool_results.append(patient_process_pool.starmap_async(_process_patient_records, [(rec_que.get(),)], chunksize=1, error_callback=logging.error))
counter += 1

with Pool(processes=CPU_THREADS) as annotations_process_pool:
results = list()
try:
for result in patient_process_pool_results:
result_data = result.get(timeout=TIMEOUT)
_ptt2sex, _ptt2eth, _ptt2dob, _ptt2age, _ptt2dod, _doc2ptt = result_data[0][0], result_data[0][1], result_data[0][2], result_data[0][3], result_data[0][4], result_data[0][5]

ptt2sex.update(_ptt2sex)
ptt2eth.update(_ptt2eth)
ptt2dob.update(_ptt2dob)
ptt2age.update(_ptt2age)
ptt2dod.update(_ptt2dod)
doc2ptt.update(_doc2ptt)

rec_que = Queue()
except Exception as exception:
time = datetime.datetime.now()
with open(log_file_path, "a+") as log_file:
log_file.write("\n" + str(time) + ": " + str(exception))
log_file.write("\n" + str(time) + ": " + traceback.format_exc())

record_chunks = list(chunk(input_annotations, CPU_THREADS))
return doc2ptt, ptt2dod, ptt2age, ptt2dob, ptt2eth, ptt2sex

counter = 0
for record_chunk in record_chunks:
rec_que.put(record_chunk)
annotation_process_pool_results.append(annotations_process_pool.starmap_async(process_annotation_records, [(rec_que.get(), doc2ptt )], chunksize=1, error_callback=logging.error))
counter += 1
def multiprocess_annotation_records(doc2ptt: dict, input_annotations: dict):

try:
for result in annotation_process_pool_results:
result_data = result.get(timeout=TIMEOUT)
# cui2ptt_pos.jsonl each line is a dictionary of cui and the value is a dictionary of patients with a count {<cui>: {<patient_id>:<count>, ...}}\n...
cui2ptt_pos = defaultdict(Counter) # store the count of a SNOMED term for a patient

_cui2ptt_pos, _cui2ptt_tsp = result_data[0][0], result_data[0][1]
cui2ptt_pos.update(_cui2ptt_pos)
cui2ptt_tsp.update(_cui2ptt_tsp)
# cui2ptt_tsp.jsonl each line is a dictionary of cui and the value is a dictionary of patients with a timestamp {<cui>: {<patient_id>:<tsp>, ...}}\n...
cui2ptt_tsp = defaultdict(lambda: defaultdict(int)) # store the first mention timestamp of a SNOMED term for a patient

except Exception:
raise Exception("exception generated by worker: " + str(traceback.format_exc()))

dict2json_file(ptt2sex, os.path.join(OUTPUT_FOLDER_PATH, "ptt2sex.json"))
dict2json_file(ptt2eth, os.path.join(OUTPUT_FOLDER_PATH, "ptt2eth.json"))
dict2json_file(ptt2dob, os.path.join(OUTPUT_FOLDER_PATH, "ptt2dob.json"))
dict2json_file(ptt2dod, os.path.join(OUTPUT_FOLDER_PATH, "ptt2dod.json"))
dict2json_file(ptt2age, os.path.join(OUTPUT_FOLDER_PATH, "ptt2age.json"))

with open('cui2ptt_pos.jsonl', 'a', encoding='utf-8') as outfile:
for k,v in cui2ptt_pos.items():
o = {k: v}
json_obj = json.loads(json.dumps(o))
json.dump(json_obj, outfile, ensure_ascii=False, indent=None, separators=(',',':'))
print('', file = outfile)

with open('cui2ptt_tsp.jsonl', 'a', encoding='utf-8') as outfile:
for k,v in cui2ptt_tsp.items():
o = {k: v}
json_obj = json.loads(json.dumps(o))
json.dump(json_obj, outfile, ensure_ascii=False, indent=None, separators=(',',':'))
print('', file = outfile)
annotation_process_pool_results = []

with Pool(processes=CPU_THREADS) as annotations_process_pool:

rec_que = Queue()

record_chunks = list(chunk(input_annotations, CPU_THREADS))

counter = 0
for record_chunk in record_chunks:
rec_que.put(record_chunk)
annotation_process_pool_results.append(annotations_process_pool.starmap_async(_process_annotation_records, [(rec_que.get(), doc2ptt)], chunksize=1, error_callback=logging.error))
counter += 1

try:
for result in annotation_process_pool_results:
result_data = result.get(timeout=TIMEOUT)

_cui2ptt_pos, _cui2ptt_tsp = result_data[0][0], result_data[0][1]
cui2ptt_pos.update(_cui2ptt_pos)
cui2ptt_tsp.update(_cui2ptt_tsp)

except Exception as exception:
time = datetime.datetime.now()
with open(log_file_path, "a+") as log_file:
log_file.write("\n" + str(time) + ": " + str(exception))
log_file.write("\n" + str(time) + ": " + traceback.format_exc())

return cui2ptt_pos, cui2ptt_tsp


#############################################



# for testing
# OUTPUT_FOLDER_PATH = "../../data/cogstack-cohort/"
# INPUT_FOLDER_PATH = "../../data/cogstack-cohort/"
# INPUT_ANNOTATIONS_RECORDS_FILE_NAME_PATTERN = "medical_reports_anns_"
# INPUT_PATIENT_RECORD_FILE_NAME_PATTERN = "medical_reports_text__"

global_doc2ptt = {}

if INPUT_PATIENT_RECORD_FILE_NAME_PATTERN:
# read each of the patient record files one by one
for root, sub_directories, files in os.walk(INPUT_FOLDER_PATH):
for file_name in files:
if INPUT_PATIENT_RECORD_FILE_NAME_PATTERN in file_name:
f_path = os.path.join(root,file_name)

contents = []

with open(f_path, mode="r+") as f:
contents = json.loads(f.read())
_doc2ptt, _ptt2dod, _ptt2age, _ptt2dob, _ptt2eth, _ptt2sex = multiprocess_patient_records(contents)
dict2json_truncate_add_to_file(_ptt2sex, os.path.join(OUTPUT_FOLDER_PATH, "ptt2sex.json"))
dict2json_truncate_add_to_file(_ptt2eth, os.path.join(OUTPUT_FOLDER_PATH, "ptt2eth.json"))
dict2json_truncate_add_to_file(_ptt2dob, os.path.join(OUTPUT_FOLDER_PATH, "ptt2dob.json"))
dict2json_truncate_add_to_file(_ptt2dod, os.path.join(OUTPUT_FOLDER_PATH, "ptt2dod.json"))
dict2json_truncate_add_to_file(_ptt2age, os.path.join(OUTPUT_FOLDER_PATH, "ptt2age.json"))

global_doc2ptt.update(_doc2ptt)

if INPUT_ANNOTATIONS_RECORDS_FILE_NAME_PATTERN:
# read each of the patient record files one by one
for root, sub_directories, files in os.walk(INPUT_FOLDER_PATH):
for file_name in files:
if INPUT_ANNOTATIONS_RECORDS_FILE_NAME_PATTERN in file_name:
f_path = os.path.join(root,file_name)

contents = []

with open(f_path, mode="r+") as f:
contents = json.loads(f.read())

cui2ptt_pos, cui2ptt_tsp = multiprocess_annotation_records(global_doc2ptt, contents)
with open(os.path.join(OUTPUT_FOLDER_PATH, "cui2ptt_pos.jsonl"), "a+", encoding="utf-8") as outfile:
for k,v in cui2ptt_pos.items():
o = {k: v}
json_obj = json.loads(json.dumps(o))
json.dump(json_obj, outfile, ensure_ascii=False, indent=None, separators=(',',':'))
print('', file=outfile)

with open(os.path.join(OUTPUT_FOLDER_PATH, "cui2ptt_tsp.jsonl"), "a+", encoding="utf-8") as outfile:
for k,v in cui2ptt_tsp.items():
o = {k: v}
json_obj = json.loads(json.dumps(o))
json.dump(json_obj, outfile, ensure_ascii=False, indent=None, separators=(',',':'))
print('', file=outfile)
21 changes: 18 additions & 3 deletions nifi/user-scripts/utils/generic.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
import json
import os

def chunk(input_list: list, num_slices: int):
for i in range(0, len(input_list), num_slices):
yield input_list[i:i + num_slices]

# function to convert a dictionary to json and write to file (d: dictionary, fn: string (filename))
def dict2json_file(input_dict: dict, file_name: str):
def dict2json_file(input_dict: dict, file_path: str):
# write the json file
with open(file_name, 'w', encoding='utf-8') as outfile:
json.dump(input_dict, outfile, ensure_ascii=False, indent=None, separators=(',',':'))
with open(file_path, "a+", encoding="utf-8") as outfile:
json.dump(input_dict, outfile, ensure_ascii=False, indent=None, separators=(",", ":"))

def dict2json_truncate_add_to_file(input_dict: dict, file_path: str):

if os.path.exists(file_path):
with open(file_path, "a+") as outfile:
outfile.seek(outfile.tell() - 1, os.SEEK_SET)
outfile.truncate()
outfile.write(",")
json_string = json.dumps(input_dict, separators=(",", ":"))
json_string = json_string[1:]

outfile.write(json_string)
else:
dict2json_file(input_dict, file_path)

0 comments on commit 86496ed

Please sign in to comment.