Comments (9)

ricardoV94 avatar ricardoV94 commented on September 27, 2024 1

I'm closing this issue here, as it seems there is nothing left to do other than re-release pymc and pymc-experimental.

Let me know if I missed anything

ricardoV94 avatar ricardoV94 commented on September 27, 2024

This may be related to #1, namely because OpFromGraph does not handle special gradient variable types like DisconnectedInput. Just a very vague hypothesis though

ricardoV94 avatar ricardoV94 commented on September 27, 2024

We should check if the problem goes away again after #723

ricardoV94 avatar ricardoV94 commented on September 27, 2024

This may also have been solved before by: #340

ricardoV94 avatar ricardoV94 commented on September 27, 2024

@iAvicenna can you try running with the latest pymc-experimental? I think this is not failing anymore

iAvicenna avatar iAvicenna commented on September 27, 2024


Hello I tried running this on:

pymc_experimental: '0.1.0'
pymc: '5.13.1'
pytensor/pytensor-base: 2.20.0

It now goes for

>Metropolis: [w]
>NUTS: [mu_x, mu_y, σ]

Previously it was all metropolis and none of it could be changed to NUTs. Still can not force w to be sampled with NUTs.
I am trying these on the versions in the main branch, should I try this after installing pytensor from branch fix_OFG_grad?

ricardoV94 avatar ricardoV94 commented on September 27, 2024

Strange I think I tried from main and it was all nuts

iAvicenna avatar iAvicenna commented on September 27, 2024

Hmm so here is my environment (it is not minimal, it has numpyro, jax and cuda related libs cause I wanted to test gpu and also spyder kernels, let me know if you think any of those would cause a problem):

_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                  2_kmp_llvm    conda-forge
arviz                     0.18.0             pyhd8ed1ab_0    conda-forge
asttokens                 2.4.1              pyhd8ed1ab_0    conda-forge
atk-1.0                   2.38.0               hd4edc92_1    conda-forge
binutils_impl_linux-64    2.40                 ha885e6a_0    conda-forge
binutils_linux-64         2.40                 hdade7a5_3    conda-forge
blas                      2.122                  openblas    conda-forge
blas-devel                3.9.0           22_linux64_openblas    conda-forge
brotli                    1.1.0                hd590300_1    conda-forge
brotli-bin                1.1.0                hd590300_1    conda-forge
bzip2                     1.0.8                hd590300_5    conda-forge
c-ares                    1.28.1               hd590300_0    conda-forge
ca-certificates           2024.2.2             hbcca054_0    conda-forge
cached-property           1.5.2                hd8ed1ab_1    conda-forge
cached_property           1.5.2              pyha770c72_1    conda-forge
cachetools                5.3.3              pyhd8ed1ab_0    conda-forge
cairo                     1.18.0               h3faef2a_0    conda-forge
certifi                   2024.2.2           pyhd8ed1ab_0    conda-forge
cloudpickle               3.0.0              pyhd8ed1ab_0    conda-forge
comm                      0.2.2              pyhd8ed1ab_0    conda-forge
cons                      0.4.6              pyhd8ed1ab_0    conda-forge
contourpy                 1.2.1           py312h8572e83_0    conda-forge
cycler                    0.12.1             pyhd8ed1ab_0    conda-forge
debugpy                   1.8.1           py312h30efb56_0    conda-forge
decorator                 5.1.1              pyhd8ed1ab_0    conda-forge
dm-tree                   0.1.8           py312h72fbbdf_4    conda-forge
etuples                   0.3.9              pyhd8ed1ab_0    conda-forge
exceptiongroup            1.2.0              pyhd8ed1ab_2    conda-forge
executing                 2.0.1              pyhd8ed1ab_0    conda-forge
expat                     2.6.2                h59595ed_0    conda-forge
filelock                  3.13.4             pyhd8ed1ab_0    conda-forge
font-ttf-dejavu-sans-mono 2.37                 hab24e00_0    conda-forge
font-ttf-inconsolata      3.000                h77eed37_0    conda-forge
font-ttf-source-code-pro  2.038                h77eed37_0    conda-forge
font-ttf-ubuntu           0.83                 h77eed37_1    conda-forge
fontconfig                2.14.2               h14ed4e7_0    conda-forge
fonts-conda-ecosystem     1                             0    conda-forge
fonts-conda-forge         1                             0    conda-forge
fonttools                 4.51.0          py312h98912ed_0    conda-forge
freetype                  2.12.1               h267a509_2    conda-forge
fribidi                   1.0.10               h36c2ea0_0    conda-forge
gcc                       12.3.0               h915e2ae_6    conda-forge
gcc_impl_linux-64         12.3.0               h1562d66_6    conda-forge
gcc_linux-64              12.3.0               h6477408_3    conda-forge
gdk-pixbuf                2.42.11              hb9ae30d_0    conda-forge
giflib                    5.2.2                hd590300_0    conda-forge
graphite2                 1.3.13            h59595ed_1003    conda-forge
graphviz                  9.0.0                h78e8752_1    conda-forge
gtk2                      2.24.33              h280cfa0_4    conda-forge
gts                       0.7.6                h977cf35_4    conda-forge
gxx                       12.3.0               h915e2ae_6    conda-forge
gxx_impl_linux-64         12.3.0               h1562d66_6    conda-forge
gxx_linux-64              12.3.0               h4a1b8e8_3    conda-forge
h5netcdf                  1.3.0              pyhd8ed1ab_0    conda-forge
h5py                      3.11.0          nompi_py312h1b477d7_100    conda-forge
harfbuzz                  8.3.0                h3d44ed6_0    conda-forge
hdf5                      1.14.3          nompi_h4f84152_100    conda-forge
icu                       73.2                 h59595ed_0    conda-forge
importlib-metadata        7.1.0              pyha770c72_0    conda-forge
importlib_metadata        7.1.0                hd8ed1ab_0    conda-forge
ipykernel                 6.29.3             pyhd33586a_0    conda-forge
ipython                   8.22.2             pyh707e725_0    conda-forge
jax                       0.4.26                   pypi_0    pypi
jaxlib                    0.4.26+cuda12.cudnn89          pypi_0    pypi
jedi                      0.19.1             pyhd8ed1ab_0    conda-forge
joblib                    1.4.0              pyhd8ed1ab_0    conda-forge
jupyter_client            8.6.1              pyhd8ed1ab_0    conda-forge
jupyter_core              5.7.2           py312h7900ff3_0    conda-forge
kernel-headers_linux-64   2.6.32              he073ed8_17    conda-forge
keyutils                  1.6.1                h166bdaf_0    conda-forge
kiwisolver                1.4.5           py312h8572e83_1    conda-forge
krb5                      1.21.2               h659d440_0    conda-forge
lcms2                     2.16                 hb7c19ff_0    conda-forge
ld_impl_linux-64          2.40                 h55db66e_0    conda-forge
lerc                      4.0.0                h27087fc_0    conda-forge
libabseil                 20240116.1      cxx17_h59595ed_2    conda-forge
libaec                    1.1.3                h59595ed_0    conda-forge
libblas                   3.9.0           22_linux64_openblas    conda-forge
libbrotlicommon           1.1.0                hd590300_1    conda-forge
libbrotlidec              1.1.0                hd590300_1    conda-forge
libbrotlienc              1.1.0                hd590300_1    conda-forge
libcblas                  3.9.0           22_linux64_openblas    conda-forge
libcurl                   8.7.1                hca28451_0    conda-forge
libdeflate                1.20                 hd590300_0    conda-forge
libedit                   3.1.20191231         he28a2e2_2    conda-forge
libev                     4.33                 hd590300_2    conda-forge
libexpat                  2.6.2                h59595ed_0    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc-devel_linux-64     12.3.0             h2af2641_106    conda-forge
libgcc-ng                 13.2.0               hc881cc4_6    conda-forge
libgd                     2.3.3                h119a65a_9    conda-forge
libgfortran-ng            13.2.0               h69a702a_6    conda-forge
libgfortran5              13.2.0               h43f5ff8_6    conda-forge
libglib                   2.80.0               hf2295e7_6    conda-forge
libgomp                   13.2.0               hc881cc4_6    conda-forge
libhwloc                  2.10.0          default_h2fb2949_1000    conda-forge
libiconv                  1.17                 hd590300_2    conda-forge
libjpeg-turbo             3.0.0                hd590300_1    conda-forge
liblapack                 3.9.0           22_linux64_openblas    conda-forge
liblapacke                3.9.0           22_linux64_openblas    conda-forge
libnghttp2                1.58.0               h47da74e_1    conda-forge
libnsl                    2.0.1                hd590300_0    conda-forge
libopenblas               0.3.27          pthreads_h413a1c8_0    conda-forge
libpng                    1.6.43               h2797004_0    conda-forge
librsvg                   2.58.0               hadf69e7_1    conda-forge
libsanitizer              12.3.0               h2af2641_6    conda-forge
libsodium                 1.0.18               h36c2ea0_1    conda-forge
libsqlite                 3.45.3               h2797004_0    conda-forge
libssh2                   1.11.0               h0841786_0    conda-forge
libstdcxx-devel_linux-64  12.3.0             h2af2641_106    conda-forge
libstdcxx-ng              13.2.0               h95c4c6d_6    conda-forge
libtiff                   4.6.0                h1dd3fc0_3    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libwebp                   1.3.2                h658648e_1    conda-forge
libwebp-base              1.3.2                hd590300_1    conda-forge
libxcb                    1.15                 h0b41bf4_0    conda-forge
libxcrypt                 4.4.36               hd590300_1    conda-forge
libxml2                   2.12.6               h232c23b_2    conda-forge
libzlib                   1.2.13               hd590300_5    conda-forge
llvm-openmp               18.1.3               h4dfa4b3_0    conda-forge
logical-unification       0.4.6              pyhd8ed1ab_0    conda-forge
markdown-it-py            3.0.0              pyhd8ed1ab_0    conda-forge
matplotlib-base           3.8.4           py312he5832f3_0    conda-forge
matplotlib-inline         0.1.7              pyhd8ed1ab_0    conda-forge
mdurl                     0.1.2              pyhd8ed1ab_0    conda-forge
minikanren                1.0.3              pyhd8ed1ab_0    conda-forge
mkl                       2023.2.0         h84fe81f_50496    conda-forge
mkl-service               2.4.1           py312h4daa2fd_0    conda-forge
ml-dtypes                 0.4.0                    pypi_0    pypi
multipledispatch          0.6.0                      py_0    conda-forge
munkres                   1.1.4              pyh9f0ad1d_0    conda-forge
ncurses                   6.4.20240210         h59595ed_0    conda-forge
nest-asyncio              1.6.0              pyhd8ed1ab_0    conda-forge
numpy                     1.26.4          py312heda63a1_0    conda-forge
numpyro                   0.14.0                   pypi_0    pypi
nvidia-cublas-cu12                 pypi_0    pypi
nvidia-cuda-cupti-cu12    12.4.127                 pypi_0    pypi
nvidia-cuda-nvcc-cu12     12.4.131                 pypi_0    pypi
nvidia-cuda-nvrtc-cu12    12.4.127                 pypi_0    pypi
nvidia-cuda-runtime-cu12  12.4.127                 pypi_0    pypi
nvidia-cudnn-cu12                 pypi_0    pypi
nvidia-cufft-cu12                 pypi_0    pypi
nvidia-cusolver-cu12                 pypi_0    pypi
nvidia-cusparse-cu12               pypi_0    pypi
nvidia-nccl-cu12          2.21.5                   pypi_0    pypi
nvidia-nvjitlink-cu12     12.4.127                 pypi_0    pypi
openblas                  0.3.27          pthreads_h7a3da1a_0    conda-forge
openjpeg                  2.5.2                h488ebb8_0    conda-forge
openssl                   3.2.1                hd590300_1    conda-forge
opt-einsum                3.3.0                    pypi_0    pypi
packaging                 24.0               pyhd8ed1ab_0    conda-forge
pandas                    2.2.2           py312hfb8ada1_0    conda-forge
pango                     1.52.2               ha41ecd1_0    conda-forge
parso                     0.8.4              pyhd8ed1ab_0    conda-forge
patsy                     0.5.6              pyhd8ed1ab_0    conda-forge
pcre2                     10.43                hcad00b1_0    conda-forge
pexpect                   4.9.0              pyhd8ed1ab_0    conda-forge
pickleshare               0.7.5                   py_1003    conda-forge
pillow                    10.3.0          py312hdcec9eb_0    conda-forge
pip                       24.0               pyhd8ed1ab_0    conda-forge
pixman                    0.43.2               h59595ed_0    conda-forge
platformdirs              4.2.0              pyhd8ed1ab_0    conda-forge
prompt-toolkit            3.0.42             pyha770c72_0    conda-forge
psutil                    5.9.8           py312h98912ed_0    conda-forge
pthread-stubs             0.4               h36c2ea0_1001    conda-forge
ptyprocess                0.7.0              pyhd3deb0d_0    conda-forge
pure_eval                 0.2.2              pyhd8ed1ab_0    conda-forge
pygments                  2.17.2             pyhd8ed1ab_0    conda-forge
pymc                      5.13.1               hd8ed1ab_0    conda-forge
pymc-base                 5.13.1             pyhd8ed1ab_0    conda-forge
pymc-experimental         0.1.0                    pypi_0    pypi
pyparsing                 3.1.2              pyhd8ed1ab_0    conda-forge
pytensor                  2.20.0          py312h30efb56_1    conda-forge
pytensor-base             2.20.0          py312hfb8ada1_1    conda-forge
python                    3.12.3          hab00c5b_0_cpython    conda-forge
python-dateutil           2.9.0              pyhd8ed1ab_0    conda-forge
python-graphviz           0.20.3             pyh717bed2_0    conda-forge
python-tzdata             2024.1             pyhd8ed1ab_0    conda-forge
python_abi                3.12                    4_cp312    conda-forge
pytz                      2024.1             pyhd8ed1ab_0    conda-forge
pyzmq                     26.0.2          py312h8fd38d8_0    conda-forge
readline                  8.2                  h8228510_1    conda-forge
rich                      13.7.1             pyhd8ed1ab_0    conda-forge
scikit-learn              1.4.2           py312h394d371_0    conda-forge
scipy                     1.13.0          py312heda63a1_0    conda-forge
seaborn                   0.13.2               hd8ed1ab_0    conda-forge
seaborn-base              0.13.2             pyhd8ed1ab_0    conda-forge
setuptools                69.5.1             pyhd8ed1ab_0    conda-forge
six                       1.16.0             pyh6c4a22f_0    conda-forge
spyder-kernels            2.5.1           unix_pyh707e725_0    conda-forge
stack_data                0.6.2              pyhd8ed1ab_0    conda-forge
statsmodels               0.14.1          py312hc7c0aa3_0    conda-forge
sysroot_linux-64          2.12                he073ed8_17    conda-forge
tbb                       2021.12.0            h00ab1b0_0    conda-forge
threadpoolctl             3.4.0              pyhc1e730c_0    conda-forge
tk                        8.6.13          noxft_h4845f30_101    conda-forge
toolz                     0.12.1             pyhd8ed1ab_0    conda-forge
tornado                   6.4             py312h98912ed_0    conda-forge
tqdm                      4.66.2                   pypi_0    pypi
traitlets                 5.14.3             pyhd8ed1ab_0    conda-forge
typing-extensions         4.11.0               hd8ed1ab_0    conda-forge
typing_extensions         4.11.0             pyha770c72_0    conda-forge
tzdata                    2024a                h0c530f3_0    conda-forge
wcwidth                   0.2.13             pyhd8ed1ab_0    conda-forge
wheel                     0.43.0             pyhd8ed1ab_1    conda-forge
wurlitzer                 3.0.3              pyhd8ed1ab_0    conda-forge
xarray                    2024.3.0           pyhd8ed1ab_0    conda-forge
xarray-einstats           0.7.0              pyhd8ed1ab_0    conda-forge
xorg-kbproto              1.0.7             h7f98852_1002    conda-forge
xorg-libice               1.1.1                hd590300_0    conda-forge
xorg-libsm                1.2.4                h7391055_0    conda-forge
xorg-libx11               1.8.9                h8ee46fc_0    conda-forge
xorg-libxau               1.0.11               hd590300_0    conda-forge
xorg-libxdmcp             1.1.3                h7f98852_0    conda-forge
xorg-libxext              1.3.4                h0b41bf4_2    conda-forge
xorg-libxrender           0.9.11               hd590300_0    conda-forge
xorg-renderproto          0.11.1            h7f98852_1002    conda-forge
xorg-xextproto            7.3.0             h0b41bf4_1003    conda-forge
xorg-xproto               7.0.31            h7f98852_1007    conda-forge
xz                        5.2.6                h166bdaf_0    conda-forge
zeromq                    4.3.5                h59595ed_1    conda-forge
zipp                      3.17.0             pyhd8ed1ab_0    conda-forge
zlib                      1.2.13               hd590300_5    conda-forge
zstd                      1.5.5                hfc55251_0    conda-forge

When I run the code above (forcing NUTs to everything), I get

Traceback (most recent call last):

  File ~/miniconda3/envs/pymc_experimental_env2/lib/python3.12/site-packages/spyder_kernels/ in compat_exec
    exec(code, globals, locals)

  File ~/Dropbox/data_analysis/MODELING/ngs_pfra_proportions_with_rf_sally/
    step = pm.NUTS([weights, mu_x, mu_y, sigma])

  File ~/miniconda3/envs/pymc_experimental_env2/lib/python3.12/site-packages/pymc/step_methods/hmc/ in __init__
    super().__init__(vars, **kwargs)

  File ~/miniconda3/envs/pymc_experimental_env2/lib/python3.12/site-packages/pymc/step_methods/hmc/ in __init__
    super().__init__(vars, blocked=blocked, model=self._model, dtype=dtype, **pytensor_kwargs)

  File ~/miniconda3/envs/pymc_experimental_env2/lib/python3.12/site-packages/pymc/step_methods/ in __init__
    func = model.logp_dlogp_function(vars, dtype=dtype, **pytensor_kwargs)

  File ~/miniconda3/envs/pymc_experimental_env2/lib/python3.12/site-packages/pymc/model/ in logp_dlogp_function
    return ValueGradFunction(costs, grad_vars, extra_vars_and_values, **kwargs)

  File ~/miniconda3/envs/pymc_experimental_env2/lib/python3.12/site-packages/pymc/model/ in __init__
    grads = pytensor.grad(cost, grad_vars, disconnected_inputs="ignore")

  File ~/miniconda3/envs/pymc_experimental_env2/lib/python3.12/site-packages/pytensor/ in grad
    _rval: Sequence[Variable] = _populate_grad_dict(

  File ~/miniconda3/envs/pymc_experimental_env2/lib/python3.12/site-packages/pytensor/ in _populate_grad_dict
    rval = [access_grad_cache(elem) for elem in wrt]

  File ~/miniconda3/envs/pymc_experimental_env2/lib/python3.12/site-packages/pytensor/ in access_grad_cache
    term = access_term_cache(node)[idx]

  File ~/miniconda3/envs/pymc_experimental_env2/lib/python3.12/site-packages/pytensor/ in access_term_cache
    output_grads = [access_grad_cache(var) for var in node.outputs]

  File ~/miniconda3/envs/pymc_experimental_env2/lib/python3.12/site-packages/pytensor/ in access_grad_cache
    term = access_term_cache(node)[idx]

  File ~/miniconda3/envs/pymc_experimental_env2/lib/python3.12/site-packages/pytensor/ in access_term_cache
    input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)

  File ~/miniconda3/envs/pymc_experimental_env2/lib/python3.12/site-packages/pytensor/graph/ in L_op
    return self.grad(inputs, output_grads)

  File ~/miniconda3/envs/pymc_experimental_env2/lib/python3.12/site-packages/pytensor/graph/ in grad
    raise NotImplementedError()


ricardoV94 avatar ricardoV94 commented on September 27, 2024

You're totally right!

I reproduced the example on Colab:

I must have tried the wrong example or with some bleeding edge versions of pytensor/pymc not yet released. Let me confirm

EDIT: It needs the yet unreleased version of PyMC that's on main branch. I updated the example above and it now works. Sampling will be a bit slow unless you also use the main of pymc-experimental. We got a speed up in the logp in: pymc-devs/pymc-experimental#337

EDIT: I updated the example to also use bleeding pymc-experimental. To reproduce the problem just do !pip install pymc-experimental==0.1.0 instead of the git route, and don't do anything about pymc

