GithubHelp home page GithubHelp logo

hjpwhu / cakechat Goto Github PK

View Code? Open in Web Editor NEW

This project forked from lukalabs/cakechat

0.0 2.0 0.0 106 KB

CakeChat: Emotional Generative Dialog System

Home Page: https://cakechat.replika.ai

License: Apache License 2.0

Python 100.00%

cakechat's Introduction

CakeChat: Emotional Generative Dialog System

CakeChat is a dialog system that is able to express emotions in a text conversation. Try it online!

Demo

It is written in Theano and Lasagne. It uses end-to-end trained embeddings of 5 different emotions to generate responses conditioned by a given emotion.

The code is flexible and allows to condition a response by an arbitrary categorical variable defined for some samples in the training data. With CakeChat you can, for example, train your own persona-based neural conversational model[5] or create an emotional chatting machine without external memory[4].

Table of contents

Network architecture and features

Network architecture

  • Model:
    • Hierarchical Recurrent Encoder-Decoder (HRED) architecture for handling deep dialog context[7]
    • Multilayer RNN with GRU cells. First layer of the utterance-level encoder is always bidirectional.
    • Thought vector is fed into decoder on each decoding step.
    • Decoder can be conditioned on any string label. For example: emotion label or id of a person talking.
  • Word embedding layer:
    • May be initialized using w2v model trained on your own corpus.
    • Embedding layer may either stay fixed of be fine-tuned along with all other weights of the network.
  • Decoding
    • 4 different response generation algorithms: "sampling", "beamsearch", "sampling-reranking" and "beamsearch-reranking". Reranking of the generated candidates is performed according to the log-likelihood or MMI-criteria[3]. See configuration settings description for details.
  • Metrics:
    • Perplexity
    • n-gram distinct metrics adjusted to the samples size[3].
    • Lexical similarity between samples of the model and some fixed dataset. Lexical similarity is a cosine distance between TF-IDF vector of responses generated by the model and tokens in the dataset.
    • Ranking metrics: mean average precision and mean recall@k.[8]

Quick start

Quickly build a CPU-only docker image, run it & start the CakeChat serving the model on 8080 port:

# Build & run docker container
docker build -t cakechat:latest -f dockerfiles/Dockerfile.cpu dockerfiles/

docker run --name cakechat-dev -p 127.0.0.1:8080:8080 -it cakechat:latest \
    bash -c "python tools/download_model.py && python bin/cakechat_server.py"

That's it! Now you can try it by running python tools/test_api.py -f localhost -p 8080 -c "Hi! How are you?" from the host command line.

Setup

Docker

This is the easiest way to set up the environment and install all the dependencies.

CPU-only setup

  1. Install Docker

  2. Build a docker image

Build a CPU-only image:

docker build -t cakechat:latest -f dockerfiles/Dockerfile.cpu dockerfiles/
  1. Start the container

Run a docker container in the CPU-only environment

docker run --name <CONTAINER_NAME> -it cakechat:latest

GPU-enabled setup

  1. Install nvidia-docker for the GPU support.

  2. Build a GPU-enabled docker image:

nvidia-docker build -t cakechat-gpu:latest -f dockerfiles/Dockerfile.gpu dockerfiles/
  1. Start the container

Run a docker container in the GPU-enabled environment:

nvidia-docker run --name <CONTAINER_NAME> -it cakechat-gpu:latest

That's it! Now you can train your model and chat with it.

Manual setup

If you don't want to deal with docker images and containers, you can always simply run (with sudo, --user or inside your virtualenv):

pip install -r requirements.txt

Most likely this will do the job. NB: This method only provides a CPU-only environment. To get a GPU support, you'll need to build and install libgpuarray by yourself (see Dockerfile.gpu for example).

Getting the model

Using a pre-trained model

Run python tools/download_model.py to download our pre-trained model.

The model is trained with context size 3 where the encoded sequence contains 30 tokens or less and the decoded sequence contains 32 tokens or less. Both encoder and decoder contain 2 GRU layers with 512 hidden units each.

The model was trained on a Twitter preprocessed conversational data. To clean up the data, we removed URLs, retweets and citations. Also we removed mentions and hashtags that are not preceded by normal words or punctuation marks and filtered out all messages that contains more than 30 tokens.
Then we marked out each utterance with our emotions classifier that predicts one of the 5 emotions: "neutral", "joy", "anger", "sadness" and "fear". To mark-up your own corpus with emotions you can use, for example, DeepMoji tool or any other emotions classifier that you have.

Training your own model

  1. Put your training text corpus to data/corpora_processed/. Each line of the corpus file should be a JSON object containing a list of dialog messages sorted in chronological order. Code is fully language-agnostic — you can use any unicode texts in datasets. Refer to our dummy corpus to see the input format data/corpora_processed/train_processed_dialogs.txt.

  2. The following datasets are used for validation and early stopping:

  1. Set up training parameters in cakechat/config.py. See configuration settings description for more details.
  2. Run python tools/prepare_index_files.py to build the index files with tokens and conditions from the training corpus.
  3. Run python tools/train.py. Don't forget to set USE_GPU=<GPU_ID> environment variable (with GPU_ID as from nvidia-smi) if you want to use GPU. Use SLICE_TRAINSET=N to train the model on a subset of the first N samples of your training data to speed up preprocessing for debugging.
  4. You can also set IS_DEV=1 to enable the "development mode". It uses a reduced number of model parameters (decreased hidden layer dimensions, input and output sizes of token sequences, etc.), performs verbose logging and disables Theano graph optimizations. Use this mode for debugging.
  5. Weights of your model will be saved in data/nn_models/.

Existing training datasets

You can train a dialog model on any text conversational dataset available to you. A great overview of existing conversational datasets can be found here: https://breakend.github.io/DialogDatasets/

Running the system

Local HTTP-server

Run a server that processes HTTP-requests with given input messages (contexts) and returns response messages of the model:

python bin/cakechat_server.py

Specify USE_GPU=<GPU_ID> environment variable if you want to use a certain GPU.

Wait until the model is compiled. Don't forget to run tools/download_model.py prior to running bin/cakechat_server.py if you want to start an API with our pre-trained model.

To make sure everything works fine, test the model on the following conversation:

– Hi, Eddie, what's up?
– Not much, what about you?
– Fine, thanks. Are you going to the movies tomorrow?

python tools/test_api.py -f 127.0.0.1 -p 8080 \
    -c "Hi, Eddie, what's up?" \
    -c "Not much, what about you?" \
    -c "Fine, thanks. Are you going to the movies tomorrow?"

HTTP-server API description

/cakechat_api/v1/actions/get_response

JSON parameters are:

Parameter Type Description
context list of strings List of previous messages from the dialogue history (max. 3 is used)
emotion string, one of enum One of {'neutral', 'anger', 'joy', 'fear', 'sadness'}. An emotion to condition the response on. Optional param, if not specified, 'neutral' is used
Request
POST /cakechat_api/v1/actions/get_response
data: {
 'context': ['Hello', 'Hi!', 'How are you?'],
 'emotion': 'joy'
}
Response OK
200 OK
{
 'response': 'I\'m fine!'
}

Gunicorn HTTP-server

We recommend to use Gunicorn for serving the API of your model at a production scale.

Run a server that processes HTTP-queries with input messages and returns response messages of the model:

cd bin && gunicorn cakechat_server:app -w 1 -b 127.0.0.1:8080 --timeout 2000

You may need to install gunicorn from pip: pip install gunicorn.

Telegram bot

You can also test your model in a Telegram bot: create a telegram bot and run

python tools/telegram_bot.py --token <YOUR_BOT_TOKEN>

Repository overview

  • cakechat/dialog_model/ - contains computational graph, training procedure and other model utilities
  • cakechat/dialog_model/inference/ - algorithms for response generation
  • cakechat/dialog_model/quality/ - code for metrics calculation and logging
  • cakechat/utils/ - utilities for text processing, w2v training, etc.
  • cakechat/api/ - functions to run http server: API configuration, error handling
  • tools/ - scripts for training, testing and evaluating your model

Important tools

Important configuration settings

All the configuration parameters for the network architecture, training, predicting and logging steps are defined in cakechat/config.py. Some inference parameters used in an HTTP-server are defined in cakechat/api/config.py.

  • Network architecture and size

    • HIDDEN_LAYER_DIMENSION is the main parameter that defines the number of hidden units in recurrent layers.
    • WORD_EMBEDDING_DIMENSION and CONDITION_EMBEDDING_DIMENSION define the number of hidden units that each token/condition are mapped into. Together they sum up to the dimension of input vector passed to the encoder RNN.
    • Number of units of the output layer of the decoder is defined by the number of tokens in the dictionary in the tokens_index directory.
  • Decoding algorithm:

    • PREDICTION_MODE_FOR_TESTS defines how the responses of the model are generated. The options are the following:
      • sampling – response is sampled from output distribution token-by-token. For every token the temperature transform is performed prior to sampling. You can control the temperature value by tuning DEFAULT_TEMPERATURE parameter.
      • sampling-reranking – multiple candidate-responses are generated using sampling procedure described above. After that the candidates are ranked according to their MMI-score[3] You can tune this mode by picking SAMPLES_NUM_FOR_RERANKING and MMI_REVERSE_MODEL_SCORE_WEIGHT parameters.
      • beamsearch – candidates are sampled using beam search algorithm. The candidates are ordered according to their log-likelihood score computed by the beam search procedure.
      • beamsearch-reranking – same as above, but the candidates are re-ordered after the generation in the same way as in sampling-reranking mode.

    Note that there are other parameters that affect the response generation process. See REPETITION_PENALIZE_COEFFICIENT, NON_PENALIZABLE_TOKENS, MAX_PREDICTIONS_LENGTH.

Example use cases

By providing additional condition labels within a dataset entries, you can build the following models:

To make use of these extra conditions, please refer to the section Training your own model. Just set the "condition" field in the training set to one of the following: persona ID, emotion or topic label, update the index files and start the training.

References

Credits & Support

CakeChat is developed and maintained by the Replika team: Michael Khalman, Nikita Smetanin, Artem Sobolev, Nicolas Ivanov, Artem Rodichev and Denis Fedorenko. Demo by Oleg Akbarov, Alexander Kuznetsov and Vladimir Chernosvitov.

All issues and feature requests can be tracked here - GitHub Issues.

License

© 2018 Luka, Inc. Licensed under the Apache License, Version 2.0. See LICENSE file for more details.

cakechat's People

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.