GithubHelp home page GithubHelp logo

wxl1999 / unicrs Goto Github PK

View Code? Open in Web Editor NEW
21.0 4.0 14.0 13.31 MB

[KDD22] Official PyTorch implementation for "Towards Unified Conversational Recommender Systems via Knowledge-Enhanced Prompt Learning".

License: MIT License

Python 95.93% Shell 4.07%
conversation conversational-ai conversational-bots dialog dialogue dialogue-systems pretrained-language-model pretrained-models pretraining prompt

unicrs's People

Contributors

wxl1999 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

unicrs's Issues

The result in conversational task is better than paper's

I ran your code successfully. Although the result in recommendation part is almost the same, the result of conversational part is so much better than your paper's. And I didn't change the source code. Would you tell me the reasons about it?
The following is the result in my device.
'test/dist@2': 0.5503033723719486, 'test/dist@3': 0.9362212501763792, 'test/dist@4': 1.211090729504727
The following is the result in your paper.
image

environment file is needed

When I'm trying to reproduce the method, there are two main challenges:

  1. The link of DBpedia is not available now.
  2. There is no environment file and some necessary packages are not provided in readme, which makes it hard to create a env.

Hope to update~

Meaning of pooling in pre-training step

Hi Xiaolei, thank you for your work. I'm interested in various works with knowledge graph.
Anyway, I have a question. What is the meaning of the pooling in the pre-training step?

$$ h_{u} = Pooling[f(\tilde{C}{pre} | \Theta{plm} ; \Theta_{fuse})] $$ in this equation, I cannot understand the meaning of pooling, and find corresponding code in github.

It would be very grateful indeed if you can give me anyhelp.
Best regards.
Yongtaek

about process steps

Hey Xiaolei. First of all, thanks for your work! I have successfully run your code on Github, but I have a few questions about the preprocessing code.

  1. I got very high recall@k scores in pretrained-prompt model.
    {'test/recall@1': 0.5659283956497578, 'test/recall@10': 0.9083115027387473, 'test/recall@50': 0.9323648487735176, 'test/ndcg@1': 0.5659283956497578, 'test/ndcg@10': 0.757117269984917, 'test/ndcg@50': 0.7626062380924029, 'test/mrr@1': 0.5659283956497578, 'test/mrr@10': 0.7063158638961657, 'test/mrr@50': 0.7075750767791935, 'test/loss': 2.209772330341917, 'epoch': 4}
    Is this due to the file preprocessed(test_data_processed.jsonl) through process.py containing user and system responses? So this recall@k is not accurate?

  2. I wonder if it is correct to understand the code for the three preprocessed datasets in this way.
    (1) the 'process.py' is for extracting user and system response and their context for semantic fusion.
    (2) the 'process_mask.py' is for extracting system response for conv prompt.
    (3) the 'merge.py' is for merge the template by conv prompt model and items for rec prompt.

It would be very grateful indeed if you can give me anyhelp.
Best regards.
siqi

How to reproduce the rec performance on INSPIRED dataset?

Dear Author,

I am trying to reproduce the rec performance on INSPIRED dataset.

image

I use the hyperparameters you recommend and the "best" model as prompt-encoder. Unfortunately, I was not able to reproduce the performance on the paper.

---- Here I attached the loss and recall@1 on testset for prompt pre-training, conversational training, and recommendation training steps:
image
image

prompt pre-training

image
conversational training

image
image
recommendation training (as you can see, the best recall@1 I got is around 0.04, far from 0.09)

---- and here are the configuration I use for prompt pre-training, conversational training, and recommendation training steps:

python3 train_pre.py \
    --dataset inspired \
    --tokenizer microsoft/DialoGPT-small \
    --model microsoft/DialoGPT-small \
    --text_tokenizer roberta-base \
    --text_encoder roberta-base \
    --num_train_epochs 5 \
    --gradient_accumulation_steps 1 \
    --per_device_train_batch_size 64 \
    --per_device_eval_batch_size 128 \
    --num_warmup_steps 168 \
    --max_length 200 \
    --prompt_max_length 200 \
    --entity_max_length 32 \
    --learning_rate 6e-4 \
    --output_dir UniCRS/src/result_promptpretraining_inspired \
    --use_wandb \
    --project crs-prompt-pre-inspired \
    --name exp1 \
    --gpu 0 

prompt pre-training

python3 train_conv.py \
    --dataset inspired \
    --tokenizer microsoft/DialoGPT-small \
    --model microsoft/DialoGPT-small \
    --text_tokenizer roberta-base \
    --text_encoder roberta-base \
    --n_prefix_conv 20 \
    --prompt_encoder UniCRS/src/result_promptpretraining_inspired/best/ \
    --num_train_epochs 10 \
    --gradient_accumulation_steps 1 \
    --ignore_pad_token_for_loss \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 16 \
    --num_warmup_steps 976 \
    --context_max_length 200 \
    --resp_max_length 183 \
    --prompt_max_length 200 \
    --entity_max_length 32 \
    --learning_rate 1e-4 \
    --output_dir UniCRS/src/result_convprompt_inspired \
    --use_wandb \
    --project crs-prompt-conv-inspired \
    --name exp1 \
    --gpu 0

conv training

python3 infer_conv.py \
    --dataset inspired \
    --split test \
    --tokenizer microsoft/DialoGPT-small \
    --model microsoft/DialoGPT-small \
    --text_tokenizer roberta-base \
    --text_encoder roberta-base \
    --n_prefix_conv 20 \
    --prompt_encoder UniCRS/src/result_convprompt_inspired/best \
    --per_device_eval_batch_size 64 \
    --context_max_length 200 \
    --resp_max_length 183 \
    --prompt_max_length 200 \
    --entity_max_length 32 \
    --gpu 1

conv infer

python3 train_rec.py \
    --dataset inspired_gen \
    --tokenizer microsoft/DialoGPT-small \
    --model microsoft/DialoGPT-small \
    --text_tokenizer roberta-base \
    --text_encoder roberta-base \
    --n_prefix_rec 10 \
    --prompt_encoder UniCRS/src/result_promptpretraining_inspired/best \
    --num_train_epochs 5 \
    --per_device_train_batch_size 64 \
    --per_device_eval_batch_size 64 \
    --gradient_accumulation_steps 1 \
    --num_warmup_steps 33 \
    --context_max_length 200 \
    --prompt_max_length 200 \
    --entity_max_length 32 \
    --learning_rate 1e-4 \
    --output_dir UniCRS/src/result_rec_inspired \
    --use_wandb \
    --project crs-prompt-rec-inspired \
    --name exp1 \
    --gpu 0

rec training

Thank you!

How to reproduce the performance on the ReDial dataset?

I trained according to the code provided on GitHub, but since the dataset link you provided cannot be opened, I used mapping based objects_ Lang=en_ 202112.ttl dataset. The final results of my training are as follows:

conv:
'test/dist@2': 0.310709750246931, 'test/dist@3': 0.49851841399746016, 'test/dist@4': 0.6383519119514605
rec:
'test/recall@1': 0.029324894514767934, 'test/recall@10': 0.16729957805907172, 'test/recall@50': 0.37953586497890296

(1)These results differ greatly from the results presented in the paper. Can you give me some guidance? I hope to reproduce results similar to yours. Thank you very much.
(2)According to your paper, do I need to set --n_prefix_conv 50 in the train_conv.py and --use_resp in the train_rec. py?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.