GithubHelp home page GithubHelp logo

xiaoyangmoa / attention-ocr Goto Github PK

View Code? Open in Web Editor NEW

This project forked from emedvedev/attention-ocr

0.0 2.0 0.0 239 KB

A Tensorflow model for text recognition (CNN + seq2seq with visual attention) available as a Python package and compatible with Google Cloud ML Engine.

License: MIT License

Python 100.00%

attention-ocr's Introduction

Attention-based OCR

Visual attention-based OCR model for image recognition with additional tools for creating TFRecords datasets and exporting the trained model with weights as a SavedModel or a frozen graph.

Acknowledgements

This project is based on a model by Qi Guo and Yuntian Deng. You can find the original model in the da03/Attention-OCR repository.

The model

Authors: Qi Guo and Yuntian Deng.

The model first runs a sliding CNN on the image (images are resized to height 32 while preserving aspect ratio). Then an LSTM is stacked on top of the CNN. Finally, an attention model is used as a decoder for producing the final outputs.

OCR example

Installation

pip install aocr

Note: Tensorflow and Numpy will be installed as dependencies. Additional dependencies are PIL/Pillow, distance, and six.

Usage

Create a dataset

To build a TFRecords dataset, you need a collection of images and an annotation file with their respective labels.

aocr dataset ./datasets/annotations-training.txt ./datasets/training.tfrecords
aocr dataset ./datasets/annotations-testing.txt ./datasets/testing.tfrecords

Annotations are simple text files containing the image paths (either absolute or relative to your working dir) and their corresponding labels:

datasets/images/hello.jpg hello
datasets/images/world.jpg world

Train

aocr train ./datasets/training.tfrecords

A new model will be created, and the training will start. Note that it takes quite a long time to reach convergence, since we are training the CNN and attention model simultaneously.

The --steps-per-checkpoint parameter determines how often the model checkpoints will be saved (the default output dir is checkpoints/).

Important: there is a lot of available training options. See the CLI help or the parameters section of this README.

Test and visualize

aocr test ./datasets/testing.tfrecords

Additionally, you can visualize the attention results during testing (saved to out/ by default):

aocr test --visualize ./datasets/testing.tfrecords

Example output images in results/correct:

Image 0 (j/j):

example image 0

Image 1 (u/u):

example image 1

Image 2 (n/n):

example image 2

Image 3 (g/g):

example image 3

Image 4 (l/l):

example image 4

Image 5 (e/e):

example image 5

Export

After the model is trained and a checkpoint is available, it can be exported as either a frozen graph or a SavedModel.

# SavedModel (default):
aocr export ./exported-model

# Frozen graph:
aocr export --format=frozengraph ./exported-model

Load weights from the latest checkpoints and export the model into the ./exported-model directory.

Serving

Exported SavedModel can be served as a HTTP REST API using Tensorflow Serving. You can start the server by running following command:

tensorflow_model_server --port=9000 --rest_api_port=9001 --model_name=yourmodelname --model_base_path=./exported-model

Note: tensorflow_model_server requires a sub-directory with the version number to be present and inside it the files exported in the previous step. So you need to manually move contents of exported-model into exported-model/1.

Now you can send a prediction request to the running server, for example:

curl -X POST \
  http://localhost:9001/v1/models/yourmodelname:predict \
  -H 'cache-control: no-cache' \
  -H 'content-type: application/json' \
  -d '{
  "signature_name": "serving_default",
  "inputs": {
     	"input": { "b64": "/9j/4AAQ==" }
  }
}'

REST API requires binary inputs to be encoded as Base64 and wrapped in an object containing b64 key. See 'Encoding binary values' in Tensorflow Serving documentation

Google Cloud ML Engine

To train the model in the Google Cloud Machine Learning Engine, upload the training dataset into a Google Cloud Storage bucket and start a training job with the gcloud tool.

  1. Set the environment variables:
# Prefix for the job name.
export JOB_PREFIX="aocr"

# Region to launch the training job in.
# Should be the same as the storage bucket region.
export REGION="us-central1"

# Your storage bucket.
export GS_BUCKET="gs://aocr-bucket"

# Path to store your training dataset in the bucket.
export DATASET_UPLOAD_PATH="training.tfrecords"
  1. Upload the training dataset:
gsutil cp ./datasets/training.tfrecords $GS_BUCKET/$DATASET_UPLOAD_PATH
  1. Launch the ML Engine job:
export NOW=$(date +"%Y%m%d_%H%M%S")
export JOB_NAME="$JOB_PREFIX$NOW"
export JOB_DIR="$GS_BUCKET/$JOB_NAME"

gcloud ml-engine jobs submit training $JOB_NAME \
    --job-dir=$JOB_DIR \
    --module-name=aocr \
    --package-path=aocr \
    --region=$REGION \
    --scale-tier=BASIC_GPU \
    --runtime-version=1.2 \
    -- \
    train $GS_BUCKET/$DATASET_UPLOAD_PATH \
    --steps-per-checkpoint=500 \
    --batch-size=512 \
    --num-epoch=20

Parameters

Global

  • log-path: Path for the log file.

Testing

  • visualize: Output the attention maps on the original image.

Exporting

  • format: Format for the export (either savedmodel or frozengraph).

Training

  • steps-per-checkpoint: Checkpointing (print perplexity, save model) per how many steps
  • num-epoch: The number of whole data passes.
  • batch-size: Batch size.
  • initial-learning-rate: Initial learning rate, note the we use AdaDelta, so the initial value does not matter much.
  • target-embedding-size: Embedding dimension for each target.
  • attn-num-hidden: Number of hidden units in attention decoder cell.
  • attn-num-layers: Number of layers in attention decoder cell. (Encoder number of hidden units will be attn-num-hidden*attn-num-layers).
  • no-resume: Create new weights even if there are checkpoints present.
  • max-gradient-norm: Clip gradients to this norm.
  • no-gradient-clipping: Do not perform gradient clipping.
  • gpu-id: GPU to use.
  • use-gru: Use GRU cells instead of LSTM.
  • max-width: Maximum width for the input images. WARNING: images with the width higher than maximum will be discarded.
  • max-height: Maximum height for the input images.
  • max-prediction: Maximum length of the predicted word/phrase.

References

Convert a formula to its LaTex source

What You Get Is What You See: A Visual Markup Decompiler

Torch attention OCR

attention-ocr's People

Contributors

emedvedev avatar ckirmse avatar da03 avatar mattfeury avatar sivanke avatar dos1in avatar imoonkey avatar linjm avatar stickler-ci avatar nektor211 avatar brishtiteveja avatar rmoe avatar rrtaylor avatar pokonski avatar mgaitan avatar mariusmez avatar gammasts avatar alpexjava avatar adamwp avatar

Watchers

James Cloos avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.