Code and checkpoints for training the transformer-based Table QA models introduced in the paper TAPAS: Weakly Supervised Table Parsing via Pre-training.
The repository uses protocol buffers, and requires the protoc
compiler to run.
You can download the latest binary for your OS here.
On Ubuntu/Debian, it can be installed with:
sudo apt-get install protobuf-compiler
Afterwards, clone and install the git repository:
git clone https://github.com/google-research/tapas
cd tapas
pip install -e .
To run the test suite we use the tox library which can be run by calling:
pip install tox
tox
The pre-trained Tapas checkpoints can be downloaded here:
The first two models are pre-trained on the Mask-LM task and the last two on the Mask-LM task first and the SQA second.
You also need to download the task data for the fine-tuning tasks:
We need to create the TF examples before starting the training. For example, for SQA that would look like:
python tapas/run_task_main.py \
--task="SQA" \
--input_dir="${sqa_data_dir}" \
--output_dir="${output_dir}" \
--bert_vocab_file="${tapas_data_dir}/vocab.txt" \
--mode="create_data"
Afterwards, training can be started by running:
python tapas/run_task_main.py \
--task="SQA" \
--output_dir="${output_dir}" \
--init_checkpoint="${tapas_data_dir}/model.ckpt" \
--bert_config_file="${tapas_data_dir}/bert_config.json" \
--mode="train" \
--use_tpu
This will use the preset hyper-paremters set in hparam_utils.py
.
It's recommended to start a separate eval job to continuously produce predictions for the checkpoints created by the training job. Alternatively, you can run the eval job after training to only get the final results.
python tapas/run_task_main.py \
--task="SQA" \
--output_dir="${output_dir}" \
--init_checkpoint="${tapas_data_dir}/model.ckpt" \
--bert_config_file="${tapas_data_dir}/bert_config.json" \
--mode="predict_and_evaluate"
Another tool to run experiments is tapas_classifier_experiment.py
. It's more
flexible than run_task_main.py
but also requires setting all the hyper-parameters
(via the respective command line flags).
Unfortunately we cannot release the pre-training data. The code for
creating the pre-training TF examples can be found in the class
ToPretrainingTensorflowExample
in tf_example_utils.py
. The implementation
of the model can be found in tapas_pretraining_experiment.py
and
tapas_pretraining_model.py
.
By default, SQA will evaluate using the reference answers of the previous
questions. The number in the paper (Table 5) are computed
using the more realistic setup
where the previous answer are model predictions. run_task_main.py
will output
additional prediction files for this setup as well if run on GPU.
For the official evaluation results one should convert the TAPAS predictions to
the WTQ format and run the official evaluation script. This can be done using
convert_predictions.py
.
As discussed in the paper our code will compute evaluation metrics that deviate from the official evaluation script (Table 3 and 10).
TAPAS is essentialy a BERT model and thus has the same requirements.
This means that training the large model with 512 sequence length will
require a TPU.
You can use the option max_seq_length
to create shorter sequences. This will
reduce accuracy but also make the model trainable on GPUs.
Another option is to reduce the batch size (train_batch_size
),
but this will likely also affect accuracy.
We added an options gradient_accumulation_steps
that allows you to split the
gradient over multiple batches.
Evaluation with the default test batch size (32) should be possible on GPU.
You can cite the paper to appear at ACL:
@inproceedings{49053,
title = {Tapas: Weakly Supervised Table Parsing via Pre-training},
author = {Jonathan Herzig and Paweł Krzysztof Nowak and Thomas Müller and Francesco Piccinno and Julian Martin Eisenschlos},
year = {2020},
URL = {https://arxiv.org/abs/2004.02349},
note = {to appear},
booktitle = {Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)},
address = {Seattle, Washington, United States}
}
This is not an official Google product.
For help or issues, please submit a GitHub issue.