-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Problems reproducing results on "I2B2 2006: Smoker Identification" dataset #7
Comments
Hi Arne,
Thank you for your carefully laid out issue. What you tried seems
reasonable. This repository contains code transferred from the original
implemention - I will cross check with the original and re-verify things on
our end. Do you have a link to the forked implementation with the described
changes?
…On Fri, Mar 6, 2020, 6:19 AM ArneD ***@***.***> wrote:
First of all, thanks for the interesting paper and code.
However, I am having problems reproducing the results reported in the
paper.
I tried reproducing the results reported on the "I2B2 2006: Smoker
Identification" dataset.
My results after 1000 epochs of training (using the architecture:
DocumentBertLSTM) are (on the test set):
Metric PAST_SMOKER CURRENT_SMOKER NON-SMOKER UNKNOWN
Precision 0.36363636363636365 0.0 0.4230769230769231 0.7704918032786885
Recall 0.36363636363636365 0.0 0.6875 0.746031746031746
F1 0.36363636363636365 0.0 0.5238095238095238 0.7580645161290323
Micro-F1 0.8004807692307693
Macro-F1 0.7331169082125604
So only a micro-f1 score of around 0.80 was obtained (versus 0.981
reported in the paper)
The n2c2_2006_train_config.ini config file was used, along with the
train_n2c2_2006.py script.
I used ClinicalBERT (https://github.com/EmilyAlsentzer/clinicalBERT) as
BERT model.
In order for the script to work, I had to change lines 34-35 in
bert_document_classification/examples/ml4health_2019_replication/data.py
to
`if partition == 'train':
with open("data/smokers_surrogate_%s_all_version2.xml" % partition) as raw:
file = raw.read().strip()
elif partition == 'test':
with open("data/smokers_surrogate_%s_all_groundtruth_version2.xml" % partition) as raw:
file = raw.read().strip() `
(I could also have renamed the files)
Also, at line 14 of the n2c2_2006_train_config.ini config file I found the
following:
freeze_bert: False
This is confusing, because at lines 35-40 of
bert_document_classification/bert_document_classification/document_bert_architectures.py:
with torch.set_grad_enabled(False):
for doc_id in range(document_batch.shape[0]):
bert_output[doc_id][:self.bert_batch_size] = self.dropout(self.bert(document_batch[doc_id][:self.bert_batch_size,0],
token_type_ids=document_batch[doc_id][:self.bert_batch_size,1],
attention_mask=document_batch[doc_id][:self.bert_batch_size,2])[1])
Thus ClinicalBERT is always freezed, independent of the freeze_bert option
in the config File. This makes sense, because in the paper, BERT was
reported to have been freezed.
Next, I had to do simulations on a single GPU, because on two GPU's the
training loss did not decrease.
How many GPU's were used in the paper? Because the number of GPU's has an
effect on the effective batch size, it could affect performance.
I tried increasing and decreasing the batch size, but this did not result
in better performance.
On the other hand, using the pretrained models (from
bert_document_classification.models import SmokerPhenotypingBert), an f1
score of 0.981 ( as reported in the paper) was obtained when evaluated on
the test set.
As a side note, it could be made clearer that only the first
bert_batch_size*510 (defaults to 3570) tokens of the the documents are used
for document classification (i.e., only bert_batch_size sequences are fed
to the LSTM and classification layers).
I used pytorch version 1.4.
—
You are receiving this because you are subscribed to this thread.
Reply to this email directly, view it on GitHub
<#7?email_source=notifications&email_token=ADJ4TBSCFHCSBO5Y5C25NITRGDL3HA5CNFSM4LC6ENC2YY3PNVWWK3TUL52HS4DFUVEXG43VMWVGG33NNVSW45C7NFSM4ITCJGAQ>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ADJ4TBWUNC4NZBOQ6KQGJX3RGDL3HANCNFSM4LC6ENCQ>
.
|
Hi, And made two changes: 1)
2)
This because, as explained above, on 2 GPU's the training loss does not decrease. I can push my code to a repository if this would still be necessary. |
I pushed my code to https://github.com/ArneDefauw/BERT_doc_classification. The only meaningful changes to the master branch of this repository are the ones described above. I also add results (log files) of two experiments: one on the "I2B2 2006: Smoker Identification" dataset (as described above) and one on the "20newsgroups" dataset of scikit learn (results are ok, but not so much better than, for example, using tf-idf+SVM). Config files for the experiments on the 20newsgroups dataset is also included. You will also find a Dockerfile I used for building the docker image from where I run the experiments. Notebooks with only the essential parts of the code are also included (I switched to the more recent transformers library there+some small modifications with no effect on the results). |
After some analysis of the code, I may have found the reason for the discrepancy between the results reported in the paper, and the results I found: I refer to lines 38-74 in and lines 142-143 in https://github.com/ArneDefauw/BERT_doc_classification/blob/master/bert_document_classification/bert_document_classification/document_bert.py when unfreezing the last encoder layer of BERT, I obtain, on the "I2B2 2006: Smoker Identification" dataset a f1 score of 0.918 after only 160 epochs (with training loss still decreasing). (using the BERT+LSTM architecture) Note that this is confusing, as in the paper I find the following phrase: "In our proposed architectures (Figure 1), we use a frozen (no parameter fine-tuning) instance of ClinicalBERT..." I guess that for obtaining the results reported in the paper, a fully unfrozen version of BERT is used. I only unfroze the last layers, because the fully unfrozen BERT+LSTM layers did not fit on the GPU's used (11Gb), and because of faster training. |
Arne,
Thank you again for the continued detailed analysis. What you quoted indeed
disagrees with our original implementation. This may very well be be an
error in writing/interpretation on our part when analyzing our iterated
experiments. Additionally, finetuning agrees with other reported work that
the CLS representation, without finetuning, is not a meaningful segment
level represention.
I am currently on holiday and will address this on return.
…On Wed, Mar 11, 2020, 4:24 PM ArneD ***@***.***> wrote:
After some analysis of the code, I may have found the reason for the
discrepancy between the results reported in the paper, and the results I
found:
It seems necessary to unfreeze the last layers of BERT.
I refer to lines 38-74 in
https://github.com/ArneDefauw/BERT_doc_classification/blob/master/bert_document_classification/bert_document_classification/document_bert_architectures.py
and lines 142-143 in
https://github.com/ArneDefauw/BERT_doc_classification/blob/master/bert_document_classification/bert_document_classification/document_bert.py
when unfreezing the last encoder layer of BERT, I obtain, on the "I2B2
2006: Smoker Identification" dataset a f1 score of 0.918 after only 160
epochs (with training loss still decreasing). (using the BERT+LSTM
architecture)
Note that this is confusing, as in the paper I find the following phrase:
*"In our proposed architectures (Figure 1), we use a frozen (no parameter
fine-tuning) instance of ClinicalBERT..."*
I guess that for obtaining the results reported in the paper, a fully
unfrozen version of BERT is used.
Is this correct?
I only unfroze the last layers, because the fully unfrozen BERT+LSTM
layers did not fit on the GPU's used (11Gb), and because of faster training.
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
<#7 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ADJ4TBR2MDBUW6EWG4MRTD3RG6UL3ANCNFSM4LC6ENCQ>
.
|
First of all, thanks for the interesting paper and code.
However, I am having problems reproducing the results reported in the paper.
I tried reproducing the results reported on the "I2B2 2006: Smoker Identification" dataset.
My results after 1000 epochs of training (using the architecture: DocumentBertLSTM) are (on the test set):
Metric PAST_SMOKER CURRENT_SMOKER NON-SMOKER UNKNOWN
Precision 0.36363636363636365 0.0 0.4230769230769231 0.7704918032786885
Recall 0.36363636363636365 0.0 0.6875 0.746031746031746
F1 0.36363636363636365 0.0 0.5238095238095238 0.7580645161290323
Micro-F1 0.8004807692307693
Macro-F1 0.7331169082125604
So only a micro-f1 score of around 0.80 was obtained (versus 0.981 reported in the paper)
The n2c2_2006_train_config.ini config file was used, along with the train_n2c2_2006.py script.
I used ClinicalBERT (https://github.com/EmilyAlsentzer/clinicalBERT) as BERT model.
In order for the script to work, I had to change lines 34-35 in
bert_document_classification/examples/ml4health_2019_replication/data.py
to
(I could also have renamed the files)
Also, at line 14 of the n2c2_2006_train_config.ini config file I found the following:
freeze_bert: False
This is confusing, because at lines 35-40 of bert_document_classification/bert_document_classification/document_bert_architectures.py:
Thus ClinicalBERT is always freezed, independent of the freeze_bert option in the config File. This makes sense, because in the paper, BERT was reported to have been freezed.
Next, I had to do simulations on a single GPU, because on two GPU's the training loss did not decrease.
How many GPU's were used in the paper? Because the number of GPU's has an effect on the effective batch size, it could affect performance.
I tried increasing and decreasing the batch size, but this did not result in better performance.
On the other hand, using the pretrained models (from bert_document_classification.models import SmokerPhenotypingBert), an f1 score of 0.981 ( as reported in the paper) was obtained when evaluated on the test set.
As a side note, it could be made clearer that only the first bert_batch_size*510 (defaults to 3570) tokens of the the documents are used for document classification (i.e., only bert_batch_size sequences are fed to the LSTM and classification layers).
I used pytorch version 1.4.
The text was updated successfully, but these errors were encountered: