GithubHelp home page GithubHelp logo

Comments (20)

zhouhansheng avatar zhouhansheng commented on August 17, 2024

image

from spu.

xiushuiguande avatar xiushuiguande commented on August 17, 2024

@zhouhansheng hi,看结果这是符合预期的 。
1、请问您这边是什么场景和需求?
2、计划投入多少计算资源?
3、期望达到多高的精度呢 ?

from spu.

anakinxc avatar anakinxc commented on August 17, 2024

Hi @zhouhansheng

请问一下 spu 的版本是多少呀

from spu.

zhouhansheng avatar zhouhansheng commented on August 17, 2024

@zhouhansheng hi,看结果这是符合预期的 。 1、请问您这边是什么场景和需求? 2、计划投入多少计算资源? 3、期望达到多高的精度呢 ?

  1. 目前没有实际的业务需求,希望能够处理的场景是:假设有10个参与方,各自有100000个用户数据(假设经过psi处理,得到各不相同的用户数据),现在需要联合10个参与方做一些联合统计计算。
  2. 资源投入方面只和数据量有关吧,投入更多资源能提高精度?
  3. 从测试结果看,即使128bit精度,水平联合数据达到百万,结果有效位只有2位,如果没有6位以上的精度,结果可能都没有意义。

from spu.

zhouhansheng avatar zhouhansheng commented on August 17, 2024

Hi @zhouhansheng

请问一下 spu 的版本是多少呀

0.5.0b0

from spu.

anakinxc avatar anakinxc commented on August 17, 2024

Hi @zhouhansheng
请问一下 spu 的版本是多少呀

0.5.0b0

麻烦用最新的 spu 试一下?

from spu.

zhouhansheng avatar zhouhansheng commented on August 17, 2024

Hi @zhouhansheng
请问一下 spu 的版本是多少呀

0.5.0b0

麻烦用最新的 spu 试一下?

用了最新的隐语镜像,secretflow==1.3.0.dev20231212,spu==0.6.0b0,算出来结果和之前是一样的;
128位精度设置,只能保证小数据量的足够有效位?

from spu.

anakinxc avatar anakinxc commented on August 17, 2024

Hi @zhouhansheng
请问一下 spu 的版本是多少呀

0.5.0b0

麻烦用最新的 spu 试一下?

用了最新的隐语镜像,secretflow==1.3.0.dev20231212,spu==0.6.0b0,算出来结果和之前是一样的; 128位精度设置,只能保证小数据量的足够有效位?

我用 spu 0.7.0.dev20231219 测试结果并没有这么大的误差。

def statistic_fn(dataset_parties_input_list):
    data_concat = jnp.concatenate(dataset_parties_input_list, axis=0)
    return jnp.array(
        [
            jnp.median(data_concat),
            jnp.mean(data_concat),
            jnp.var(data_concat),
            jnp.std(data_concat),
            jnp.max(data_concat),
            jnp.min(data_concat),
        ]
    )


if __name__ == "__main__":
    """
    You can modify the code below for debug purpose only.
    Please DONT commit it unless it will cause build break.
    """

    sim = ppsim.Simulator.simple(2, spu_pb2.ProtocolKind.SEMI2K, spu_pb2.FieldType.FM64)

    x = np.random.uniform(1, 1000, (10000,))
    y = np.random.uniform(1, 1000, (10000,))
    
    spu_fn = ppsim.sim_jax(sim, statistic_fn)
    z = spu_fn([x, y])

    np.set_printoptions(suppress=True)

    print(f"spu out = {z}")
    print(f"cpu out = {statistic_fn([x, y])}")

结果

image

from spu.

anakinxc avatar anakinxc commented on August 17, 2024

但是如果你的数量继续增加比如两方各100000的数据量,用 FM64 是会出现你的结果,我们看一下能不能改善

from spu.

anakinxc avatar anakinxc commented on August 17, 2024

Hi @zhouhansheng

本质上的问题在于算 variance 的时候 sum(square(x-mean(x))) 的值在数据量非常大的时候,可能会溢出 FM64 的表达范围

from spu.

zhouhansheng avatar zhouhansheng commented on August 17, 2024

square

即使FM128,数据量再大些,应该也会出现这样的问题吧,对于百万级别的数据量,似乎不可避免造成比较大的误差?
隐语目前的协议中,有支持更高精度的设置吗?

from spu.

anakinxc avatar anakinxc commented on August 17, 2024

square

即使FM128,数据量再大些,应该也会出现这样的问题吧,对于百万级别的数据量,似乎不可避免造成比较大的误差? 隐语目前的协议中,有支持更高精度的设置吗?

是的,目前 SPU 里目前没有比 128bit 更大的环的支持,不过本质上,可能需要从算法层面解决这个问题
对于数据量特别大的算 var,std 这种,可能需要切成一个个 batch 来算

from spu.

anakinxc avatar anakinxc commented on August 17, 2024

Hi @zhouhansheng

举个例子,改成这样

def statistic_fn(dataset_parties_input_list):
    data_concat = jnp.concatenate(dataset_parties_input_list, axis=0)
    jmedian = jnp.median(data_concat)
    jmean = jnp.mean(data_concat)
    jmax = jnp.max(data_concat)
    jmin = jnp.min(data_concat)
    
    jvar = 0
    for a in jnp.split(data_concat, 200):
        partial_diff = a - jmean
        partial_sum = jnp.sum(partial_diff**2)
        jvar += partial_sum / len(data_concat)

    return jnp.array(
        [
            jmedian, jmean, jvar, jnp.sqrt(jvar), jmax, jmin 
        ]
    )

from spu.

zhouhansheng avatar zhouhansheng commented on August 17, 2024

Hi @zhouhansheng

举个例子,改成这样

def statistic_fn(dataset_parties_input_list):
    data_concat = jnp.concatenate(dataset_parties_input_list, axis=0)
    jmedian = jnp.median(data_concat)
    jmean = jnp.mean(data_concat)
    jmax = jnp.max(data_concat)
    jmin = jnp.min(data_concat)
    
    jvar = 0
    for a in jnp.split(data_concat, 200):
        partial_diff = a - jmean
        partial_sum = jnp.sum(partial_diff**2)
        jvar += partial_sum / len(data_concat)

    return jnp.array(
        [
            jmedian, jmean, jvar, jnp.sqrt(jvar), jmax, jmin 
        ]
    )

这好像只能解决溢出FM64表达范围的问题,算不准的问题还是没法解决,还有一些疑惑的地方:①FM64计算出的平均值是可能比FM128计算出来的更准一些?②由于①这个问题,按你这边代码运行也会导致方差和标准差FM64比FM128更准

from spu.

anakinxc avatar anakinxc commented on August 17, 2024

Hi @zhouhansheng
举个例子,改成这样

def statistic_fn(dataset_parties_input_list):
    data_concat = jnp.concatenate(dataset_parties_input_list, axis=0)
    jmedian = jnp.median(data_concat)
    jmean = jnp.mean(data_concat)
    jmax = jnp.max(data_concat)
    jmin = jnp.min(data_concat)
    
    jvar = 0
    for a in jnp.split(data_concat, 200):
        partial_diff = a - jmean
        partial_sum = jnp.sum(partial_diff**2)
        jvar += partial_sum / len(data_concat)

    return jnp.array(
        [
            jmedian, jmean, jvar, jnp.sqrt(jvar), jmax, jmin 
        ]
    )

这好像只能解决溢出FM64表达范围的问题,算不准的问题还是没法解决,还有一些疑惑的地方:①FM64计算出的平均值是可能比FM128计算出来的更准一些?②由于①这个问题,按你这边代码运行也会导致方差和标准差FM64比FM128更准

fm128 也按我的代码跑精度应该是差不多的

这里本质上的问题是mpc现在基本上都是用定点数来模拟浮点数的计算。浮点数本身encode到定点数就会有一定的误差,然后定点数计算很多非线性拟合也只能保证在一定范围内(不是非常大,也不是非常小)精度非常接近明文

本质上要全域接近明文精度,可能需要哪天有个可用的基于浮点数的 mpc 协议了。

在这之前,我认为比较可行的做法就是对于特别大的数据量,针对数值 range 调整一下算法提升精度

from spu.

zhouhansheng avatar zhouhansheng commented on August 17, 2024

Hi @zhouhansheng
举个例子,改成这样

def statistic_fn(dataset_parties_input_list):
    data_concat = jnp.concatenate(dataset_parties_input_list, axis=0)
    jmedian = jnp.median(data_concat)
    jmean = jnp.mean(data_concat)
    jmax = jnp.max(data_concat)
    jmin = jnp.min(data_concat)
    
    jvar = 0
    for a in jnp.split(data_concat, 200):
        partial_diff = a - jmean
        partial_sum = jnp.sum(partial_diff**2)
        jvar += partial_sum / len(data_concat)

    return jnp.array(
        [
            jmedian, jmean, jvar, jnp.sqrt(jvar), jmax, jmin 
        ]
    )

这好像只能解决溢出FM64表达范围的问题,算不准的问题还是没法解决,还有一些疑惑的地方:①FM64计算出的平均值是可能比FM128计算出来的更准一些?②由于①这个问题,按你这边代码运行也会导致方差和标准差FM64比FM128更准

fm128 也按我的代码跑精度应该是差不多的

这里本质上的问题是mpc现在基本上都是用定点数来模拟浮点数的计算。浮点数本身encode到定点数就会有一定的误差,然后定点数计算很多非线性拟合也只能保证在一定范围内(不是非常大,也不是非常小)精度非常接近明文

本质上要全域接近明文精度,可能需要哪天有个可用的基于浮点数的 mpc 协议了。

在这之前,我认为比较可行的做法就是对于特别大的数据量,针对数值 range 调整一下算法提升精度

spdz中的semi2k协议应该是支持更丰富的精度设置,还可以指定小数位的精度,隐语这边目前可以实现相似设置吗?

from spu.

anakinxc avatar anakinxc commented on August 17, 2024

Hi @zhouhansheng
举个例子,改成这样

def statistic_fn(dataset_parties_input_list):
    data_concat = jnp.concatenate(dataset_parties_input_list, axis=0)
    jmedian = jnp.median(data_concat)
    jmean = jnp.mean(data_concat)
    jmax = jnp.max(data_concat)
    jmin = jnp.min(data_concat)
    
    jvar = 0
    for a in jnp.split(data_concat, 200):
        partial_diff = a - jmean
        partial_sum = jnp.sum(partial_diff**2)
        jvar += partial_sum / len(data_concat)

    return jnp.array(
        [
            jmedian, jmean, jvar, jnp.sqrt(jvar), jmax, jmin 
        ]
    )

这好像只能解决溢出FM64表达范围的问题,算不准的问题还是没法解决,还有一些疑惑的地方:①FM64计算出的平均值是可能比FM128计算出来的更准一些?②由于①这个问题,按你这边代码运行也会导致方差和标准差FM64比FM128更准

fm128 也按我的代码跑精度应该是差不多的
这里本质上的问题是mpc现在基本上都是用定点数来模拟浮点数的计算。浮点数本身encode到定点数就会有一定的误差,然后定点数计算很多非线性拟合也只能保证在一定范围内(不是非常大,也不是非常小)精度非常接近明文
本质上要全域接近明文精度,可能需要哪天有个可用的基于浮点数的 mpc 协议了。
在这之前,我认为比较可行的做法就是对于特别大的数据量,针对数值 range 调整一下算法提升精度

spdz中的semi2k协议应该是支持更丰富的精度设置,还可以指定小数位的精度,隐语这边目前可以实现相似设置吗?

这里 fxp_fraction_bits 可以设置小数的 bits 数量,但是小数位数越多,整数的范围会变小,这里的 trade off 需要自己斟酌

from spu.

zhouhansheng avatar zhouhansheng commented on August 17, 2024

Hi @zhouhansheng
举个例子,改成这样

def statistic_fn(dataset_parties_input_list):
    data_concat = jnp.concatenate(dataset_parties_input_list, axis=0)
    jmedian = jnp.median(data_concat)
    jmean = jnp.mean(data_concat)
    jmax = jnp.max(data_concat)
    jmin = jnp.min(data_concat)
    
    jvar = 0
    for a in jnp.split(data_concat, 200):
        partial_diff = a - jmean
        partial_sum = jnp.sum(partial_diff**2)
        jvar += partial_sum / len(data_concat)

    return jnp.array(
        [
            jmedian, jmean, jvar, jnp.sqrt(jvar), jmax, jmin 
        ]
    )

这好像只能解决溢出FM64表达范围的问题,算不准的问题还是没法解决,还有一些疑惑的地方:①FM64计算出的平均值是可能比FM128计算出来的更准一些?②由于①这个问题,按你这边代码运行也会导致方差和标准差FM64比FM128更准

fm128 也按我的代码跑精度应该是差不多的
这里本质上的问题是mpc现在基本上都是用定点数来模拟浮点数的计算。浮点数本身encode到定点数就会有一定的误差,然后定点数计算很多非线性拟合也只能保证在一定范围内(不是非常大,也不是非常小)精度非常接近明文
本质上要全域接近明文精度,可能需要哪天有个可用的基于浮点数的 mpc 协议了。
在这之前,我认为比较可行的做法就是对于特别大的数据量,针对数值 range 调整一下算法提升精度

spdz中的semi2k协议应该是支持更丰富的精度设置,还可以指定小数位的精度,隐语这边目前可以实现相似设置吗?

这里 fxp_fraction_bits 可以设置小数的 bits 数量,但是小数位数越多,整数的范围会变小,这里的 trade off 需要自己斟酌

域的范围可以设置到192这种更大的值吗?

from spu.

anakinxc avatar anakinxc commented on August 17, 2024

Hi @zhouhansheng
举个例子,改成这样

def statistic_fn(dataset_parties_input_list):
    data_concat = jnp.concatenate(dataset_parties_input_list, axis=0)
    jmedian = jnp.median(data_concat)
    jmean = jnp.mean(data_concat)
    jmax = jnp.max(data_concat)
    jmin = jnp.min(data_concat)
    
    jvar = 0
    for a in jnp.split(data_concat, 200):
        partial_diff = a - jmean
        partial_sum = jnp.sum(partial_diff**2)
        jvar += partial_sum / len(data_concat)

    return jnp.array(
        [
            jmedian, jmean, jvar, jnp.sqrt(jvar), jmax, jmin 
        ]
    )

这好像只能解决溢出FM64表达范围的问题,算不准的问题还是没法解决,还有一些疑惑的地方:①FM64计算出的平均值是可能比FM128计算出来的更准一些?②由于①这个问题,按你这边代码运行也会导致方差和标准差FM64比FM128更准

fm128 也按我的代码跑精度应该是差不多的
这里本质上的问题是mpc现在基本上都是用定点数来模拟浮点数的计算。浮点数本身encode到定点数就会有一定的误差,然后定点数计算很多非线性拟合也只能保证在一定范围内(不是非常大,也不是非常小)精度非常接近明文
本质上要全域接近明文精度,可能需要哪天有个可用的基于浮点数的 mpc 协议了。
在这之前,我认为比较可行的做法就是对于特别大的数据量,针对数值 range 调整一下算法提升精度

spdz中的semi2k协议应该是支持更丰富的精度设置,还可以指定小数位的精度,隐语这边目前可以实现相似设置吗?

这里 fxp_fraction_bits 可以设置小数的 bits 数量,但是小数位数越多,整数的范围会变小,这里的 trade off 需要自己斟酌

域的范围可以设置到192这种更大的值吗?

目前不行,最大只能到 128

from spu.

github-actions avatar github-actions commented on August 17, 2024

Stale issue message. Please comment to remove stale tag. Otherwise this issue will be closed soon.

from spu.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.