GithubHelp home page GithubHelp logo

xiaofengshi / image2katex Goto Github PK

View Code? Open in Web Editor NEW
279.0 10.0 75.0 71.06 MB

公式图片ocr,输入图片输出对应的latex表达式

License: Other

Python 13.06% Makefile 0.04% Shell 0.02% HTML 86.88%
image2latex tensorflow-gpu

image2katex's Introduction

Img2Katex模型说明V2

简介

本仓库代码实现将公式图像转换为katex输出,是整个ocr项目中,该部分是最难的一部分,该仓库是对原有项目该功能的升级。

项目演示

输入公式图像数据


100c0c140c

预测输出latex

g _ { \alpha \beta } = \frac { ( 1 + | \psi | ^ { 2 } ) \delta _ { \alpha \beta } - \bar { \psi } _ { \alpha } \psi _ { \beta } } { ( 1 + | \psi | ^ { 2 } ) ^ { 2 } }

本代码主要参考的是https://guillaumegenthial.github.io/image-to-latex.html该文章对应的代码。

  • 输入的是公式的图像,进行CNN特征提取,在CNN中使用elu(指数线性单元)激活函数,一点需要特别注意,在CNN最后一层中加入位置编码,这里参考的是transformer中的位置编码实现。
  • 使用CNN对图像进行特征提取之后,得到的数据维度为NHWC,将数据维度进行转换,在高度方向上使用双向RNN编码,之所以使用高度方向上进行双向RNN操作是因为输入的图像是一种长条形图片,在高度方向进行RNN是为了找到宽度方向上的前后联系。
  • 对训练数据集的目标label要进行处理,首先分词得到每个katex单元,确定整个词表的大小;此后,对目标label头部插入START标志,尾部插入END标志,词表中不存在的katex单元使用UNKNOWN标志,统计整个训练集,使用kmeans计算bucket的尺寸,并对训练集中的图像进行resize和padding,将相近的尺寸转换至对应的bucket的尺寸,在训练的时候,从bucket中提取数据凑成一个bucket进行训练,对目标序列进行padding,使用PAD标识符对齐目标序列。
  • 接下来是decoder阶段,使用的是Attention机制的decoder的解码机制,并且提供了greedy解码和binary search的解码方式,使用集束搜索可以得到更优解(贪婪搜索可以看做是搜索窗口为1的集束搜索)。

整个实现的pipeline如下图所示


ocr


对流程图进行解释说明

对于输入的图片使用CNN网络进行特征提取,之后在高度方向上使用双向RNN,找到数据的前后依赖关系,对数据进行维度转换,之后使用GRU+attention进行解码器设计,对于Attention的计算方式可以参考相关博客文献深度学习-Python-tutorial-LSTM-GRU-Attention

本项目代码说明

本项目中包含三个模型,包含im2katex,errorchecker,dismodel分别实现图像预测生成katex,预测katex错误语法纠正,和预测katex语法错误判别器。

对于im2katex和errorchecker均可以使用项目文件夹下的makefile文件进行训练和测试。

参数说明

  • data_type:

    使用哪种数据集训练得到的权重,目前有'handwritten'—手写体图片训练集, 'original'—印刷体图片训练集, 'merged'—二者合并的训练集,默认使用二者合并的训练集。

  • model_type

    运行那个程序,目前有两个模型,一个是im2katex,也就是上文所说的输入公式图片,输出预测的katex表达式;另一个是error,这个是对im2katex的改进(预测的katex存在缺失会导致katex无法渲染生成图片,对该类错误的图片使用nmt的方式进行错误纠正),经过训练,效果不好,目前已经废除,暂不使用。

  • mode

    何种网络模型运行方式,可选参数为'trainval', 'test', 'val', 'infer’

  • encoder_type

    在im2katex模型中的encoder的方式,可选为'Augment','conv',分别对应的是是否数据增强,如果参数为Augment则在数据输入时使用数据增强,如果是conv使用默认的数据输入

环境准备

测试

目前项目对于img2katex已经训练完成,可以使用makefile文件进行测试

命令行测试

命令行测试代码为

make im2katex-inference

根据提示在命令行输入待测试图片的地址或者待测试的文件夹位置

web端测试

目前的代码中提供了web端测试页面

启动测试代码

make server

页面如下,每行对应的内容分别为:

  1. inputimage:输入的图片
  2. processimage:预处理之后的图片
  3. latexpredict:预测生成的latex
  4. renderedimage:使用预测的latex渲染生成的图片

web_demo 如果没有安装pdflatex这个渲染插件,可以将输出的latex,在 https://katex.org/ 这里进行渲染 img


训练

目前im2katex使用的数据是网上的开源数据集,包含两种数据,一种为印刷体公式图片,数量大约为10w,另一个为手写体公式图片,数量大约为7w。将两个数据集合并,可以组成混合数据集。

训练模型

如果使用当前已经准备好的标准格式数据集,直接使用make文件下的命令,

如果训练标准数据集

make im2katex-trainval-ori

如果对标准数据集进行增强

make im2katex-trainval-aug

预训练权重参数

BaiduDisk:

链接: https://pan.baidu.com/s/1cUCSxCaReXw9EPs8_lwCRQ 提取码: bx4b

训练自己的数据

数据集的 链接: https://pan.baidu.com/s/1Ni1eOb3l9udrjlA0Mfx9OQ 提取码: vf9w

数据集预处理

首先对数据集的label进行预处理,将label中的katex进行分词,可以使用./data/process_im2latex/run_formula_normal.sh下的命令

对数据提取bucket,resize,padding等操作,相关处理程序代码为./data/build_imglatex_data.py

最后得到的数据为numpy格式,按照不同的bucket size存储,内部存储着图像名和对应的label,并且生成了包含数据集的一些映射关系:词表字典,索引与字符转换关系等。

如果要训练自己的数据,需要对数据集进行预处理。

训练模型

在准备好数据之后,同样使用make命令进行训练模型。

errorchecker模型说明

该模型的设计初衷是,使用im2katex经过训练测试之后,loss已经非常小,bleu和edit distance衡量指标已经很高,但是如果使用预测的katex能否生成图像为衡量指标,预测生成率仅为50%,也就是说,虽然预测时katex和label已经非常相似,但是对于严格的ketex结构,存在少量的差异也会导致渲染图片失败,因此errorchecker是按照nmt的思路设计了一个机器翻译模型,以预测的字符串作为输入,以label作为输出,训练nmt模型,并且使用的是im2katex预训练之后得到的embeding向量,由于对于一个正确的label存在多个错误的输入,这是一种多对一的映射,因此经过短期训练之后发现nmt模型无法收敛,该模型已经放弃

disMode模型说明

同样是对原来的im2katex的进一步优化,原来的loss已经很小,loss无法下降,那么增加loss是否可以使模型训练继续下去?答案是可以的。目前的loss来源于预测的katex字符和label之间的交叉熵损失,那么借鉴对抗生成网络,现在的im2katex可以看做是生成器,那么设计一个判别器来判断预测生成的katex是否可以生成图像,这就为模型引入了一个损失。

在现有im2katex的基础上进行预测,得到预测的katex和同等数量的label列表,接下来使用能否渲染生成图片作为衡量指标对预测的katex进行过滤,只保留无法生成的图片的预测katex,并且保留对应的label,判别器对该katex进行判断,如果无法生成katex,判别器预测输出为0,如果输入的是label,那么判别器预测输出应该为1,根据对抗生成网络的损失函数,实际就是价差上损失,对判别器进行训练,如果判别器训练收敛之后,可以将该判别器加入到im2katex之中,在im2katex中dismodel为测试模式,不更新参数。

目前思路:

对于整个项目,输入时公式的图片,我们想要的是准确的katex输出。在im2katex中输入图片,得到预测的katex,计算预测katex和label之间的交叉熵损失$\sum Alog(y)+(1-A)log(1-y)$,并且预测的katex输入dismodel,计算dismodel判断katex为1的损失$mean(-\log(D(katex)))$

在disModel中,使用的是TCN网络(具体的实现可以参考https://github.com/Songweiping/TCN-TF或者https://github.com/locuslab/TCN),对输入的序列进行3层的空洞卷积,之后经过线性变换得到输出。

参考代码

image2katex's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

image2katex's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.