Skip to content

Commit

Permalink
checkin fast_r2d2
Browse files Browse the repository at this point in the history
  • Loading branch information
aaron.hx committed Mar 2, 2022
1 parent 215ec6f commit 71477ca
Show file tree
Hide file tree
Showing 47 changed files with 30,402 additions and 874 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,6 @@ cache_dict

# train
logs

model_data/
scripts/
207 changes: 146 additions & 61 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,94 +1,179 @@
# R2D2/Fast-R2D2

This is the official code for paper titled "[R2D2: Recursive Transformer based on Differentiable Tree for Interpretable Hierarchical Language Modeling](https://arxiv.org/abs/2107.00967)".
The current repo is refactored from the original version used in the paper. If meet any issue, please feel free to feedback.
This is the official code for paper titled "[R2D2: Recursive Transformer based on Differentiable Tree for Interpretable Hierarchical Language Modeling](https://arxiv.org/abs/2107.00967)" and "[Fast-R2D2: A Pretrained Recursive Neural Network based on Pruned CKY for Grammar Induction and Text Representation](https://arxiv.org/abs/2203.00281)".

Our new work Fast-R2D2 will be released soon~

## Data
## Requires
gcc >= 5.0
pytorch == 1.9.0+cu111
cuda == 11.1

## Train
For other versions of pytorch, please make sure the corresponding version of CUDA has been installed.

## Setup

export PATH="/usr/local/gcc-version/bin:$PATH"
export CXX=g++

python setup.py build_ext --inplace

### Multi-GPUs
Check if r2d2lib is correctly compiled:
python -m unittest unittests/cuda_unittest.py

For training from scratch in a single machine with multiple GPUs, please follow scripts below:
## Dataset
WikiText103: https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/
GLUE:

## Dataset preprocess

Split WikiText103 to sentences and remove special tokens like @-@, @.@.
```bash
CORPUS_PATH=
OUTPUT_PATH=
NODE_NUM=

python -m torch.distributed.launch \
--nproc_per_node $NODE_NUM R2D2_trainer.py --batch_size 16 \
--min_len 2 \
--max_batch_len 512 \
--max_line -1 \
--corpus_path $CORPUS_PATH \
--vocab_path data/en_bert/bert-base-uncased-vocab.txt \
--config_path data/en_bert/config.json \
--epoch 60 \
--output_dir $OUTPUT_PATH \
--window_size 4 \
--input_type txt
```

### Single-GPU
python -m utils.data_processor --corpus_path CORPUS_PATH --output_path OUTPUT_PATH --task_type split
```

Convert raw text to ids
```bash
CORPUS_PATH=
OUTPUT_PATH=
CORPUS_PATH=path to your corpus
OUTPUT_PATH=path to processed corpus
CONFIG_PATH=
VOCAB_DIR=

python trainer.R2D2_trainer \
--batch_size 16 \
--min_len 2 \
--max_batch_len 512 \
--max_line -1 \
--corpus_path $CORPUS_PATH \
--vocab_path data/en_bert/bert-base-uncased-vocab.txt \
--config_path data/en_bert/config.json \
--epoch 10 \
--output_dir $OUTPUT_PATH \
--input_type txt
python -m utils.data_processor --corpus_path $CORPUS_PATH --output_path $OUTPUT_PATH --task_type tokenizing \
--vocab_dir $VOCAB_DIR --config_path $CONFIG_PATH
```

## Train
Pretrain with span constraints
```bash
VOCAB_DIR=data/en_config
CONFIG_PATH=data/en_config/fast_r2d2.json
PROCESSED_CORPUS=output corpus process at the last step(tokenized and converted to ids)
OUTPUT_DIR=output model dir

cd trainer

python -m torch.distributed.launch --nproc_per_node 8 parser_r2d2_trainer.py \
--batch_size 96 --max_batch_len 1536 \
--lr 5e-5 --parser_lr 1e-2 \
--vocab_dir $VOCAB_DIR \
--config_path $CONFIG_PATH \
--max_grad_norm 1.0 --input_type ids \
--corpus_path $PROCESSED_CORPUS \
--output_dir $OUTPUT_DIR \
--num_samples 256 --log_step 500 --epochs 60 \
--seperator " "
```

## Evaluation
## Grammar Induction

Evaluating the bidirectional language model task.
```bash
CORPUS_PATH=path to training corpus
VOCAB_DIR=directory of vocab.txt
MODEL_PATH=path to model.bin
CONFIG_PATH=path to config.json

python lm_eval_buckets.py \
--model_name R2D2 \
--dataset test \
--config_path CONFIG_PATH \
--model_path MODEL_PATH \
--vocab_dir VOCAB_DIR \
--corpus_path CORPUS_PATH
# generate trees in ptb format

python -m eval.r2d2_ptb_printer \
--model_path path_to_r2d2_model \
--parser_path path_to_r2d2_parser \
--parser_only --in_word \
--config_path \
path_to_your_config \
--corpus_path \
data/wsj/wsj_test_raw.txt \
--output_path \
path_to_output_file

```

For evaluating F1 score on constituency trees, please refer to https://github.com/harvardnlp/compound-pcfg/blob/master/compare_trees.py

Evaluating compatibility with dependency trees:
Download WSJ dataset and convert to dependency trees by Stanford CoreNLP(https://stanfordnlp.github.io/CoreNLP/).
As WSJ is not a free dataset, it's not included in our project. Please refer to the files in data/predict_trees for detail format of tree induced.
```bash
R2D2_TREE=path to output file generated by r2d2_ptb_printer

python compare_trees.py --tree1 path_to_gold_tree --tree2 R2D2_TREE
```


## GLUE tasks

finetune GLUE on 8*A100

```bash
TASK_NAME=SST-2/CoLA/QQP/MNLI

python -m torch.distributed.launch --nproc_per_node 8 trainer/fast_r2d2_glue_trainer.py \
--max_grad_norm 1.0 --lr 5e-5 --parser_lr 1e-3 \
--config_path path_to_config \
--vocab_dir path_to_vocab_dir \
--task_type sst-2 --glue_dir path_to_glue_dir/$TASK_NAME --max_batch_len 1536 \
--max_batch_size 8 --output_dir path_to_model_save_dir \
--epochs 10 --pretrain_dir path_to_pretrain_model_dir \
--log_step 50 --num_samples 256 -sampler random --apex_mode O0
```

evaluation

python eval_tree.py \
--pred_tree_path path_to_tree_induced \
--ground_truth_path path_to_dependency_trees
--vocab_dir VOCAB_DIR
```bash

TASK_TYPE=sst-2/mnli/cola/qqp
EVAL_TURN=number of the turn to evaluate
MODE= forced or cky

python -m eval.eval_fast_r2d2 \
--model_dir \
path_to_dir_of_fine_tuned_models \
--config_path \
path_to_config \
--vocab_dir \
dir_to_vocab \
--task_type \
TASK_TYPE \
--glue_dir \
dir_of_glue_task \
--max_batch_len \
1024 \
--max_batch_size \
32 \
--turn \
$EVAL_TURN \
--r2d2_mode \
$MODE
```

## On-going work
## Evaluate speed

Sample sentences from WikiText103(tokenized and converted to ids).

1. Re-implement whole model to increase GPU utility ratio.
2. Pre-train on large corpus
python -m utils.data_processor --task_type sampling --corpus_path path_to_wiki103_ids --output_path path_to_wiki103_outputs

```bash
LEN_RANGE=50/100/200/500
R2D2_MODE=forced/cky
BATCH_SIZE=50

python eval/eval_speed.py \
--model_dir \
path_to_pretrain_model_dir \
--config_path \
path_to_pretrain_model_dir/config.json \
--vocab_dir \
path_to_pretrain_model_dir \
--corpus_path \
model_data/en_wiki/wiki103.speed.ids.$LEN_RANGE \
--max_batch_len \
2500000 \
--input_type \
ids \
--model \
fast-r2d2 \
--turn \
59 \
--r2d2_mode \
$R2D2_MODE \
--batch_size \
$BATCH_SIZE
```

## Contact

[email protected] and [email protected]
[email protected]
42 changes: 42 additions & 0 deletions cuda_extension/binding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// coding=utf-8
// Copyright (c) 2022 Ant Group
// Author: Xiang Hu

#include <torch/torch.h>
#include "r2d2lib.h"

// part3:pybind11 (将python与C++11进行绑定, 注意这里的forward,backward名称就是后来在python中可以引用的方法名)
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
std::string name = std::string("TreeNode");
py::class_<ExportNode>(m, name.c_str())
.def_readwrite("cache_id", &ExportNode::cache_id)
.def_readwrite("left_i", &ExportNode::left_i)
.def_readwrite("right_i", &ExportNode::right_i)
.def_readwrite("left_j", &ExportNode::left_j)
.def_readwrite("right_j", &ExportNode::right_j)
.def_readwrite("left_idx", &ExportNode::left_idx)
.def_readwrite("right_idx", &ExportNode::right_idx)
.def_readwrite("log_p", &ExportNode::log_p);
name = std::string("TableCell");
py::class_<ExportCell>(m, name.c_str())
.def_readwrite("best_tree_idx", &ExportCell::best_tree_idx)
.def_readwrite("nodes", &ExportCell::nodes);
name = std::string("TablesManager");
py::class_<TablesManager>(m, name.c_str())
.def(py::init([](bool directional, size_t window_size, size_t beam_size)
{ return new TablesManager(directional, window_size, beam_size); }))
.def("encoding_start", &TablesManager::encoding_start)
.def("step", &TablesManager::step)
.def("set_merge_trajectories", &TablesManager::set_merge_trajectories)
.def("beam_select", &TablesManager::beam_select)
.def("step_over", &TablesManager::step_over)
.def("encoding_over", &TablesManager::encoding_over)
.def("current_step", &TablesManager::current_step)
.def("finished", &TablesManager::finished)
.def("prepare_bilm", &TablesManager::prepare_bilm)
.def("total_len", &TablesManager::total_len)
.def("batch_size", &TablesManager::batch_size)
.def("recover_sampled_trees", &TablesManager::recover_sampled_trees)
.def("dump_cells", &TablesManager::dump_cells);
}
28 changes: 28 additions & 0 deletions cuda_extension/common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// coding=utf-8
// Copyright (c) 2022 Ant Group
// Author: Xiang Hu

#ifndef R2D2_COMMON_H
#define R2D2_COMMON_H
#include <stdio.h>
#include <c10/cuda/CUDACachingAllocator.h>

static void HandleError(cudaError_t err,
const char *file,
int line)
{
if (err != cudaSuccess)
{
printf("%s in %s at line %d\n", cudaGetErrorString(err),
file, line);
exit(EXIT_FAILURE);
}
}
#define HANDLE_ERROR(fn_call) \
if (fn_call != cudaSuccess) \
{ \
c10::cuda::CUDACachingAllocator::emptyCache(); \
HandleError(fn_call, __FILE__, __LINE__); \
}

#endif
Loading

0 comments on commit 71477ca

Please sign in to comment.