GithubHelp home page GithubHelp logo

eagle705 / pytorch-transformer-chatbot Goto Github PK

View Code? Open in Web Editor NEW
49.0 3.0 11.0 48.95 MB

PyTorch v1.2에서 생긴 Transformer API 를 이용한 간단한 Chitchat 챗봇

Python 100.00%
pytorch transformer chitchat chatbot seq2seq korean-nlp text-generation pytorch-nlp transformer-api pytorch-transformer-chatbot

pytorch-transformer-chatbot's Introduction

PyTorch_Transformer_Chatbot

Simple Korean Generative Chatbot Implementation based on new PyTorch Transformer API (PyTorch v1.2 / Python 3.x)

transformer_fig

ToDo

  • Beam Search
  • Search hyperparams
  • Attention Visualization
  • Char-level transformer

Transformer API core logic

  • Padding masking이 매우 편함
  • decoder input의 future token을 못보게 하기 위한 masking은 함수로 제공함
  • Transformer의 input, output dim 순서는 [Sequnece, Batch, Embedding Dimension]으로 되어있어서 Transpose 해줘야함
  • 아쉽지만 Transformer API에서 Attention weight dict을 제공해주진 않음
def forward(self, enc_input: torch.Tensor, dec_input: torch.Tensor) -> torch.Tensor:
    x_enc_embed = self.input_embedding(enc_input.long())
    x_dec_embed = self.input_embedding(dec_input.long())

    # Masking
    src_key_padding_mask = enc_input == self.vocab.PAD_ID # tensor([[False, False, False,  True,  ...,  True]])
    tgt_key_padding_mask = dec_input == self.vocab.PAD_ID
    memory_key_padding_mask = src_key_padding_mask
    tgt_mask = self.transfomrer.generate_square_subsequent_mask(dec_input.size(1))

    # einsum ref: https://pytorch.org/docs/stable/torch.html#torch.einsum
    # https://obilaniu6266h16.wordpress.com/2016/02/04/einstein-summation-in-numpy/
    x_enc_embed = torch.einsum('ijk->jik', x_enc_embed)
    x_dec_embed = torch.einsum('ijk->jik', x_dec_embed)


    # transformer ref: https://pytorch.org/docs/stable/nn.html#torch.nn.Transformer
    feature = self.transfomrer(src = x_enc_embed,
                               tgt = x_dec_embed,
                               src_key_padding_mask = src_key_padding_mask,
                               tgt_key_padding_mask = tgt_key_padding_mask,
                               memory_key_padding_mask=memory_key_padding_mask,
                               tgt_mask = tgt_mask.to(device)) # src: (S,N,E) tgt: (T,N,E)

    logits = self.proj_vocab_layer(feature)
    logits = torch.einsum('ijk->jik', logits)

    return logits

Experiments

  • train set에 대해서는 Overfit으로 95%의 정확도를 보이지만, val set에 대해서는 낮음 (예시로 공개하기 애매할정도)
  • padding 값은 acc, loss 계산에서 모두 제외함
input:  [['나/NP', '를/JKO', '사랑/NNG', '한/XSA+ETM', '그/MM', '사람/NNG', '에게/JKB', '해/VV+EC', '줄/VX+ETM', '수/NNB', '있/VV', '는/ETM', '것/NNB', '<pad>', '<pad>'], ['맥주/NNG', '한/MM', '잔/NNG', '해야지/VV+EC', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']]
pred:  [['천천히/MAG', '잊어/NNP', '가/JKS', '요/JX', './SF', '</s>', '들/VV', '노력/NNG', '요/JX', '은/JX', '도/JX', '이/JKS', '도/JX', '이/JKS', '이/JKS'], ['적당히/MAG', '드세요/VV+EP+EF', './SF', '</s>', '에/JKB', '에/JKB', '드세요/VV+EP+EF', '드세요/VV+EP+EF', '구경/NNG', '드세요/VV+EP+EF', '에/JKB', '드세요/VV+EP+EF', '가리/VV', '드세요/VV+EP+EF', '드세요/VV+EP+EF']]
real:  [['천천히/MAG', '잊어/NNP', '가/JKS', '요/JX', './SF', '</s>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>'], ['적당히/MAG', '드세요/VV+EP+EF', './SF', '</s>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']]
  • Inference test 결과
    • train set에 있는건 잘 대답함
    • 얼추 잘(?) 입력하면 잘 대답함
    • "unk" 뜨거나 토큰 하나 바뀌어도 대답이 바뀜 ㅠ
문장을 입력하세요: 배고파
input:  [['배고파/VA+EC', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']]
pred:  [['얼른/MAG', '맛난/VA+ETM', '음식/NNG', '드세요/VV+EP+EF', './SF', '</s>']]
pred_str:  얼른 맛난 음식 드세요.

문장을 입력하세요: 너 누구야
input:  [['너/NP', '누구/NP', '야/VCP+EF', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']]
pred:  [['저/NP', '는/JX', '마음/NNG', '을/JKO', '이어주/VV', '는/ETM', '위/NNG', '로/JKB', '봇/NNG', '입니다/VCP+EF', './SF', '</s>']]
pred_str:  저는 마음을 이어주는 위로봇입니다.

문장을 입력하세요: 안녕
input:  [['안녕/IC', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']]
pred:  [['안녕/NNG', '하/XSV', '세요/EP+EF', './SF', '</s>']]
pred_str:  안녕하세요.

문장을 입력하세요: 졸리다 이제 자야지
input:  [['<unk>', '다/EF', '이제/MAG', '자/VV', '야지/EC', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']]
pred:  [['너무/MAG', '걱정/NNG', '하/XSV', '지/EC', '마세요/VX+EP+EF', './SF', '</s>']]
pred_str:  너무 걱정하지 마세요.

실행순서

  • 테스트만 할경우 inference.py만 실행
python build_vocab.py # 사전 구축
python train.py # 학습
python inference.py # 테스트

Requirements

pip install mxnet
pip install gluonnlp
pip install konlpy
pip install python-mecab-ko
pip install chatspace
pip install tb-nightly
pip install future
pip install pathlib

Reference Repositories

pytorch-transformer-chatbot's People

Contributors

eagle705 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

pytorch-transformer-chatbot's Issues

Chatspace 대용

안타깝게도... 이루다 사건으로 Chatspace 마저 사라진 것 같네요...
그래서 저는 evaluate.py의 다음 부분을

from chatspace import Chatspace
spacer = Chatspace()

그리고 evaluate 함수 내에선

pred_str = spacer.space(pred_str)

hanspell을 사용해 다음으로 변경하였습니다
from hanspell import spell_checker

evaluate 함수 내에선

pred_str = spell_checker.check(pred_str)
pred_str = pred_str.checked

일단은 문제는 없는데, 우선은 원 작가에게 이야기를 드리는게 좋을 것 같아 남겨놓습니다.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.