GithubHelp home page GithubHelp logo

本TF版本的输入 `token_ids, segment_ids` 和原版的模型输入 `input_ids, token_type_ids` 是否一样? about cdial-gpt-tf HOT 9 OPEN

bojone avatar bojone commented on August 22, 2024
本TF版本的输入 `token_ids, segment_ids` 和原版的模型输入 `input_ids, token_type_ids` 是否一样?

from cdial-gpt-tf.

Comments (9)

bojone avatar bojone commented on August 22, 2024

具体的差别是什么?把部分同位置的数值复制粘贴一下?

我没有跑过pytorch版,但是我将用transformers的tf版加载对比过,bert4keras版的输出跟它基本一致,所以应该是没有错的。另外,实测的聊天效果也没有问题。

from cdial-gpt-tf.

wulikai1993 avatar wulikai1993 commented on August 22, 2024

我在原项目该行 https://github.com/thu-coai/CDial-GPT/blob/master/interact.py#L81 代码下直接插入了如下代码:

tf_input_ids = input_ids.numpy()
tf_token_type_ids = token_type_ids.numpy()
tf_logits = tf_model.predict([tf_input_ids, tf_token_type_ids])

其中tf_model 是用的本项目中下载的 checkpoint 加载的。

然后打印了 logitstf_logits,输出如下:

<class 'torch.Tensor'>
torch.Size([1, 6, 13088])
tensor([[[ -6.3841,  -9.6811,   4.7181,  ...,  -8.7645,   0.7623,   0.9506],
         [ -6.3841,  -9.6811,   4.7181,  ...,  -8.7645,   0.7623,   0.9506],
         [ -9.3133, -12.9856,   2.2705,  ..., -12.7288,  -8.5736,  -7.1285],
         [ -8.5503, -11.9364,   7.3586,  ..., -11.0706,  -8.0792,  -6.5877],
         [ -8.5294, -11.8703,   0.3997,  ..., -11.7161,  -9.0465,  -9.0090],
         [ -8.7134, -13.0577,  -0.3204,  ..., -13.2630, -11.2105,  -9.9821]]])
<class 'numpy.ndarray'>
(1, 6, 13088)
[[[3.0044863e-08 8.1140666e-07 5.3718887e-02 ... 7.5160841e-08
   1.0162239e-03 1.2269986e-03]
  [3.0044863e-08 8.1140666e-07 5.3718887e-02 ... 7.5160841e-08
   1.0162249e-03 1.2269994e-03]
  [2.6168695e-10 1.0278778e-08 1.0963416e-03 ... 3.4622555e-10
   2.0836838e-08 9.0029353e-08]
  [5.9358007e-10 1.7513411e-08 1.4495271e-01 ... 1.4628808e-09
   2.7349104e-08 1.2074231e-07]
  [3.6721937e-10 1.0199308e-08 7.8793033e-05 ... 4.2641926e-10
   5.6711422e-09 5.8271072e-09]
  [1.0654655e-12 8.3270356e-11 4.1641871e-07 ... 8.8329517e-13
   6.3699354e-12 2.1308885e-11]]]

from cdial-gpt-tf.

wulikai1993 avatar wulikai1993 commented on August 22, 2024

另外,实测的聊天效果也没有问题。

请问可以参考一下您的聊天测试代码吗?

from cdial-gpt-tf.

bojone avatar bojone commented on August 22, 2024

另外,实测的聊天效果也没有问题。

请问可以参考一下您的聊天测试代码吗?

代码全在这里了呀 https://github.com/bojone/CDial-GPT-tf/blob/master/example.pychatbot.response就可以模拟聊天。

from cdial-gpt-tf.

bojone avatar bojone commented on August 22, 2024

我在原项目该行 https://github.com/thu-coai/CDial-GPT/blob/master/interact.py#L81 代码下直接插入了如下代码:

tf_input_ids = input_ids.numpy()
tf_token_type_ids = token_type_ids.numpy()
tf_logits = tf_model.predict([tf_input_ids, tf_token_type_ids])

其中tf_model 是用的本项目中下载的 checkpoint 加载的。

然后打印了 logitstf_logits,输出如下:

<class 'torch.Tensor'>
torch.Size([1, 6, 13088])
tensor([[[ -6.3841,  -9.6811,   4.7181,  ...,  -8.7645,   0.7623,   0.9506],
         [ -6.3841,  -9.6811,   4.7181,  ...,  -8.7645,   0.7623,   0.9506],
         [ -9.3133, -12.9856,   2.2705,  ..., -12.7288,  -8.5736,  -7.1285],
         [ -8.5503, -11.9364,   7.3586,  ..., -11.0706,  -8.0792,  -6.5877],
         [ -8.5294, -11.8703,   0.3997,  ..., -11.7161,  -9.0465,  -9.0090],
         [ -8.7134, -13.0577,  -0.3204,  ..., -13.2630, -11.2105,  -9.9821]]])
<class 'numpy.ndarray'>
(1, 6, 13088)
[[[3.0044863e-08 8.1140666e-07 5.3718887e-02 ... 7.5160841e-08
   1.0162239e-03 1.2269986e-03]
  [3.0044863e-08 8.1140666e-07 5.3718887e-02 ... 7.5160841e-08
   1.0162249e-03 1.2269994e-03]
  [2.6168695e-10 1.0278778e-08 1.0963416e-03 ... 3.4622555e-10
   2.0836838e-08 9.0029353e-08]
  [5.9358007e-10 1.7513411e-08 1.4495271e-01 ... 1.4628808e-09
   2.7349104e-08 1.2074231e-07]
  [3.6721937e-10 1.0199308e-08 7.8793033e-05 ... 4.2641926e-10
   5.6711422e-09 5.8271072e-09]
  [1.0654655e-12 8.3270356e-11 4.1641871e-07 ... 8.8329517e-13
   6.3699354e-12 2.1308885e-11]]]

你这里pytorch的输出应该是softmax之前的,我这里的输出是softmax之后的。

from cdial-gpt-tf.

wulikai1993 avatar wulikai1993 commented on August 22, 2024

这个pytorch 的模型是直接使用的 transformers 的 pytorch版 OpenAIGPTLMHeadModel https://github.com/thu-coai/CDial-GPT/blob/master/interact.py#L132 ,按道理如果这个bert4keras版和transformers的tf版结果一致,那么 pytorch版也应该一致。

from cdial-gpt-tf.

bojone avatar bojone commented on August 22, 2024

这个pytorch 的模型是直接使用的 transformers 的 pytorch版 OpenAIGPTLMHeadModel https://github.com/thu-coai/CDial-GPT/blob/master/interact.py#L132 ,按道理如果这个bert4keras版和transformers的tf版结果一致,那么 pytorch版也应该一致。

probas和logits的区别。

from cdial-gpt-tf.

wulikai1993 avatar wulikai1993 commented on August 22, 2024

但是我将用transformers的tf版加载对比过,bert4keras版的输出跟它基本一致

不一致我试的

from cdial-gpt-tf.

qiuxia-alone avatar qiuxia-alone commented on August 22, 2024

但是我将用transformers的tf版加载对比过,bert4keras版的输出跟它基本一致

不一致我试的

@wulikai1993 老哥你配置文件用的哪个 ?

from cdial-gpt-tf.

Related Issues (2)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.