GithubHelp home page GithubHelp logo

ikihiyori / pygru4rec Goto Github PK

View Code? Open in Web Editor NEW

This project forked from yhs968/pygru4rec

0.0 1.0 0.0 9.23 MB

PyTorch Implementation of Session-based Recommendations with Recurrent Neural Networks(ICLR 2016, Hidasi et al.)

Jupyter Notebook 50.38% Python 48.77% Shell 0.85%

pygru4rec's Introduction

pyGRU4REC


Environment

  • Python 3.6.3
  • PyTorch 0.3.0.post4
  • pandas 0.20.3
  • numpy 1.13.3

Usage

Training / Test Set Specifications

  • Filenames
    • Training set should be named as train.tsv
    • Test set should be named as test.tsv
  • File Paths
    • train.tsv, test.tsv should be located under the data directory. i.e. data/train.tsv, data/test.tsv
  • Contents
    • train.tsv, test.tsv should be the tsv files that stores the pandas dataframes that satisfy the following requirements(without headers):
      • The 1st column of the tsv file should be the integer Session IDs
      • The 2nd column of the tsv file should be the integer Item IDs
      • The 3rd column of the tsv file should be the Timestamps

Training/Testing using Jupyter Notebook

See example.ipynb for the full jupyter notebook script that

  1. Loads the data
  2. Trains a GRU4REC model
  3. Load the trained GRU4REC model
  4. Tests a GRU4REC model

Training using run_train.py

  • Before using run_train.py, I highly recommend that you to take a look at example.ipynb to see how the implementation works.
  • Default parameters are the same as the TOP1 loss case in the GRU4REC paper.
  • Intermediate models created from each training epochs will be stored to models/, unless specified.
  • The log file will be written to logs/train.out.
$ python run_train.py > logs/train.out

Args:
    --loss_type: Loss function type. Should be one of the 'TOP1', 'BPR', 'CrossEntropy'.(Default: 'TOP1')
    --model_name: The prefix for the intermediate models that will be stored during the training.(Default: 'GRU4REC')
    --hidden_size: The dimension of the hidden layer of the GRU.(Default: 100)
    --num_layers: The number of layers for the GRU.(Default: 1)
    --batch_size: Training batch size.(Default: 50)
    --dropout_input: Dropout probability of the input layer of the GRU.(Default: 0)
    --dropout_hidden: Dropout probability of the hidden layer of the GRU.(Default: .5)
    --optimizer_type: Optimizer type. Should be one of the 'Adagrad', 'RMSProp', 'Adadelta', 'Adam', 'SGD'(Default: 'Adagrad')
    --lr: Learning rate for the optimizer.(Default: 0.01)
    --weight_decay: Weight decay for the optimizer.(Default: 0)
    --momentum: Momentum for the optimizer.(Default: 0)
    --eps: eps parameter for the optimizer.(Default: 1e-6)
    --n_epochs: The number of training epochs to run.(Default: 10)
    --time_sort: Whether to sort the sessions in the dataset in a time order.(Default: False)
    --n_samples: The number of samples to use for the training. If -1, all samples in the training set are used.(Default: -1)

Testing using run_test.py

  • The log file will be written to logs/test.out.
$ python run_test.py model_file > logs/test.out

Args:
    model_file: name of the intermediate model under the `./models` directory. e.g. `python run_test.py GRU4REC_TOP1_Adagrad_0.01_epoch10 > ./logs/test.out`
    --loss_type: Loss function type. Should be one of 'TOP1', 'BPR', 'CrossEntropy'.(Default: 'TOP1')
    --hidden_size: The dimension of the hidden layer of the GRU.(Default: 100)
    --num_layers: The number of layers for the GRU.(Default: 1)
    --batch_size: Training batch size.(Default: 50)
    --dropout_input: Dropout probability of the input layer of the GRU.(Default: 0)
    --dropout_hidden: Dropout probability of the hidden layer of the GRU.(Default: .5)
    --optimizer_type: Optimizer type. Should be one of 'Adagrad', 'RMSProp', 'Adadelta', 'Adam', 'SGD'(Default: 'Adagrad')
    --lr: Learning rate for the optimizer.(Default: 0.01)
    --weight_decay: Weight decay for the optimizer.(Default: 0)
    --momentum: Momentum for the optimizer.(Default: 0)
    --eps: eps parameter for the optimizer.(Default: 1e-6)
    --n_epochs: The number of training epochs to run.(Default: 10)
    --time_sort: Whether to sort the sessions in the training set in a time order.(Default: False)
    --n_samples: The number of samples to use for the training. If -1, all samples in the training set are used.(Default: -1)

Reproducing the results of the original paper

  • The results from this PyTorch Implementation gives a slightly better result compared to the original code that was written in Theano. I guess this comes from the difference between Theano and PyTorch.
  • The results were reproducible within only 2 or 3 epochs, unlike the original Theano implementation which runs for 10 epochs by default.
$ bash run_train.sh
$ bash run_test.sh

ToDo

  • Multi-GPU training support
  • Optimize the testing code(too slow)

pygru4rec's People

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.