Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MedicalQuestionAnswering Integration & Added new pipelines #243

Merged
merged 1 commit into from
Jan 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
18 changes: 18 additions & 0 deletions nlu/components/classifiers/span_medical/span_medical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
class SpanMedical:
@staticmethod
def get_default_model():
from sparknlp_jsl.annotator import MedicalQuestionAnswering

return MedicalQuestionAnswering.pretrained() \
.setInputCols(["document_question", "context"]) \
.setOutputCol("answer")



@staticmethod
def get_pretrained_model(name, language, bucket=None):
from sparknlp_jsl.annotator import MedicalQuestionAnswering

return MedicalQuestionAnswering.pretrained(name, language, bucket) \
.setInputCols(["document_question", "context"]) \
.setOutputCol("answer")
25 changes: 25 additions & 0 deletions nlu/pipe/col_substitution/col_substitution_HC.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,3 +395,28 @@ def substitute_generic_classifier_parser_cols(c, cols, is_unique=True, nlu_ident
logger.info(f'Dropping unmatched metadata_col={col} for c={c}')
# new_cols[col]= f"{new_base_name}_confidence"
return new_cols
def substitute_hc_span_classifier_cols(c, cols, nlu_identifier=True):
"""
QA classifier
"""
new_cols = {}
#new_base_name = 'answer' if nlu_identifier == 'UNIQUE' else f'{nlu_identifier}_answer'
new_base_name = 'answer'
for col in cols:
if 'answer_results' in col:
new_cols[col] = f'{new_base_name}'
if 'answer_results_score' in col:
new_cols[col] = f'{new_base_name}_confidence'

elif 'span_start_score' in col:
new_cols[col] = f'{new_base_name}_start_confidence'
elif 'span_end_score' in col:
new_cols[col] = f'{new_base_name}_end_confidence'
elif 'start' in col and not 'score' in col:
new_cols[col] = f'{new_base_name}_start'
elif 'end' in col and not 'score' in col:
new_cols[col] = f'{new_base_name}_end'
elif 'sentence' in col:
new_cols[col] = f'{new_base_name}_sentence'

return new_cols
4 changes: 3 additions & 1 deletion nlu/spellbook.py
Original file line number Diff line number Diff line change
Expand Up @@ -10598,7 +10598,7 @@ class Spellbook:
'de.deid.pipeline': 'german_deid_pipeline_spark24',
'de.med_ner.deid_generic.pipeline': 'ner_deid_generic_pipeline'},
'en': {

'en.answer_question.clinical_notes_onnx.pipeline': 'clinical_notes_qa_base_onnx_pipeline',
'en.classify.bert_sequence.binary_rct_biobert.pipeline': 'bert_sequence_classifier_binary_rct_biobert_pipeline',
'en.classify.bert_sequence.vop_hcp_consult.pipeline': 'bert_sequence_classifier_vop_hcp_consult_pipeline',
'en.classify.bert_sequence.vop_drug_side_effect.pipeline': 'bert_sequence_classifier_vop_drug_side_effect_pipeline',
Expand Down Expand Up @@ -10634,6 +10634,7 @@ class Spellbook:
'en.explain_doc.clinical_ade': 'explain_clinical_doc_ade',
'en.explain_doc.clinical_radiology.pipeline': 'explain_clinical_doc_radiology',
'en.explain_doc.era': 'explain_clinical_doc_era',
'en.explain_doc.clinical_granular': 'explain_clinical_doc_granular',
'en.icd10_icd9.mapping': 'icd10_icd9_mapping',
'en.icd10cm.umls.mapping': 'icd10cm_umls_mapping',
'en.icd10cm_resolver.pipeline': 'icd10cm_resolver_pipeline',
Expand Down Expand Up @@ -10765,6 +10766,7 @@ class Spellbook:
'en.spell.clinical.pipeline': 'spellcheck_clinical_pipeline',
'en.summarize.biomedical_pubmed.pipeline':'summarizer_biomedical_pubmed_pipeline',
'en.summarize.clinical_guidelines_large.pipeline': 'summarizer_clinical_guidelines_large_pipeline',
'en.summarize.clinical_laymen_onnx.pipeline': 'summarizer_clinical_laymen_onnx_pipeline',
'en.summarize.clinical_jsl_augmented.pipeline': 'summarizer_clinical_jsl_augmented_pipeline',
'en.summarize.clinical_questions.pipeline': 'summarizer_clinical_questions_pipeline',
'en.summarize.generic_jsl.pipeline': 'summarizer_generic_jsl_pipeline',
Expand Down
3 changes: 2 additions & 1 deletion nlu/universe/annotator_class_universe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class AnnoClassRef:
JSL_anno2_py_class: Dict[JslAnnoId, JslAnnoPyClass] = {

A_N.E5_SENTENCE_EMBEDDINGS: 'E5Embeddings',
A_N.INSTRUCTOR_SENTENCE_EMBEDDINGS:'InstructorEmbeddings',
A_N.INSTRUCTOR_SENTENCE_EMBEDDINGS: 'InstructorEmbeddings',

A_N.WHISPER_FOR_CTC: 'WhisperForCTC',
A_N.HUBERT_FOR_CTC: 'HubertForCTC',
Expand Down Expand Up @@ -240,6 +240,7 @@ class AnnoClassRef:

}
JSL_anno_HC_ref_2_py_class: Dict[JslAnnoId, JslAnnoPyClass] = {
HC_A_N.MEDICAL_QUESTION_ANSWERING: 'MedicalQuestionAnswering',
HC_A_N.MEDICAL_TEXT_GENERATOR: 'MedicalTextGenerator',
HC_A_N.MEDICAL_SUMMARIZER:'MedicalSummarizer',
HC_A_N.ZERO_SHOT_NER: 'ZeroShotNerModel',
Expand Down
22 changes: 22 additions & 0 deletions nlu/universe/component_universes.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from nlu.components.classifiers.span_longformer.span_longformer import SpanLongFormerClassifier
from nlu.components.classifiers.span_roberta.span_roberta import SpanRobertaClassifier
from nlu.components.classifiers.span_xlm_roberta.span_xlm_roberta import SpanXlmRobertaClassifier
from nlu.components.classifiers.span_medical.span_medical import SpanMedical
from nlu.components.classifiers.token_albert.token_albert import TokenAlbert
from nlu.components.classifiers.token_bert.token_bert import TokenBert
from nlu.components.classifiers.token_bert_healthcare.token_bert_healthcare import TokenBertHealthcare
Expand Down Expand Up @@ -3278,6 +3279,27 @@ class ComponentUniverse:
computation_context=ComputeContexts.spark,
output_context=ComputeContexts.spark,
),
H_A.MEDICAL_QUESTION_ANSWERING: partial(NluComponent,
name=H_A.MEDICAL_QUESTION_ANSWERING,
jsl_anno_class_id= H_A.MEDICAL_QUESTION_ANSWERING,
jsl_anno_py_class= ACR.JSL_anno_HC_ref_2_py_class[
H_A.MEDICAL_QUESTION_ANSWERING],
node= NLP_HC_FEATURE_NODES.nodes[
H_A.MEDICAL_QUESTION_ANSWERING],
get_default_model= SpanMedical.get_default_model,
get_pretrained_model= SpanMedical.get_pretrained_model,
type= T.QUESTION_SPAN_CLASSIFIER,
pdf_extractor_methods={
'default': default_span_classifier_config,
'default_full': default_full_span_classifier_config, },
pdf_col_name_substitutor=substitute_hc_span_classifier_cols,
output_level=L.INPUT_DEPENDENT_DOCUMENT_CLASSIFIER,
description='TODO',
provider=ComponentBackends.hc,
license=Licenses.hc,
computation_context=ComputeContexts.spark,
output_context=ComputeContexts.spark,
),

A.MULTI_DOCUMENT_ASSEMBLER: partial(NluComponent,
name=A.MULTI_DOCUMENT_ASSEMBLER,
Expand Down
1 change: 1 addition & 0 deletions nlu/universe/feature_node_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ class NLP_HC_NODE_IDS: # or Mode Node?
ENTITY_CHUNK_EMBEDDING = JslAnnoId('entity_chunk_embedding')
MEDICAL_SUMMARIZER = JslAnnoId('med_summarizer')
MEDICAL_TEXT_GENERATOR = JslAnnoId('med_text_generator')
MEDICAL_QUESTION_ANSWERING = JslAnnoId('med_question_answering')

class OCR_NODE_IDS:
"""All available Feature nodes in OCR
Expand Down
2 changes: 2 additions & 0 deletions nlu/universe/feature_node_universes.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,8 @@ class NLP_HC_FEATURE_NODES():
H_F = NLP_HC_FEATURES
# HC Feature Nodes
nodes = {
A.MEDICAL_QUESTION_ANSWERING: NlpFeatureNode(A.MEDICAL_QUESTION_ANSWERING, [F.DOCUMENT_QUESTION, F.DOCUMENT_QUESTION_CONTEXT], [F.CLASSIFIED_SPAN]),

A.MEDICAL_TEXT_GENERATOR: NlpFeatureNode(A.MEDICAL_TEXT_GENERATOR, [F.DOCUMENT], [F.DOCUMENT_GENERATED]),

A.MEDICAL_SUMMARIZER: NlpFeatureNode(A.MEDICAL_SUMMARIZER, [F.DOCUMENT], [F.DOCUMENT_GENERATED]),
Expand Down
Loading