GithubHelp home page GithubHelp logo

rsommerfeld / trocr Goto Github PK

View Code? Open in Web Editor NEW
173.0 5.0 26.0 1.35 MB

Powerful handwritten text recognition. A simple-to-use, unofficial implementation of the paper "TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models".

License: MIT License

Python 100.00%
trocr ocr handwritten-text-recognition transformer computer-vision pre-trained-model

trocr's Introduction

Handwritten Character Recognition - an unofficial implementation of the paper

TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models


This is an unofficial implementation of TrOCR based on the Hugging Face transformers library and the TrOCR paper. There is also a repository by the authors of the paper (link). The code in this repository is merely a more simple wrapper to quickly get started with training and deploying this model for character recognition tasks.

 

Results:

Predictions

After training on a dataset of 2000 samples for 8 epochs, we got an accuracy of 96,5%. Both the training and the validation datasets were not completely clean. Otherwise, even higher accuracies would have been possible.

 

Architecture:

TrOCR (TrOCR architecture. Taken from the original paper.)

TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models, Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei, Preprint 2021.

 

 

 



 

1. Setup

Clone the repository and make sure to have conda or miniconda installed. Then go into the directory of the cloned repository and run

conda env create -n trocr --file environment.yml
conda activate trocr

This should install all necessary libraries.

Training without GPU:

It is highly recommended to use a CUDA GPU, but everything also works on cpu. For that, install from file environment-cpu.yml instead.

In case the process terminates with the warning "killed", reduce the batch size to fit into the working memory.

 

2. Using the repository

There are 3 modes, inference, validation and training. All 3 of them can either start with a local model in the right path (see src/constants/paths) or with the pretrained model from huggingface. Inference and Validation use the local model per default, training starts with the huggingface model per default.

 

Inference (Prediction):

python -m src predict <image_files>  # predict image files using the trained local model
python -m src predict data/img1.png data/img2.png  # list all image files
python -m src predict data/*  # also works with shell expansion
python -m src predict data/* --no-local-model  # uses the pretrained huggingface model

Validation:

python -m src validate # uses pretrained local model
python -m src validate --no-local-model # loads pretrained model from huggingface

Training:

python -m src train  # starts with pretrained model from huggingface
python -m src train --local-model  # starts with pretrained local model

 

For validation and training, input images should be in directories train and val and the labels should be in gt/labels.csv. In the csv, each row should consist of image name and then ending, for example img1.png,a (in quotes, if necessary).

It is also pretty straightforward to read labels from somewhere else. For that, just add the necessary code to load_filepaths_and_labels in src/dataset.py.

For choosing a subsample of the train data as validation data, this command can be used

find train -type f | shuf -n <num of val samples> | xargs -I '{}' mv {} val

 

3. Integrating into other projects

If you want to use the predictions as part of a bigger project, you can just use the interface provided by the TrocrPredictor in main. For that make sure to run all code as python modules.

See the following code example:

from PIL import Image
from trocr.src.main import TrocrPredictor

# load images
image_names = ["data/img1.png", "data/img2.png"]
images = [Image.open(img_name) for img_name in image_names]

# directly predict on Pillow Images or on file names
model = TrocrPredictor()
predictions = model.predict_images(images)
predictions = model.predict_for_file_names(image_names)

# print results
for i, file_name in enumerate(image_names):
    print(f'Prediction for {file_name}: {predictions[i]}')

 

4. Adapting the Code

In general, it should be easy to adapt the code for other input formats or use cases.

  • Learning Rate, Batch size, Train Epoch Count, Logging, Word Len: src/configs/constants.py
  • Input Paths, Model Checkpoint Path: src/configs/paths.py
  • Different label format: src/dataset.py : load_filepaths_and_labels

The word len constant is very important. To facilitate batch training, all labels need to be padded to the same length. Some experiments might be needed here. For us, padding to 8 worked well.

If you want to change specifics of the model, you can supply a TrOCRConfig object to the transformers interface. See https://huggingface.co/docs/transformers/model_doc/trocr#transformers.TrOCRConfig for more details.

 

5. Contact

If the setup fails to work, please let me know in a Github issue! Sometimes sub-dependencies update and become incompatible with other dependencies, so the dependency list has to be updated.

Feel free to submit issues with questions about the implementation as well.

For questions about the paper or the architecture, please get in touch with the authors.

trocr's People

Contributors

rsommerfeld avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

trocr's Issues

Model.eval()

Hi!

During training you don’t use model.eval() when you call to validate() function. It could affect to model performance?

Local Model file

Can you please provide the local model file that you have already trained

issue on shape of underlying tensors

Hi, I wanted to give it a try and when trying to do inferance on the base model, I get an error as follows :
IndexError: The shape of the mask [1, 4] at index 0 does not match the shape of the indexed tensor [8, 5] at index 0

I tried with jpg and png image, both have the exact output. Also, using the python example code didn't change anything, error is the same.
Is it because of new tensorflow version ? Can you help me fix the code to make it work ? Probably the definition of the dtype is incorrect.

Error trying to run the model (Runtime error CUDA out of memory)

Hi, I've been struggling for a while trying to run your model, everytime I try to do so it shows the same error message.
First, it showed error Win 1455, like this one:
[WinError 1455] The paging file is too small for this operation to complete. Error loading "C:\ProgramData\Anaconda3\lib\site-packages\torch\lib\caffe2_detectron_ops_gpu.dll" or one of its dependencies.

After searching about it on the web, and modifying parameters I got, what it seems, another version of that same error:
RuntimeError: CUDA out of memory. Tried to allocate 16.00 MiB (GPU 0; 6.00 GiB total capacity; 5.30 GiB already allocated; 0 bytes free; 5.35 GiB reserved in total by PyTorch) If reserved memory is allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Do you have any idea how to manage this situation?
Thanks

Beam Search

Hi,

thanks for the great and simple repo!
Is the generation of the predict a greedy or a beam Search approach?
If you know how to implement an beam search generation, then I would be very happy if you could help me out!

Cheers, Jonas

Fix No module named 'torch'

Command:
from trocr.src.main import TrocrPredictor

Output:

ModuleNotFoundError                       Traceback (most recent call last)
Cell In[5], line 2
      1 from PIL import Image, ImageDraw, ImageFont
----> 2 from trocr.src.main import TrocrPredictor

File c:\Users\ranas\Documents\ML\OCR Based PDF Reader\trocr\src\__init__.py:1
----> 1 from .main import TrocrPredictor
      3 # expose the TrocrPredictor interface to other models
      4 __all__ = ["TrocrPredictor"]

File c:\Users\ranas\Documents\ML\OCR Based PDF Reader\trocr\src\main.py:2
      1 from PIL import Image
----> 2 from torch.utils.data import DataLoader
      4 from .configs import paths
      5 from .configs import constants

ModuleNotFoundError: No module named 'torch'

Accuracy goes to 0.0 frequently

Hi, i have problem with the training of the model. Indeed the gradient seems to explode frequently but not at every training. Here is a graph that represents this problem.

MicrosoftTeams-image

I've tried to print the prediction of the model at each validation step but when the gradient explode the model keeps predicting empty labels.
I'm using a portion of the IAM dataset and my labels are structured this way : file-name.png,¤label¤
I'm using the character '¤' since it does not appear in the dataset and so i can predict double quotes (I've modified the csv reader to take this character to mark out the label).
I've tried to force the download of the pretrained weights at the beginning of each training without effect.
I've also tried to increase the word len without any effect too.
I'm surely missing something but can't see what.

Do you have any idea what could cause the model to run this way ?
Thanks

Loss Function?

Hi, the loss function used in the original paper is CTC. However, it seems that you have used loss function of the Decoder of this class. Any reason as to why.?

Multi-GPU support

I noticed that only one of the GPUs is being used when I train. I tried setting model to torch.nn.DataParallel(model), but I kept getting "RuntimeError: grad can be implicitly created only for scalar outputs". I am not quite familiar enough with Torch's multi-GPU support to fix it just yet. While I am looking, I am hoping someone could help make it support multiple GPUs

score/certainty of prediction

Hi

I was wondering if there is any way to score a prediction, how certain is the model that the output is correct.

Thanks

Issues with training using IAM dataset

I cloned your repo on an Ubuntu 20.04 server, and tested train and inference commands with your included data and they all run correctly.
Next, I ran train command for sanity check using IAM. For training data, I have first partition the IAM into train and val folder (90%+10% random split). For instance, in the train folder:
a01-000u-00.png
a01-000u-01.png
a01-000u-02.png
a01-000u-03.png
a01-000u-05.png
a01-000u-06.png
a01-000x-02.png
a01-000x-03.png
a01-000x-04.png
.....

Then in the gt folder, I created a labels.csv file using the same format as you included in your repo. here are a few lines at the beginning of the file:
a01-000u-00.png,A MOVE to stop Mr. Gaitskell from
a01-000u-01.png,nominating any more Labour life Peers
a01-000u-02.png,is to be made at a meeting of Labour
a01-000u-03.png,Ps tomorrow . Mr. Michael Foot has
a01-000u-04.png,put down a resolution on the subject
a01-000u-05.png,and he is to be backed by Mr. Will
a01-000u-06.png,P for Manchester Exchange .
a01-000x-00.png,A MOVE to stop Mr. Gaitskell from nominating
a01-000x-01.png,any more Labour life Peers is to be made at a
a01-000x-02.png,Ps tomorrow . Mr. Michael
a01-000x-03.png,Foot has put down a resolution on the subject
a01-000x-04.png,and he is to be backed by Mr. Will Griffiths ,
a01-000x-05.png,P for Manchester Exchange .

In other words, labels for all files in train and val folders are all included in the gt/labels.csv. Is that format correct for gt/labels.csv?

After I converted the IAM data to your above format, I ran the command python -m src train, which produces error as below:

Traceback (most recent call last):
File "/home/shenw/workspace/anaconda3/envs/trocr-2/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/shenw/workspace/anaconda3/envs/trocr-2/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/shenw/workspace/sandbox/htr/trocr-2/src/main.py", line 6, in
main()
File "/home/shenw/workspace/anaconda3/envs/trocr-2/lib/python3.9/site-packages/typer/main.py", line 214, in call
return get_command(self)(*args, **kwargs)
File "/home/shenw/workspace/anaconda3/envs/trocr-2/lib/python3.9/site-packages/click/core.py", line 1130, in call
return self.main(*args, **kwargs)
File "/home/shenw/workspace/anaconda3/envs/trocr-2/lib/python3.9/site-packages/click/core.py", line 1055, in main
rv = self.invoke(ctx)
File "/home/shenw/workspace/anaconda3/envs/trocr-2/lib/python3.9/site-packages/click/core.py", line 1657, in invoke
return _process_result(sub_ctx.command.invoke(sub_ctx))
File "/home/shenw/workspace/anaconda3/envs/trocr-2/lib/python3.9/site-packages/click/core.py", line 1404, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/home/shenw/workspace/anaconda3/envs/trocr-2/lib/python3.9/site-packages/click/core.py", line 760, in invoke
return __callback(*args, **kwargs)
File "/home/shenw/workspace/anaconda3/envs/trocr-2/lib/python3.9/site-packages/typer/main.py", line 500, in wrapper
return callback(**use_params) # type: ignore
File "/home/shenw/workspace/sandbox/htr/trocr-2/src/cli.py", line 10, in train
main_train(local_model)
File "/home/shenw/workspace/sandbox/htr/trocr-2/src/main.py", line 30, in main_train
train_dataset = HCRDataset(paths.train_dir, processor)
File "/home/shenw/workspace/sandbox/htr/trocr-2/src/dataset.py", line 54, in init
self.image_name_list, self.label_list = load_filepaths_and_labels(data_dir)
File "/home/shenw/workspace/sandbox/htr/trocr-2/src/dataset.py", line 41, in load_filepaths_and_labels
assert file_name in label_dict, f"No label for image '{file_name}'"
AssertionError: No label for image 'n02-045-01.png'

It appears that the label for n02-045-01.png is missing in gt/labels.csv. So I checked the label for n02-045-01.png in the gt/labels.csv, and it's there and I can display it correctly. Any idea why it produces above errors?
Your help will be greatly appreciated.

How to train a model from scratch?

Hello,

Sorry if I did not understand the instructions correctly, but is there a way to train a model from scratch on the IAM dataset? i.e. without loading TrOCR pretrained weights.

What kind of detector would you recommend?

This model does recognition. Thus, in order to apply it to arbitrary images, one needs a boundingbox detector.

Which one would you recommend?

I am currently using the one from paddleocr.

Cheers,
Chris

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.