GithubHelp home page GithubHelp logo

tf2qa's Introduction

7th place solution to the TensorFlow 2.0 Question Answering competition

Solution summary: https://www.kaggle.com/c/tensorflow2-question-answering/discussion/127259

envorinment: python 3.6+, tensorflow 1.15

Files

Most of the model code are based on bert joint. Evaluation code are based on official NQ metric, but modified for this competition.

  • prepare_nq_data.py: pre-processing
  • jb_train_tpu.py: training on TPU
  • jb_pred_tpu.py: inference and evaluation of dev set on TPU
  • ensemble_and_tune.py: tuning ensemble weights and thresholds
  • 7th-place-submission.ipynb: inference notebook, same as this
  • vocab_cased-nq.txt: vocab file for cased model with special NQ tokens added
  • bert_config_cased.json: config file for cased model

scripts for the 3 single models

model c: wwm, neg sampling, max_contexts=200, dev 64.5

# pre-processing (this step does not require TPU and could be distributed over multiple processes)
export do_lower_case=True
export max_contexts=200
export tfrecord_dir=fix_top_level_bug_max_contexts_200_0.01_0.04
for shard in {0..49} 
do 
	python3 prepare_nq_data.py --do_lower_case=$do_lower_case --tfrecord_dir=$tfrecord_dir --include_unknowns_answerable=0.01 --include_unknowns_unanswerable=0.04 --shard=$shard --max_contexts=$max_contexts
done

# training
export TPU_NAME=node-1
export train_batch_size=64
export learning_rate=4e-5
export model_suffix=_wwm_fix_top_level_bug_max_contexts_200_0.01_0.04
export train_precomputed_file=gs://<your_bucket>/tfrecords/fix_top_level_bug_max_contexts_200_0.01_0.04/nq-train.tfrecords-*
export init_checkpoint=gs://<your_bucket>/wwm_uncased_L-24_H-1024_A-16/bert_model.ckpt
python3 jb_train_tpu.py --tpu=$TPU_NAME --model_suffix=${model_suffix} --train_batch_size=${train_batch_size} --learning_rate=${learning_rate} --train_precomputed_file=$train_precomputed_file --init_checkpoint=$init_checkpoint --num_train_epochs=1 

# evaluation (ckpt 9500 turned out to be the best)
export MODEL_SUFFIX=_wwm_fix_top_level_bug_max_contexts_200_0.01_0.04-64-4.00E-05
export CKPT_FROM=8000
export CKPT_TO=10000
export doc_stride=256
export do_lower_case=True
python3 jb_pred_tpu.py --tpu=$TPU_NAME --doc_stride=$doc_stride --model_suffix=$MODEL_SUFFIX --ckpt_from=$CKPT_FROM --ckpt_to=$CKPT_TO --eval_set=dev --do_predict=True --do_lower_case=$do_lower_case

model d: wwm, neg sampling, stride=192, dev 63.8

# pre-processing (this step does not require TPU and could be distributed over multiple processes)
export do_lower_case=True
export doc_stride=192
export tfrecord_dir=stride_192_0.01_0.04
for shard in {0..49} 
do 
	python3 prepare_nq_data.py --do_lower_case=$do_lower_case --tfrecord_dir=$tfrecord_dir --include_unknowns_answerable=0.01 --include_unknowns_unanswerable=0.04 --shard=$shard --doc_stride=$doc_stride
done

# training
export TPU_NAME=node-1
export train_batch_size=64
export learning_rate=2e-5
export model_suffix=_wwm_stride_192_neg_0.01_0.04
export train_precomputed_file=gs://<your_bucket>/tfrecords/stride_192_0.01_0.04/nq-train.tfrecords-*
export init_checkpoint=gs://<your_bucket>/wwm_uncased_L-24_H-1024_A-16/bert_model.ckpt
python3 jb_train_tpu.py --tpu=$TPU_NAME --model_suffix=${model_suffix} --train_batch_size=${train_batch_size} --learning_rate=${learning_rate} --train_precomputed_file=$train_precomputed_file --init_checkpoint=$init_checkpoint --num_train_epochs=1 

# evaluation (ckpt 7000 turned out to be the best)
export MODEL_SUFFIX=_wwm_stride_192_neg_0.01_0.04-64-2.00E-05
export CKPT_FROM=5000
export CKPT_TO=8000
export doc_stride=256
export do_lower_case=True
python3 jb_pred_tpu.py --tpu=$TPU_NAME --doc_stride=$doc_stride --model_suffix=$MODEL_SUFFIX --ckpt_from=$CKPT_FROM --ckpt_to=$CKPT_TO --eval_set=dev --do_predict=True --do_lower_case=$do_lower_case

model e: wwm, neg sampling, cased, dev 63.3

# pre-processing (this step does not require TPU and could be distributed over multiple processes)
export do_lower_case=False
export tfrecord_dir=fix_top_level_bug_cased_0.01_0.04
for shard in {0..49} 
do 
	python3 prepare_nq_data.py --do_lower_case=$do_lower_case --tfrecord_dir=$tfrecord_dir --include_unknowns_answerable=0.01 --include_unknowns_unanswerable=0.04 --shard=$shard
done

# training
export TPU_NAME=node-1
export train_batch_size=64
export learning_rate=4.5e-5
export model_suffix=_wwm_cased_fix_top_level_bug_0.01_0.04
export train_precomputed_file=gs://<your_bucket>/tfrecords/fix_top_level_bug_cased_0.01_0.04/nq-train.tfrecords-*
export init_checkpoint=gs://<your_bucket>/wwm_cased_L-24_H-1024_A-16/bert_model.ckpt
export do_lower_case=False
python3 jb_train_tpu.py --tpu=$TPU_NAME --model_suffix=${model_suffix} --train_batch_size=${train_batch_size} --learning_rate=${learning_rate} --train_precomputed_file=$train_precomputed_file --init_checkpoint=$init_checkpoint --num_train_epochs=1 --do_lower_case=${do_lower_case}

# evaluation (ckpt 8500 turned out to be the best)
export MODEL_SUFFIX=_wwm_cased_fix_top_level_bug_0.01_0.04-64-4.50E-05
export CKPT_FROM=6000
export CKPT_TO=8500
export doc_stride=256
export do_lower_case=False
python3 jb_pred_tpu.py --tpu=$TPU_NAME --doc_stride=$doc_stride --model_suffix=$MODEL_SUFFIX --ckpt_from=$CKPT_FROM --ckpt_to=$CKPT_TO --eval_set=dev --do_predict=True --do_lower_case=$do_lower_case

tf2qa's People

Contributors

boliu61 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

jt120 cytsinghua

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.