Repository to fine-tune a BERT-base multi-label/multi-class classifier, based on HuggingFace library. The repository includes a Flask API wrapper for inference.
To install the repository please run the following command:
git clone https://github.com/JulesBelveze/BERT-multi-label-classifier.git
The repository uses Poetry as a package manager (see full documentation here). To install the required packages please run the following commands:
python3 -m venv .venv/bert-mlc
source .venv/bert-mlc/bin/activate
poetry install
This repo uses neptune.ai to manage experiments. We invite you to look at their documentation if needed.
models/
: folder containing custom modelsutils/
: folder containing function utilitiesmain.py
: main file to runtrain.py
: file containing the training procedureeval.py
: file containing the evaluation procedureapp.py
: file containing the Flask appinferer.py
: file containing the model infererpoetry.lock
: Poetry filepyproject.toml
: Poetry filerequirements_inference.txt
: required packages for inferenceDockerfile
: file to run the API as a docker image
- multi-class: you can download it here
- multi-label: Toxic Comment Classification Challenge | Kaggle
We provide customisation of four different models: BERT, Roberta, XLMRoberta and Distilbert.
The model is an adaptation of the BertForSequenceClassification
model of HuggingFace to handle multi-label. The key modification here is the modification of loss function.
The model used is basically a MLP on top of a BERT model. Once again, the custom model provided extends the BertForSequenceClassification
model of HuggingFace to integrate the class weights in the loss function.
The inferrer only supports single input inference. It handles all the processing steps required to feed the text into the classification model. It can be used in the following way:
model_infer = ModelInferer(config=config, checkpoint_path=checkpoint_path, quantize=True)
model_infer.predict("I hate you from more than you can imagine")
We also provide a Flask API that encapsulates the inferrer as well as a way Dockerized the app for production usage.