GithubHelp home page GithubHelp logo

markusschmitt / vmc_jax Goto Github PK

View Code? Open in Web Editor NEW
55.0 5.0 16.0 1.05 MB

Impementation of Variational Monte Carlo (VMC) for quantum many-body dynamics using JAX.

License: MIT License

Python 100.00%

vmc_jax's Introduction

Documentation Status PyPi version arXiv

jVMC

This is an impementation of Variational Monte Carlo (VMC) for quantum many-body dynamics using the JAX library (and Flax on top) to exploit the blessings of automatic differentiation for easy model composition and just-in-time compilation for execution on accelerators.

  1. Documentation
  2. Installation
  3. Online example
  4. Important gotchas
  5. Citing jVMC

Please report bugs as well as other issues or suggestions on our issues page.

Documentation

Documentation is available here.

Installation

Option 1: pip-install

  1. We recommend you create a new conda environment to work with jVMC:

     conda create -n jvmc python=3.8
     conda activate jvmc
    
  2. pip-install the package

     pip install jVMC
    

Option 2: Clone and pip-install

  1. Clone the jVMC repository and check out the development branch:

     git clone https://github.com/markusschmitt/vmc_jax.git
     cd vmc_jax
    
  2. We recommend you create a new conda environment to work with jVMC:

     conda create -n jvmc python=3.8
     conda activate jvmc
    
  3. pip-install the package

     pip install .  
    

    Alternatively, for development:

     pip install -e .[dev]
    

Test that everything worked, e.g. run 'python -c "import jVMC"' from a place different than vmc_jax.

Compiling JAX

How to compile JAX on a supercomputing cluster

Online example

Open In Colab

Click on the badge above to open a notebook that implements an exemplary ground state search in Google Colab.

Important gotchas

Out-of-memory issues and batching

Memory requirements grow with increasing network sizes. To avoid out-of-memory issues, the batchSize parameter of the NQS class has to be adjusted. The batchSize indicates on how many input configurations the network is evaluated concurrently. Out-of-memory issues are usually resolved by reducing this number. The numChains parameter of the Sampler class for Markov Chain Monte Carlo sampling plays a similar role, but its optimal values in terms of computational speed are typically not memory critical.

Citing jVMC

If you use the jVMC package for your research, please cite our reference paper SciPost Phys. Codebases 2 (2022).

vmc_jax's People

Contributors

emergentspacetime avatar jonasrigo avatar lagrange2art avatar laurinbrunner avatar markusschmitt avatar rehmoritz avatar tszoldra avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

vmc_jax's Issues

Complex Weights Not Properly Saved in Network Checkpoints

I have come across an issue where complex weights in network checkpoints are not being saved properly. Currently, only the real parts of the complex numbers are saved and the imaginary parts are discarded. This behavior was encountered when using the write_network_checkpoint method of the OutputManager class.

batched Oloc always allocates complex zeros

The get_O_loc_batched function of Operator class calls the _alloc_Oloc_pmapd function which always allocates zeros with dtype global_defs.tCpx. In the POVM case the POVMOperator class returns real matrix elements when get_s_primes is called.
Inserting real valued numbers into a complex jax array with jax.lax.dynamic_update_slice will result in the following error:

TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got complex128, float64.

adding operator_string to branchfree_operator can lead to indexing errors

Adding operator strings (here just Sz(i)) to the hamiltonian without prefactor (i.e. not using jvmc.operator.scal_opstr) can lead to an indexing error when a certrain number of them are added (here L>=50 to repdrocude the error). Replacing in the following example the line h.add((jvmc.operator.Sz(i),)) by h.add(jvmc.operator.scal_opstr(1., (jvmc.operator.Sz(i),))) solves the issue for me.

It seems that if L is large enough, compilation is distributed via MPI (see branch_free.py line 224 and following), which then leads to the error because the prefactors (if that is the same as the actual factors provided in scal_opstr) are considered and the pure operator string does not have one(?). A possible solution would be to point out in the documentation to always use the scal_opstr method (even if the prefactor is 1) or to automatically apply scal_opstr with a factor of 1 if not done explicitly.

A minimal example to reproduce the error:

import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)
import jVMC as jvmc

# define transverse field Ising Hamiltonian
L = 50  # works with L <= 49

h = jvmc.operator.BranchFreeOperator()

for i in range(L):
    h.add((jvmc.operator.Sz(i),))
    h.add(jvmc.operator.scal_opstr(1., (jvmc.operator.Sx(i),)))

s = jnp.ones(shape=(1, 12, L), dtype=jnp.int32)
primes = h.get_s_primes(s)

The error message:

IndexError                                Traceback (most recent call last)
Cell In[1], line 16
     13     h.add(jvmc.operator.scal_opstr(1., (jvmc.operator.Sx(i),)))
     15 s = jnp.ones(shape=(1, 12, L), dtype=jnp.int32)
---> 16 primes = h.get_s_primes(s)

File [~/Programming/jvmc_playground/jvmc_cpu/lib/python3.12/site-packages/jVMC/operator/base.py:141](http://localhost:8888/lab/tree/notebooks/~/Programming/jvmc_playground/jvmc_cpu/lib/python3.12/site-packages/jVMC/operator/base.py#line=140), in Operator.get_s_primes(self, s, *args)
    139 if type(fun) is tuple:
    140     self.arg_fun = fun[1]
--> 141     args = self.arg_fun(*args)
    142     fun = fun[0]
    143 else:

File [~/Programming/jvmc_playground/jvmc_cpu/lib/python3.12/site-packages/jVMC/operator/branch_free.py:240](http://localhost:8888/lab/tree/notebooks/~/Programming/jvmc_playground/jvmc_cpu/lib/python3.12/site-packages/jVMC/operator/branch_free.py#line=239), in BranchFreeOperator.compile.<locals>.arg_fun(prefactor, init, *args)
    238     res = init[myStart:myEnd]
    239     for i,f in prefactor[myStart:myEnd]:
--> 240         res[i-myStart] = f(*args)
    242     res = np.concatenate(comm.allgather(res), axis=0)
    244 return (jnp.array(res), )

IndexError: index 51 is out of bounds for axis 0 with size 50

Examples 5 and 6 raise "KeyError"

Since the latest change in jVMC/util/util.py examples 5 and 6 raise "KeyError: batch_size". Likewise test povm_t.py does no longer run without errors.

CrossValidation cannot select subset of Eloc

The TDVP class tries to get Eloc[:, 0::2] in its __call__ method if crossValidation is True, but the SampledObs object is not subscriptable.
Also the call of solve still uses the old parameters.

jVMC.util.util.init_net raises AttributeError

The function jVMC.util.util.init_net raises the error:

AttributeError: module 'jVMC.nets' has no attribute 'CpxRNN'

Since the class CpxRNN has been removed but is still referred to in this function.

Diagonalization on device can raise unhandeled exception

The diagonalization on the GPU device in TDVP.transform_to_eigenbasis done with jax.numpy.linalg.eigh can sometimes raise a ValueError. In this case the calculations cancels and needs to be restarted with the diagonalizeOnDevice parameter set to False, which makes the function fall back to the numpy CPU version of eigh.

I suggest to change this behaviour so that if diagonalizeOnDevice is True it first tries to use jax.numpy.linalg.eigh and if it fails it falls back to the numpy version automatically without the need to restart the entire calculation.

Error message:
ValueError: INTERNAL: CustomCall failed: jaxlib/cusolver_kernels.cc:444: operation cusolverDnSsyevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w, static_cast<float*>(work), d.lwork, info) failed: cuSolver execution failed

Non-holomorphic networks with complex parameters not treated correctly

If we define the class

class MatrixMultiplication_NonHolomorphic(nn.Module):
    holo: bool = False

    @nn.compact
    def __call__(self, s):
        layer1 = nn.Dense(1, use_bias=False, **init_fn_args(dtype=global_defs.tCpx))
        out = layer1(2 * s.ravel() - 1)
        if not self.holo:
            out = out + 1e-1 * jnp.real(out)
        return jnp.sum(out)

and let holo=True, the gradients computed by psi.gradients for the input s=[0, 0, 0, 0] are
[-1.+0.j -1.+0.j -1.+0.j -1.+0.j -0.-1.j -0.-1.j -0.-1.j -0.-1.j] as expected.

However, if we let holo=False, the returned gradients are [-2.1+0.j -2.1+0.j -2.1+0.j -2.1+0.j]. This means, that if we for example did time evolution with this setup, considering only the imaginary part of the S-matrix we would get all-zeroes.

If we add the gradient function

def flat_gradient_cpx_nonholo(fun, params, arg):
    gr = grad(lambda p, y: jnp.real(fun.apply(p, y)))(params, arg)["params"]
    gi = grad(lambda p, y: jnp.imag(fun.apply(p, y)))(params, arg)["params"]
    g = tree_flatten(tree_map(lambda x, y: [x.ravel(), -y.ravel()], gr, gi))[0]
    return jnp.concatenate(g)

the returned gradients are [-1.1+0.j -1.1+0.j -1.1+0.j -1.1+0.j -0. -1.j -0. -1.j -0. -1.j -0. -1.j], which is in line with the above case where holo=True.

Please improve error for omission of op.BranchFreeOperator(lDim=<local_hilbert_dinension>)

Dear Devs,
I have noticed that when one creates a Hamiltonian based on a local Hilbert space that has dimension other than 2 one has to set the lDim variable in op.BranchFreeOperator to pad the operator strings with identities.
If one forgets to do that or does it incorrectly like so

import jVMC.operator as op
hamiltonian = op.BranchFreeOperator(lDim=1)
hamiltonian.add(op.scal_opstr( 1., ( op.Sz(0), op.Sz(0) ) ) )
hamiltonian.add(op.scal_opstr( 1., ( op.Sx(0),) ) )
hamiltonian.compile()

the error message is somewhat cryptic. Thus, I propose to change line 220 in branch_free.py as follows:

try:
    self.mapC = jnp.array(self.map, dtype=np.int32)
except Exception as e:
    raise ValueError("Check that you have set <local_hilbert_dinension> in op.BranchFreeOperator(lDim=<local_hilbert_dinension>) correctly.") from e

Best, Jonas

Frequent recompilation in v1.2.0

We introduced too frequent recompilation with v1.2.0, by moving the jit_my_stuff() to the global_defs module. This severely impedes performance.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.