Skip to content

Commit

Permalink
Merge pull request #10 from Optum/efficient_msp
Browse files Browse the repository at this point in the history
Efficient MSP
  • Loading branch information
jstremme authored Nov 11, 2022
2 parents f6e7312 + dd2f133 commit e06fb90
Show file tree
Hide file tree
Showing 6 changed files with 277 additions and 77 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ This codebase contains scripts to:
1. Pretrain language models (LMs) in a self-supervised fashion using masked language modeling (MLM) and fine-tune these LMs for document classification.
2. Compute the importance of multi-token text blocks to fine-tuned LM predictions for a given document or set of documents using a variety of methods.

In our paper, [*Extend and Explain: Interpreting Very Long Language Models*](https://arxiv.org/abs/2209.01174), we propose a novel method called the Masked Sampling Procedure (MSP) and compare it to 1) random sampling and 2) the [Sampling and Occlusion (SOC) algorithm
In our paper, ["Extend and Explain: Interpreting Very Long Language Models"](https://arxiv.org/abs/2209.01174), we propose a novel method called the Masked Sampling Procedure (MSP) and compare it to 1) random sampling and 2) the [Sampling and Occlusion (SOC) algorithm
from Jin et al.](https://arxiv.org/pdf/1911.06194.pdf). MSP is well-suited to very long, sparse-attention LMs, and has been validated for medical documents using two physician annotators.

The code to run MSP currently supports [HuggingFace LMs](https://huggingface.co/models) and [Datasets](https://huggingface.co/datasets) and would require slight modifications to use other types of models and input data. If you need to fine-tune or continue pretraining an existing LM, check out `models/README.md`. To create a Hugging Face Dataset, check out the documentation [here](https://huggingface.co/docs/datasets/index).

The code used for the experiments in ["Extend and Explain"](https://arxiv.org/abs/2209.01174) can be found in the [first release of this repository (v0.0.1)](https://github.com/Optum/long-medical-document-lms/releases/tag/v0.0.1). Since then, changes have been made to MSP (for example in [PR#10](https://github.com/Optum/long-medical-document-lms/pull/10)) to improve runtime performance and potentially the clinical informativeness of explanations through new features such as GPU-efficient batching and sentence-level masked sampling.

### Environment

All scripts are intended to be run in a Python 3.8 [Anaconda](https://www.anaconda.com/products/individual) environment. To create such an environment run `conda create --name text-blocks python=3.8` then `source activate text-blocks` to activate the environment. Dependencies can be installed from `requirements.txt` by running `pip install -r requirements.txt` from the base directory of this repository.
Expand Down
26 changes: 15 additions & 11 deletions explain/explain_with_msp.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@ def main():
shutil.rmtree(output_path)
os.makedirs(output_path)

# Configure Device and Empty GPU Cache
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Load Data, Tokenizer, and Model
if PARAMS["offline"]:
os.environ["HF_DATASETS_OFFLINE"] = "1"
Expand All @@ -77,21 +74,21 @@ def tokenize_function(batch):

# Tokenize Text
# Code runs on the test data split by default
dataset["test"] = dataset["test"].map(
dataset = dataset.map(
tokenize_function, batched=True, batch_size=PARAMS["batch_size"]
)

# Take a Random Sample of the Test Data
sample_data = dataset["test"].shuffle()[0 : PARAMS["num_sample"]]
sample_data = dataset.shuffle()[0 : PARAMS["num_sample"]]

# Check Average Precision of Classifier
# To Do: Add CIs to prediction
check_average_precision(
model=model,
data=sample_data,
device=device,
class_strategy=PARAMS["class_strategy"],
average="macro",
average="micro",
batch_size=PARAMS["batch_size"],
)

# Start timer
Expand All @@ -100,7 +97,9 @@ def tokenize_function(batch):
# Run MSP
times = []
all_results = []
for s, doc_input_ids in enumerate(sample_data["input_ids"]):
for s, (doc_input_ids, doc_text) in enumerate(
zip(sample_data["input_ids"], sample_data["text"])
):

# Indicate sample number
logger.info(f"Running MSP for sample {s} of {PARAMS['num_sample']}...")
Expand All @@ -109,19 +108,24 @@ def tokenize_function(batch):
results = predict_with_masked_texts(
model=model,
input_ids=doc_input_ids,
text=doc_text,
n=PARAMS["N"],
k=PARAMS["K"],
p=PARAMS["P"],
mask_token_id=tokenizer.mask_token_id,
idx2label=PARAMS["idx2label"],
print_every=PARAMS["print_every"],
debug=PARAMS["debug"],
device=device,
max_seq_len=PARAMS["max_seq_len"],
class_strategy=PARAMS["class_strategy"],
tokenizer=tokenizer,
by_sent_segments=PARAMS["by_sent_segments"],
batch_size=PARAMS["batch_size"],
)
all_results.append(results)

results["indices_len"] = results["masked_text_indices"].apply(lambda x: len(x))
results["tokens_len"] = results["masked_text_tokens"].apply(lambda x: len(x))

# Compute time to run MSP on one doc
doc_time = time.time()
times.append(doc_time)
Expand Down Expand Up @@ -154,7 +158,6 @@ def tokenize_function(batch):
all_input_ids=sample_data["input_ids"],
all_labels=sample_data["label"],
times=times,
device=device,
tokenizer=tokenizer,
num_sample=PARAMS["num_sample"],
max_seq_len=PARAMS["max_seq_len"],
Expand All @@ -166,6 +169,7 @@ def tokenize_function(batch):
k=PARAMS["K"],
p=PARAMS["P"],
m=PARAMS["M"],
by_sent_segments=PARAMS["by_sent_segments"],
)

# End timer
Expand Down
194 changes: 154 additions & 40 deletions explain/msp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,113 @@
"""
import os
import pysbd
import torch
import random
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset
from torch.nn.utils.rnn import pad_sequence
from utils import (
predict_with_clf_model,
configure_model_for_inference,
convert_binary_to_multi_label,
torch_model_predict,
torch_model_predict_indiv,
)


def run_trial_on_fixed_blocks(tokenizer, input_ids, p, k):

# At each trial, save the masked tokens and the start index of each block
masked_text_tokens = []
masked_text_indices = []

# For each sample, create a new sample consisting of masked and unmasked blocks
new_sample = []

# Iterate through fixed length blocks
for j in range(0, len(input_ids), k):
block = input_ids[j : j + k]

# Mask a block with probability P and add the block to the new sample
if random.uniform(0, 1) < p:

# Create a masked block of appropriate length
# Save the index of the block start
# Add the block to the new sample
# Save the block
mask_block = [tokenizer.mask_token_id] * k
new_sample.extend(mask_block)
masked_text_indices.append(j)
masked_text_tokens.append(block)

else:
new_sample.extend(block)

return new_sample, masked_text_tokens, masked_text_indices


def run_trial_on_sentences(segmenter, tokenizer, text, p):

# Check that tokenizer always gives us a CLS and SEP token at the start and end
test_sent = "This is a test sentence."
e_test_sent = tokenizer.encode(test_sent)
assert_msg = "Tokenizer should always add CLS and SEP tokens to the start and end of each input sequence respectively."
assert (
e_test_sent[0] == tokenizer.cls_token_id
and e_test_sent[-1] == tokenizer.sep_token_id
), assert_msg

# At each trial, save the masked tokens and the start index of each block
masked_text_tokens = []
masked_text_indices = []

# For each sample, create a new sample consisting of masked and unmasked blocks
# We're removing CLS tokens, so make sure the new sample starts with one
new_sample = [tokenizer.cls_token_id]

# Iterate through logical text segments
for segment in segmenter.segment(text):

# Encode a block and remove the CLS and SEP tokens
# Slicing should be faster than removing [tokenizer.cls_token_id, tokenizer.sep_token_id] explicitly
# But it assumes these are always present when encoding
block = tokenizer.encode(segment)
cleaned_block = block[1:-1]

# Mask a block with probability P and add the block to the new sample
if random.uniform(0, 1) < p:

# Create a masked block of appropriate length
# Save the index of the block start
# Add the block to the new sample
# Save the block
mask_block = [tokenizer.mask_token_id] * len(cleaned_block)
masked_text_indices.append(len(new_sample))
new_sample.extend(mask_block)
masked_text_tokens.append(cleaned_block)

else:
new_sample.extend(cleaned_block)

return new_sample, masked_text_tokens, masked_text_indices


def predict_with_masked_texts(
model,
input_ids,
text,
n,
k,
p,
mask_token_id,
idx2label,
print_every,
debug,
device,
max_seq_len,
class_strategy,
tokenizer,
by_sent_segments,
batch_size,
):
"""
Returns the probabilities for each label for each iteration with the masked strings and labels.
Expand All @@ -39,15 +122,18 @@ def predict_with_masked_texts(
)
print(f"Each block of {k} subword tokens is masked with probability {p}.")

# Configure model for inference
model = configure_model_for_inference(model, device)

# Track the text strings masked in each trial and their indices
all_masked_text_tokens = []
all_masked_text_indices = []

# Track the probabilities for each label from each round of masking for each trail
all_probs = []
# Collect the new samples created from each round of masking for each trial
collected_new_samples = []

# Initialize segmenter if generating per sentence explanations
if by_sent_segments:

# Initialize segmenter
segmenter = pysbd.Segmenter(language="en", clean=False)

# Run trials
for i in range(n):
Expand All @@ -61,37 +147,66 @@ def predict_with_masked_texts(
if i % print_every == 0:
print(f" On iteration {i} of {n}...")

# At each trial, save the strings of masked text and the start index of each string
masked_text_tokens = []
masked_text_indices = []

# For each sample, create a new sample consisting of masked and unmasked blocks
new_sample = []
for j in range(0, len(input_ids), k):
block = input_ids[j : j + k]

# Mask a block with probability P and add the block to the new sample
if random.uniform(0, 1) < p:
mask_block = [mask_token_id] * k
new_sample.extend(mask_block)
masked_text_indices.append(j)
masked_text_tokens.append(block)
else:
new_sample.extend(block)

# Compute probabilities of each label on the new sample
prob = predict_with_clf_model(
model,
sample_input_ids=[new_sample[0:max_seq_len]],
device=device,
class_strategy=class_strategy,
)[0]
# Generate sentence-level or fixed block-level explanations
if by_sent_segments:
(
new_sample,
masked_text_tokens,
masked_text_indices,
) = run_trial_on_sentences(
segmenter=segmenter, tokenizer=tokenizer, text=text, p=p
)
else:
(
new_sample,
masked_text_tokens,
masked_text_indices,
) = run_trial_on_fixed_blocks(
tokenizer=tokenizer, input_ids=input_ids, p=p, k=k
)

# Save the probabilities, text strings, and indices from this trial
all_probs.append(prob)
# Save the masked blocks and start indices from this trial
all_masked_text_tokens.append(masked_text_tokens)
all_masked_text_indices.append(masked_text_indices)

# Collect the new sample from this trial
collected_new_samples.append(new_sample[0:max_seq_len])

# Pad new sequences
# Generate pointers to check that predictions are returned in the same order
padded_sequences = pad_sequence(
torch.tensor(collected_new_samples, dtype=torch.int64),
batch_first=True,
padding_value=tokenizer.pad_token_id,
)
pointers_orig = torch.tensor([[i] for i in range(len(padded_sequences))])

# Build dataset of new sequences to use to generate label probabilities
# Include the pointers we created
ds = TensorDataset(padded_sequences, pointers_orig)
dataloader = DataLoader(
dataset=ds,
batch_size=batch_size,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
)

# Iterate through data loader to make predictions and return the pointers
all_probs, pointers_returned = torch_model_predict(
model=model,
test_loader=dataloader,
class_strategy=class_strategy,
return_data_loader_targets=True,
)

# Check that predictions came back in the same order
# We want all_probs, all_masked_text_tokens, and all_masked_text_indices in corresponding order
assert np.array_equal(
pointers_orig.numpy(), pointers_returned
), "Order of records was shuffled during inference!"

# Build dataframe of results
results = pd.DataFrame(
all_probs, columns=[idx2label[idx] for idx in range(len(all_probs[0]))]
Expand All @@ -109,7 +224,6 @@ def post_process_and_save_msp_results(
all_input_ids,
all_labels,
times,
device,
tokenizer,
num_sample,
max_seq_len,
Expand All @@ -121,6 +235,7 @@ def post_process_and_save_msp_results(
k,
p,
m,
by_sent_segments,
):
"""
This step iterates through the results of running MSP for each document and:
Expand All @@ -134,7 +249,7 @@ def post_process_and_save_msp_results(
"""

# Configure model for inference
model = configure_model_for_inference(model, device)
model = configure_model_for_inference(model)

# Iterate through results on all documents to post-process and save explanations
for s, (results, doc_input_ids, doc_y, doc_time) in enumerate(
Expand All @@ -151,10 +266,9 @@ def post_process_and_save_msp_results(
)

# Get predictions on sample with no masking
full_yhat = predict_with_clf_model(
full_yhat = torch_model_predict_indiv(
model,
sample_input_ids=[doc_input_ids[0:max_seq_len]],
device=device,
class_strategy=class_strategy,
)[0]

Expand Down Expand Up @@ -252,7 +366,7 @@ def post_process_and_save_msp_results(
# Add parameters used to generate results
top_m_df["sample"] = s
top_m_df["P"] = p
top_m_df["K"] = k
top_m_df["K"] = k if not by_sent_segments else "By Sentence Segments"
top_m_df["N"] = n
top_m_df["runtime_secs"] = doc_time

Expand Down
Loading

0 comments on commit e06fb90

Please sign in to comment.