GithubHelp home page GithubHelp logo

[Bug]: SPU precision issue? about spu HOT 8 CLOSED

linzzzzzz avatar linzzzzzz commented on August 17, 2024
[Bug]: SPU precision issue?

from spu.

Comments (8)

tpppppub avatar tpppppub commented on August 17, 2024

fxp_fraction_bits=18 may be not enough for your case. try a larger fxp fraction bits.

from spu.

anakinxc avatar anakinxc commented on August 17, 2024

Issue Type

Performance

Modules Involved

SPU runtime

Have you reproduced the bug with SPU HEAD?

Yes

Have you searched existing issues?

Yes

SPU Version

spu 0.9.0b1

OS Platform and Distribution

Linux

Python Version

3.10

Compiler Version

No response

Current Behavior?

A bug happened!

Standalone code to reproduce the issue

### Array A, B, C, D are generated from np.random.randn
A = np.array([[-4.16757847e-01, -5.62668272e-02, -2.13619610e+00,
         1.64027081e+00, -1.79343559e+00, -8.41747366e-01,
         5.02881417e-01, -1.24528809e+00, -1.05795222e+00,
        -9.09007615e-01],
       [ 5.51454045e-01,  2.29220801e+00,  4.15393930e-02,
        -1.11792545e+00,  5.39058321e-01, -5.96159700e-01,
        -1.91304965e-02,  1.17500122e+00, -7.47870949e-01,
         9.02525097e-03],
       [-8.78107893e-01, -1.56434170e-01,  2.56570452e-01,
        -9.88779049e-01, -3.38821966e-01, -2.36184031e-01,
        -6.37655012e-01, -1.18761229e+00, -1.42121723e+00,
        -1.53495196e-01],
       [-2.69056960e-01,  2.23136679e+00, -2.43476758e+00,
         1.12726505e-01,  3.70444537e-01,  1.35963386e+00,
         5.01857207e-01, -8.44213704e-01,  9.76147160e-06,
         5.42352572e-01]])

B = np.array([[-0.3135082 ,  0.77101174, -1.86809065,  1.73118467,  1.46767801,
        -0.33567734,  0.61134078,  0.04797059, -0.82913529,  0.08771022],
       [ 1.00036589, -0.38109252, -0.37566942, -0.07447076,  0.43349633,
         1.27837923, -0.63467931,  0.50839624,  0.21611601, -1.85861239],
       [-0.41931648, -0.1323289 , -0.03957024,  0.32600343, -2.04032305,
         0.04625552, -0.67767558, -1.43943903,  0.52429643,  0.73527958],
       [-0.65325027,  0.84245628, -0.38151648,  0.06648901, -1.09873895,
         1.58448706, -2.65944946, -0.09145262,  0.69511961, -2.03346655]])

C = np.array([[-0.18946926, -0.07721867,  0.82470301,  1.24821292, -0.40389227,
        -1.38451867,  1.36723542,  1.21788563, -0.46200535,  0.35088849],
       [ 0.38186623,  0.56627544,  0.20420798,  1.40669624, -1.7379595 ,
         1.04082395,  0.38047197, -0.21713527,  1.1735315 , -2.34360319],
       [ 1.16152149,  0.38607805, -1.13313327,  0.43309255, -0.30408644,
         2.58529487,  1.83533272,  0.44068987, -0.71925384, -0.58341459],
       [-0.32504963, -0.56023451, -0.90224607, -0.59097228, -0.27617949,
        -0.51688389, -0.69858995, -0.92889192,  2.55043824, -1.47317325]])

D = np.array([[-1.02141473,  0.4323957 , -0.32358007,  0.42382471,  0.79918   ,
         1.26261366,  0.75196485, -0.99376098,  1.10914328, -1.76491773],
       [-0.1144213 , -0.49817419, -1.06079904,  0.59166652, -0.18325657,
         1.01985473, -1.48246548,  0.84631189,  0.49794015,  0.12650418],
       [-1.41881055, -0.25177412, -1.54667461, -2.08265194,  3.2797454 ,
         0.97086132,  1.79259285, -0.42901332,  0.69619798,  0.69741627],
       [ 0.60151581,  0.00365949, -0.22824756, -2.06961226,  0.61014409,
         0.4234969 ,  1.11788673, -0.27424209,  1.74181219, -0.44750088]])

A_1 = np.square(A)
B_1 = B + 0.1
C_1 = np.square(C)
D_1 = D + 0.1

s_0, s_1 = A_1.shape
I = jnp.tile(jnp.arange(s_1), (s_0,1))
Z = jnp.zeros(A_1.shape, dtype=int)


def compare(x1, x2):

    y0 = (x1[3]*x2[3]*(x1[0]*x2[1]-x2[0]*x1[1])+x1[1]*x2[1]*(x1[2]*x2[3]-x2[2]*x1[3]))*(x1[1]*x1[3]*x2[1]*x2[3])
    y1 = x1[5]
    z = y0 > y1
    
    zz_0 = jax.lax.select(z, x1[0], x2[0])
    zz_1 = jax.lax.select(z, x1[1], x2[1])
    zz_2 = jax.lax.select(z, x1[2], x2[2])
    zz_3 = jax.lax.select(z, x1[3], x2[3])
    zz_4 = jax.lax.select(z, x1[4], x2[4])
    zz_5 = jax.lax.select(z, x1[5], x2[5])

    return [zz_0,zz_1,zz_2,zz_3,zz_4,zz_5]


### plaintext calculation
fn = lambda a,b,c,d,e,f: jax.lax.reduce([a,b,c,d,e,f], [0.0,0.0,0.0,0.0,0,0], compare, [0])
res = vmap(fn)(A_1, B_1, C_1, D_1, I, Z)

### SPU simulation
config = spu.RuntimeConfig(
    protocol=spu.spu_pb2.ProtocolKind.CHEETAH,
    field=spu.spu_pb2.FieldType.FM64,
    fxp_fraction_bits=18,
)

simulator = pps.Simulator(2, config)

def reduce_arr(a,b,c,d,e,f):
    return vmap(fn)(a,b,c,d,e,f)

spu_argmax = pps.sim_jax(simulator, reduce_arr)

z = spu_argmax(A_1, B_1, C_1, D_1, I, Z)

print(z)

Relevant log output

output of plaintext calculation:
[Array([1.5507424, 1.2497573, 0.0557829, 0.7126968], dtype=float32),
 Array([0.14797059, 0.02552924, 0.14625552, 0.00854738], dtype=float32),
 Array([1.4832454, 1.9787943, 6.6837497, 0.8628402], dtype=float32),
 Array([-0.893761  ,  0.69166654,  1.0708613 , -0.1742421 ], dtype=float32),
 Array([7, 3, 5, 7], dtype=int32),
 Array([0, 0, 0, 0], dtype=int32)]

output of SPU simulation:
[array([1.5507393 , 1.2497559 , 0.05578232, 4.9789963 ], dtype=float32), 
array([0.14796829, 0.02552795, 0.1462555 , 0.9424553 ], dtype=float32), 
array([1.483242  , 1.9787941 , 6.6837463 , 0.31386185], dtype=float32), 
array([-0.8937607 ,  0.69166565,  1.070858  ,  0.10365677], dtype=float32), 
array([7, 3, 5, 1], dtype=int32), 
array([0, 0, 0, 0], dtype=int32)]

Hi @linzzzzzz

All MPC protocols implemented in SPU are based on fixed-point math, which is not as accurate as real floating-point. fxp_fraction_bits=18 here gives you an epsilon around 4e-6 (2^-18).

Since you are doing multiplications with small numbers, you may need a higher fxp_fraction_bits to get a more accurate result. Also be aware, more fraction bits means less bits for whole number.

from spu.

linzzzzzz avatar linzzzzzz commented on August 17, 2024

thanks for the comments! Will give a try tomorrow

from spu.

linzzzzzz avatar linzzzzzz commented on August 17, 2024

@anakinxc @tpppppub
In my experiments using real-world data, seems like the major problem is the overflow (not underflow) from calculationy0 = (x1[3]*x2[3]*(x1[0]*x2[1]-x2[0]*x1[1])+x1[1]*x2[1]*(x1[2]*x2[3]-x2[2]*x1[3]))*(x1[1]*x1[3]*x2[1]*x2[3])

I have tried fxp_fraction_bits=18 with spu.spu_pb2.FieldType.FM128 however still not able to resolve the overflow.

I wonder if you have any best practices on resolving the overflow issue like this? One specific question I have right now is:
does the order of calculation matter? for example, could something like 1000000x1000000-999999x1000000 cause overflow while (1000000-999999)x1000000 not cause overflow?

Thank you so much for your prompt responses.

from spu.

anakinxc avatar anakinxc commented on August 17, 2024

@anakinxc @tpppppub In my experiments using real-world data, seems like the major problem is the overflow (not underflow) from calculationy0 = (x1[3]*x2[3]*(x1[0]*x2[1]-x2[0]*x1[1])+x1[1]*x2[1]*(x1[2]*x2[3]-x2[2]*x1[3]))*(x1[1]*x1[3]*x2[1]*x2[3])

I have tried fxp_fraction_bits=18 with spu.spu_pb2.FieldType.FM128 however still not able to resolve the overflow.

I wonder if you have any best practices on resolving the overflow issue like this? One specific question I have right now is: does the order of calculation matter? for example, could something like 1000000x1000000-999999x1000000 cause overflow while (1000000-999999)x1000000 not cause overflow?

Thank you so much for your prompt responses.

Hi @linzzzzzz

You need a larger fxp_fraction_bits for this case.

Yes, order of calculation does matter (same as calculation on CPU or GPU).

from spu.

linzzzzzz avatar linzzzzzz commented on August 17, 2024

thanks @anakinxc

A larger fxp_fraction_bits can only solve the underflow issue but not overflow right?

What I'm struggling with now is the overflow issue in our real-world data. I know the sample data I posted in this issue is about an underflow issue but that's no longer a blocker for me in our real-world data.

from spu.

anakinxc avatar anakinxc commented on August 17, 2024

thanks @anakinxc

A larger fxp_fraction_bits can only solve the underflow issue but not overflow right?

What I'm struggling with now is the overflow issue in our real-world data. I know the sample data I posted in this issue is about an underflow issue but that's no longer a blocker for me in our real-world data.

Yes, one thing to notice is there is no INF/-INF in fixed-point, so there is no good ways to 100% match floating-point overflow behaviors with fixed-point representations used by MPC protocols.

One thing to consider is tweak the algorithm to be less overflow sensitive. For example, manually batching data when computing avg/mean

from spu.

linzzzzzz avatar linzzzzzz commented on August 17, 2024

Thanks @anakinxc!

from spu.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.