GithubHelp home page GithubHelp logo

adriangonz / statistical-nlp-17 Goto Github PK

View Code? Open in Web Editor NEW
22.0 6.0 9.0 353.61 MB

Repository for group 17 on the Statistical Natural Language Processing module at UCL

Python 71.88% Shell 0.50% Jupyter Notebook 27.62%
nlp pytorch matching-networks

statistical-nlp-17's Introduction

Statistical NLP (Group 17)

This is the repository for Group 17 of the Statistical Natural Language Processing module at UCL, formed by:

This repository implements the Matching Networks architecture (Vinyals et al., 2016) in pytorch and applies it to a Language Modelling task. The architecture is flexible enough to allow easy experimentation with distance metrics, number of labels per episode, number of examples per label, etc.

More details can be found in the associated paper.

Demo

You can experiment with the model using the attached Colab Notebook.

One_Shot_Learning_for_Language_Modelling.ipynb

Getting Started

To keep the environments as reproducible as possible, we will use pipenv to handle dependencies. To install it just follow the instructions in https://pipenv.readthedocs.io/en/latest/.

The first time, to create the environment and install all required dependencies, just run:

$ pipenv install

This will create a virtualenv and will install all required dependencies.

Installing new dependencies

To add new dependencies just run:

$ pipenv install numpy

Remember to commit the updated Pipfile and Pipfile.lock files so that everyone else can also install them!

Folder Structure

Most of the source code can be found under the src/ folder. However, we also include a set of command line tools, which should help with sampling, training and testing models. These can be found under the bin/ folder.

Additionally, you can find the following folders:

  • wikitext-2/: Raw WikiText-2 data set.
  • data/: Pre-sampled set of label/sentence pairs and pre-generated vocabulary.
  • models/: Pre-trained models. The filenames encode the different parameters used to train the model.
  • results/: Data generated after evaluating the models. It includes predictions on the test set, embeddings and attention maps.
  • figures/: Figures generated from the data in the results/ folder.

Tests

We are using pytest for writing and running unit tests. You can see some examples on the /src/tests/ folder.

To run all tests, just run the following command:

$ pytest -s src/tests

Dataset

On the data/ folder you can find the train.csv and test.csv files, which contain each 9000 labels with 10 examples each and 1000 labels with 10 examples each respectively.

The data is in CSV format with two columns:

  • label: The word acting as label which we need to find.
  • sentence: The sentence acting as input, where the particular word has been replaced with the token <blank_token>.

An example can be seen below:

label,sentence
music,no need to be a hipster to play <blank_token> in vynils
music,nowadays <blank_token> doesn't sound as before
...

Sampling new pairs

If you want to sample a new set of pairs from the WikiText-2 dataset you can use the bin.sample script. For example, to resample the entire dataset, we could just run:

$ python -m bin.sample -N 9000 -k 10 wikitext-2/wiki.train.tokens data/train.csv
$ python -m bin.sample -N 1000 -k 10 wikitext-2/wiki.test.tokens data/test.csv

Note that the file will be processed first, to be as similar as text coming from PTB.

Generating vocabulary

To make things easy to replicate, we generate in advance the vocabulary over the training set and store it in a file, which can then be used later for training and testing. You can have a look at the format in data/vocab.json.

To re-generate it (after sampling new pairs, for example), you can use the bin.vocab script:

$ python -m bin.vocab data/train.csv data/vocab.json

This command will store the vocabulary's state as a JSON file.

Training

Training of a new model can be performed using the bin.train script:

$ python -m bin.train -N 5 -k 2 -d euclidean data/vocab.json data/train.csv

The N and k parameters control the number of labels and examples we want per episode respectively. The other parameters refer to other parameters (like distance metric) and the pre-computed vocabulary and the training set.

After convergence, the best model's state_dict is stored under the models/ folder, with the different parameters encoded in its name. For example, the model poincare_vanilla_N=5_k=2_model_7.pth was trained using the poincare distance metric, vanilla embedding, using 5 labels with 2 examples each per episode. From the file name it can also be seen that it converged after 7 epochs.

These details are discussed in further detail in the associated paper.

Evaluation

Accuracy on a test set for a given model's snapshot can be measured using the bin.test script:

$ python -m bin.test -v data/vocab.json -m models/euclidean_vanilla_N\=5_k\=3_model_24.pth data/test.csv

This command has extra flags which allow to:

  • -p: Store the predictions in the results/ folder.
  • -e: Generate embeddings and attention for a single episode and store them in the results/ folder.

Some of the already generated data can be seen in the results/ folder.

Repository

This repository can be found in https://github.com/adriangonz/statistical-nlp-17.

statistical-nlp-17's People

Contributors

ad-szwarc avatar adriangonz avatar azurereflection avatar pilatracu avatar ucabtuc avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

statistical-nlp-17's Issues

Implement Matching Networks

Implement Matching Networks in pytorch. Note that this also includes the "pipeline" bit between the CSVs in the dataset and the model itself.

Implement evaluation metrics

Implement evaluation metrics for the results of the models. This may include ROC curves and some kind of summarising tables or metrics, but any other suggestion is more than welcome.

These plots will end up in the report, so perhaps it would be good to make them easy to re-render from some kind of CSV with the predictions output by the model.

Process WT2 dataset

Process WT2 dataset to generate pairs of sentences with missing words. Note that there should be at least N labels per episode with k example sentences each.

The final format will probably be a CSV like

sentence,word
The <missing> shot the duck.,hunter
There is a <missing> in my boot!,snake

Where the <missing> token is "special" and flags where the placeholder is.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.