GithubHelp home page GithubHelp logo

datawhalechina / torch-rechub Goto Github PK

View Code? Open in Web Editor NEW
345.0 12.0 69.0 1.08 MB

A Lighting Pytorch Framework for Recommendation Models, Easy-to-use and Easy-to-extend.

License: MIT License

Python 62.66% Jupyter Notebook 37.34%
ctr-prediction pytorch recommendation-system recsys

torch-rechub's Introduction

Torch-RecHub

中文Wiki站

查看最新研发进度,认领感兴趣的研发任务,学习rechub模型复现心得,加入rechub共建者团队等

点击链接

安装

#稳定版 
pip install torch-rechub

#最新版(推荐)
1. git clone https://github.com/datawhalechina/torch-rechub.git
2. cd torch-rechub
3. python setup.py install

核心定位

易用易拓展,聚焦复现业界实用的推荐模型,以及泛生态化的推荐场景

主要特性

  • scikit-learn风格易用的API(fit、predict),即插即用

  • 模型训练与模型定义解耦,易拓展,可针对不同类型的模型设置不同的训练机制

  • 接受pandas的DataFrame、Dict数据输入,上手成本低

  • 高度模块化,支持常见Layer,容易调用组装成新模型

    • LR、MLP、FM、FFM、CIN

    • target-attention、self-attention、transformer

  • 支持常见排序模型

    • WideDeep、DeepFM、DIN、DCN、xDeepFM等
  • 支持常见召回模型

    • DSSM、YoutubeDNN、YoutubeDSSM、FacebookEBR、MIND等
  • 丰富的多任务学习支持

    • SharedBottom、ESMM、MMOE、PLE、AITM等模型

    • GradNorm、UWL、MetaBanlance等动态loss加权机制

  • 聚焦更生态化的推荐场景

    • 冷启动

    • 延迟反馈

    • 去偏
  • 支持丰富的训练机制

    • 对比学习

    • 蒸馏学习

  • 第三方高性能开源Trainer支持(Pytorch Lighting)

  • 更多模型正在开发中

快速使用

使用案例

  • 所有模型使用案例参考 /examples

  • 202206 Datawhale-RecHub推荐课程 组队学习期间notebook教程参考 /tutorials

精排(CTR预测)

from torch_rechub.models.ranking import DeepFM
from torch_rechub.trainers import CTRTrainer
from torch_rechub.utils.data import DataGenerator

dg = DataGenerator(x, y)
train_dataloader, val_dataloader, test_dataloader = dg.generate_dataloader(split_ratio=[0.7, 0.1], batch_size=256)

model = DeepFM(deep_features=deep_features, fm_features=fm_features, mlp_params={"dims": [256, 128], "dropout": 0.2, "activation": "relu"})

ctr_trainer = CTRTrainer(model)
ctr_trainer.fit(train_dataloader, val_dataloader)
auc = ctr_trainer.evaluate(ctr_trainer.model, test_dataloader)

多任务排序

from torch_rechub.models.multi_task import SharedBottom, ESMM, MMOE, PLE, AITM
from torch_rechub.trainers import MTLTrainer

task_types = ["classification", "classification"] 
model = MMOE(features, task_types, 8, expert_params={"dims": [32,16]}, tower_params_list=[{"dims": [32, 16]}, {"dims": [32, 16]}])

mtl_trainer = MTLTrainer(model)
mtl_trainer.fit(train_dataloader, val_dataloader)
auc = ctr_trainer.evaluate(ctr_trainer.model, test_dataloader)

召回模型

from torch_rechub.models.matching import DSSM
from torch_rechub.trainers import MatchTrainer
from torch_rechub.utils.data import MatchDataGenerator

dg = MatchDataGenerator(x y)
train_dl, test_dl, item_dl = dg.generate_dataloader(test_user, all_item, batch_size=256)

model = DSSM(user_features, item_features, temperature=0.02,
             user_params={
                 "dims": [256, 128, 64],
                 "activation": 'prelu',  
             },
             item_params={
                 "dims": [256, 128, 64],
                 "activation": 'prelu', 
             })

match_trainer = MatchTrainer(model)
match_trainer.fit(train_dl)

torch-rechub's People

Contributors

1985312383 avatar 1qweasdzxc avatar bokang-ugent avatar ginnie23 avatar hasai666 avatar hjh233 avatar icecapriccio avatar inease avatar jjplane avatar morningsky avatar storyandwine avatar wangych6 avatar yinpu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

torch-rechub's Issues

请教关于负采样的问题

以sample_method=0的采样方式为例,在negative_sample()函数里,是随机对item_id进行采样的。这里并没有判断采样的item_id是当前用户的未点击样本。所以想请问如果采样到的item_id正好是该用户历史点击过的样本或者正好是是当前正样本的item_id应该如何处理?在代码的哪一部分进行处理的?这个我可能没有看到,请各位指教~

Default argument value is mutable

I notice that mutable Default argument value is used in rechub, which may cause unexpected results that are difficult to discover. It will share the mutable object in all instances.

Ref: python document

Default parameter values are evaluated from left to right when the function definition is executed. **This means that the expression is evaluated once, when the function is defined, and that the same “pre-computed” value is used for each call. ** This is especially important to understand when a default parameter value is a mutable object, such as a list or a dictionary: if the function modifies the object (e.g. by appending an item to a list), the default parameter value is in effect modified. This is generally not what was intended. A way around this is to use None as the default, and explicitly test for it in the body of the function, e.g.:

def whats_on_the_telly(penguin=None):
    if penguin is None:
        penguin = []
    penguin.append("property of the zoo")
    return penguin
``

列划分问题

请问在Ali-CCP数据处理脚本里,对sparse_columns,dense_columns的划分依据是什么?

Mistake in preprocess_ali_ccp.py

In Line 81, vocabulary[k][v] should be set to 1 at first to coordinate with Line 90 (which should be >=10, for example). In this case the TRUE threshold is 12 now.

线上java加载模型

模型输入的是特征纬度的dict,线上java加载推理实在是头疼,ai.djl.pytorch好像也不支持这样的输入格式

PLE 的 bug

当使用 aliccp MTL example 运行 PLE 中存在以下问题

        for ple_out, tower, predict_layer in zip(ple_outs, self.towers, self.predict_layers):
            tower_out = tower(ple_out)  #[batch_size, 1]

但是tower_out 的形状是 [batch_size, 8]。

所以我认为是不是在PLE 模型初始化中

        self.towers = nn.ModuleList(
            MLP(expert_params["dims"][-1], output_layer=False, **tower_params_list[i]) for i in range(self.n_task))

改为

        self.towers = nn.ModuleList(
            MLP(expert_params["dims"][-1], output_layer=True, **tower_params_list[i]) for i in range(self.n_task))

@morningsky 欢迎指正

loss监控

我在使用召回模型的时候想要对模型训练的过程添加监控,请问在直接使用torch_rehub库时如何能够获取到torch_rechub/trainers/match_trainer/train_one_epoch中的loss?是否能添加一个返回变量?

VisibleDeprecationWarning(Creating an ndarray from ragged nested sequences) When Saving data cache

I think it's a very good example textbook for newcomers to recommender systems, and it contains a variety of models and tools encapsulated for use, so thank you very much for developing this program! But here is a small problem that may affect the experience of using it:

when calling np.save("./data/ml-1m/saved/data_cache.npy", (x_train, y_train, x_test)), numpy throws a warning:

...\site-packages\numpy\core\_asarray.py:171: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
  return array(a, dtype, copy=False, order=order, subok=True)

I think the reason is that: since NumPy version 19.0, one must specify dtype=object when creating an array from "ragged" sequences.

Release Notes:

Deprecate automatic dtype=object for ragged input
Calling np.array([[1, [1, 2, 3]]) will issue a DeprecationWarning as per NEP 34. Users should explicitly use dtype=object to avoid the warning.

Corresponding Pull Request:

numpy/numpy#15119

I will link a pr to this issue which fixes this warning.

在全量ali-ccp数据集上训练ESMM模型,在测试集上,cvr和ctcvr的auc=1.0

环境

ubuntu 22.04, python=3.9.18, torch=2.1.0, GPU=3090,

预处理

利用preprocess_ali_ccp.py脚本处理得到 train set & test set(没有将test set进一步拆分为validation set和test set)

参数设置

tutorials/Multi_Task.ipynb文件中,将task_types = ["classification", "classification"] 修改为task_types = ["classification", "classification", "classification"](MTLTrainer中第111行,ESMM模型的total_loss = sum(loss_list[1:]),因此训练时一定是要将task_type设置成三个二分类任务(cvr,ctr,ctcvr)),其他超参数设置不变,进行全量模型训练,得到的日志如下

输出日志

train loss: {'task_0:': 0.0009412071586964892, 'task_1:': 0.15991775316040385, 'task_2:': 5.389307421478335e-05}
epoch: 0 validation scores: [1.0, 0.5951526439762982, 1.0]

ESMM的相关问题

作者你好,我是一个小白,有几个问题想咨询一下,如果您看到受累帮我解答一下,谢谢:
(1)esmm那个模型输出的两个auc分别对应的那个任务?
(2)可以输出每个任务的预测值吗?
(3)输入特征是只能是01吗,连续型数值可以吗?
(4)可以加入其他的评价指标吗?

ESMM的输入

您好,请问一下ESMM模型构架的主任务cvr和辅助任务ctr的输入是一样的吗(所有的曝光样本)?因为我看那个模型架构图,它是两个网络分别输入进去的,如果两个网络的输入是一样的,为何架构图不画成MMOE那种一个输入进去,共享底层再分成不同的任务,我这里有点疑惑。

pysparkh支持问题

大数据集群下一般用pyspark处理输出比较多,请问怎么能支持到pyspark处理后的dataframe?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.