astro-informatics / s2wav Goto Github PK
View Code? Open in Web Editor NEWDifferentiable and accelerated wavelet transform on the sphere with JAX
Home Page: https://astro-informatics.github.io/s2wav/
License: MIT License
Differentiable and accelerated wavelet transform on the sphere with JAX
Home Page: https://astro-informatics.github.io/s2wav/
License: MIT License
Once the project has progressed sufficiently, overhaul the GitHub readme, the (hidden) .pip readme, and the top level of the documentation. Update all badges, ensuring that they link to the correct urls, and update the package description and contributors.
Difficulty:
-Very Low
A catchy, yet informative, package name would be great!
Configure auto-docs (requires google style documentation to be done first!), and add details where necessary to the documentation.
Difficulty:
This is a good issue to pick up as a first jump into JAX. I would recommend first reading this introduction for JAX. A few hints would be:
Also remember to consider static arguments (ones that don't change between evaluations i.e. L, this can help when JITing things.
As a preliminary exercise it would be worthwhile working through the demo (example) scripts provided by ssht (found here) and then reading through the examples (sadly only in C) provided by s2let (e.g. here).
As always feel free to throw any questions my way, these can be a little tricky to understand at first.
Go through the functions within filter.py
and wherever possible move to a vectorised implementation.
Make sure we're testing where appropriate. Should try and get code coverage > 90% (roughy rule of thumb).
Using the logging module already included, ensure logging is included to important functions within the package (where appropriate)
Difficulty:
-Very Low
Currently we call SSHT
and SO3
to perform the harmonic and Wigner transforms where required (see e.g. here). We should switch this out to S2FFT
functionality when available.
Translate adjoint of the synthesis wavelet transform from s2let (c) to s2wav (base python).
Difficulty:
Background:
S2let File to translate:
S2Wav File location:
Notes:
Translate of the analysis wavelet transform from s2let (c) to s2wav (base python).
Difficulty:
Background:
S2let File to translate:
S2Wav File location:
Notes:
Translate of the synthesis wavelet transform from s2let (c) to s2wav (base python).
Difficulty:
Background:
S2let File to translate:
S2Wav File location:
Notes:
Hi, I was just wondering if there is a built-in implementation to transform the harmonic coefficients of the wavelets back into real space to visualize the wavelets on the sphere as depicted in the ReadMe.
Add testing metrics to be used for testing package functions
Difficulty:
Metric list:
x -> x' = inverse( forward( x ) )
S2Wav File location:
Notes:
I am trying to compute directional wavelet transformation of a Healpix map. I have tried using both s2wav.analysis
, s2wav.wavelet.flm_to_analysis
(with map to flm separately computed with s2fft
). I am encountering an AssertionError
.
Minimal example:
nside = 128
lmax = 2 * nside
N = 3
hpx_map = np.ones((12*nside**2,))
filter_bank = sw.filters.filters_directional_vectorised(lmax, N)
wavelet_coeffs, scaling_coeffs = sw.analysis(hpx_map, lmax, N, nside=nside, filters=filter_bank, sampling='healpix')
Fails with the following error message:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[116], line 8
5 hpx_map = np.ones((12*nside**2,))
7 filter_bank = sw.filters.filters_directional_vectorised(lmax, N)
----> 8 wavelet_coeffs, scaling_coeffs = sw.analysis(hpx_map, lmax, N, nside=nside, filters=filter_bank, sampling='healpix')
[... skipping hidden 11 frame]
File /pscratch/sd/s/shamikg/cmbenv/master-0.0.1/conda/lib/python3.10/site-packages/s2wav/transforms/wavelet.py:189, in analysis(f, L, N, J_min, lam, spin, sampling, nside, reality, filters, precomps)
174 Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True)
175 f_wav_lmn[j - J_min] = (
176 f_wav_lmn[j - J_min]
177 .at[::2, L0j:]
(...)
185 )
186 )
188 f_wav.append(
--> 189 s2fft.wigner.inverse_jax(
190 f_wav_lmn[j - J_min],
191 Lj,
192 Nj,
193 nside,
194 sampling,
195 reality,
196 precomps[j - J_min],
197 L0j,
198 )
199 )
201 # Project all harmonic coefficients for each lm onto scaling coefficients
202 phi = filters[1][:Ls] * jnp.sqrt(4 * jnp.pi / (2 * jnp.arange(Ls) + 1))
[... skipping hidden 11 frame]
File /pscratch/sd/s/shamikg/cmbenv/master-0.0.1/conda/lib/python3.10/site-packages/s2fft/transforms/wigner.py:257, in inverse_jax(flmn, L, N, nside, sampling, reality, precomps, L_lower)
251 precomps = [p0, p1, p2, p3, p4]
252 return (-1) ** jnp.abs(spin) * s2fft.inverse_jax(
253 flm, L, -spin, nside, sampling, False, precomps, False, L_lower
254 )
256 fban = fban.at[N - 1 + n_start_ind :].set(
--> 257 vmap(
258 partial(func, p2=precomps[2][0], p3=precomps[3][0], p4=precomps[4][0]),
259 in_axes=(0, 0, 0, 0),
260 )(flmn[N - 1 + n_start_ind :], spins, precomps[0], precomps[1])
261 )
262 if reality:
263 f = jnp.fft.irfft(fban[N - 1 :], 2 * N - 1, axis=0, norm=\"forward\")
[... skipping hidden 3 frame]
File /pscratch/sd/s/shamikg/cmbenv/master-0.0.1/conda/lib/python3.10/site-packages/s2fft/transforms/wigner.py:252, in inverse_jax.<locals>.func(flm, spin, p0, p1, p2, p3, p4)
250 def func(flm, spin, p0, p1, p2, p3, p4):
251 precomps = [p0, p1, p2, p3, p4]
--> 252 return (-1) ** jnp.abs(spin) * s2fft.inverse_jax(
253 flm, L, -spin, nside, sampling, False, precomps, False, L_lower
254 )
[... skipping hidden 11 frame]
File /pscratch/sd/s/shamikg/cmbenv/master-0.0.1/conda/lib/python3.10/site-packages/s2fft/transforms/spherical.py:319, in inverse_jax(flm, L, spin, nside, sampling, reality, precomps, spmd, L_lower)
315 ftm = ftm.at[:, m_offset : L - 1 + m_offset].set(
316 jnp.flip(jnp.conj(ftm[:, L - 1 + m_offset + 1 :]), axis=-1)
317 )
318 if sampling.lower() == \"healpix\":
--> 319 return hp.healpix_ifft(ftm, L, nside, \"jax\")
320 else:
321 ftm = jnp.conj(jnp.fft.ifftshift(ftm, axes=1))
File /pscratch/sd/s/shamikg/cmbenv/master-0.0.1/conda/lib/python3.10/site-packages/s2fft/utils/healpix_ffts.py:398, in healpix_ifft(ftm, L, nside, method, reality)
368 def healpix_ifft(
369 ftm: np.ndarray,
370 L: int,
(...)
373 reality: bool = False,
374 ) -> np.ndarray:
375 \"\"\"Wrapper function for the Inverse Fast Fourier Transform with spectral folding
376 in the polar regions to mitigate aliasing.
377
(...)
396 np.ndarray: HEALPix pixel-space array.
397 \"\"\"
--> 398 assert L >= 2 * nside
399 if method.lower() == \"numpy\":
400 return healpix_ifft_numpy(ftm, L, nside, reality)
AssertionError:
Hello,
I want to try the s2wav lib according to astro-informatics/s2let#50 (comment). When I try the code demo, I found an error occurs in s2wav/transforms/jax_wavelets.py", line 260. It seams like that the filters
is needful to perform wavelet transform(but it should be optional ). Did I do anything wrong?
Here is my test code:
import s2wav
import numpy as np
L = 128
N = 1
f = np.ones((L, 2*L-1))
f_wav, f_scal = s2wav.analysis(f, L, N)
f = s2wav.synthesis(f_wav, f_scal, L, N)
Here is the error:
Traceback (most recent call last):
File "test_wav.py", line 7, in <module>
f_wav, f_scal = s2wav.analysis(f, L, N)
File "/home/gaowenxuan/anaconda3/envs/py38/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/gaowenxuan/anaconda3/envs/py38/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/gaowenxuan/anaconda3/envs/py38/lib/python3.8/site-packages/jax/_src/pjit.py", line 158, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _, jaxpr = infer_params_fn(
File "/home/gaowenxuan/anaconda3/envs/py38/lib/python3.8/site-packages/jax/_src/api.py", line 300, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/gaowenxuan/anaconda3/envs/py38/lib/python3.8/site-packages/jax/_src/pjit.py", line 499, in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
File "/home/gaowenxuan/anaconda3/envs/py38/lib/python3.8/site-packages/jax/_src/pjit.py", line 961, in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
File "/home/gaowenxuan/anaconda3/envs/py38/lib/python3.8/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
ans = call(fun, *args)
File "/home/gaowenxuan/anaconda3/envs/py38/lib/python3.8/site-packages/jax/_src/pjit.py", line 914, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
File "/home/gaowenxuan/anaconda3/envs/py38/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/gaowenxuan/anaconda3/envs/py38/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/gaowenxuan/anaconda3/envs/py38/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/home/gaowenxuan/anaconda3/envs/py38/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/gaowenxuan/anaconda3/envs/py38/lib/python3.8/site-packages/s2wav/transforms/jax_wavelets.py", line 260, in analysis
jnp.conj(filters[0]),
jax._src.traceback_util.UnfilteredStackTrace: TypeError: 'NoneType' object is not subscriptable
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "test_wav.py", line 7, in <module>
f_wav, f_scal = s2wav.analysis(f, L, N)
File "/home/gaowenxuan/anaconda3/envs/py38/lib/python3.8/site-packages/s2wav/transforms/jax_wavelets.py", line 260, in analysis
jnp.conj(filters[0]),
TypeError: 'NoneType' object is not subscriptable
I have used the pytest
to test s2wav
and I pass the test.
========================================================== test session starts ===========================================================
platform linux -- Python 3.8.16, pytest-7.3.1, pluggy-1.0.0
rootdir: /home/gaowenxuan/Code/s2wav
configfile: pytest.ini
collected 288 items
tests/test_filters.py ........................................ [ 13%]
tests/test_gradients.py ........ [ 16%]
tests/test_wavelets.py ..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss [ 44%]
tests/test_wavelets_base.py ..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss.. [ 79%]
ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss [100%]
============================================== 168 passed, 120 skipped in 447.66s (0:07:27) ==============================================
Thanks.
Add a utility function to take parameters and store them in a dictionary to reduce floating arguments
Difficulty:
Background:
S2Let File:
S2Wav File location:
Notes:
Go through the functions within synthesis.py
and wherever possible move to a vectorised implementation.
Every good package should have an eye catching logo! Currently we just have a placeholder.
Translate the adjoint of the analysis wavelet transform from s2let (c) to s2wav (base python).
Difficulty:
Background:
S2let File to translate:
S2Wav File location:
Notes:
Currently both the NumPy and JAX wavelet transforms will not automatically generate wavelet filters if non are passed. However, the docstrings imply that filters is an optional argument, which is not the case.
We should either:
As a preliminary exercise it would be worthwhile reading through the original ssht paper to understand the underlying spherical harmonic transforms, and the original s2let paper to understand roughly how the wavelet transform operates on the sphere.
As always feel free to throw any questions my way, these can be a little tricky to understand at first.
Go through the functions within tiling.py
and wherever possible move to a vectorised implementation.
This will make it easier to work with HEALPix, for example.
Translate functions which you can call to check dimensions of vectors from s2let (c) to s2wav (base python).
Difficulty:
Background:
S2let Files to translate:
S2Wav File location:
Notes:
Exploits the conjugate symmetry, i.e. f_{l, -m} = (-1)^m f^*_{l, m}, to avoid computing the coefficients for negative m, as they are just conjugate to the positive values. Consequently this reduces both the number of computations and the amount of memory required by ~ a factor of 2. See this line in S2FFT where we generate an explicitly real signal, from which hopefully its clear what is meant by conjugate symmetry
Ensure all docstrings (the summary provided at the top of each function, describing what inputs/outputs the functions should expect and briefly what the function does) are up to date and in line with best practices -- see this guide.
Difficulty:
Add support for additional wavelet kernel generating functions. Specifically
@all-contributors please add @CosmoMatt for Code
Go through the functions within analysis.py
and wherever possible move to a vectorised implementation.
Currently we support mw
sampling but we should make sure we support mwss
, dh
and healpix
too. I suspect these are all straightforward to support, with much of the heavy lifting being done externally to this package.
Consider the following
import numpy as np
import pys2let
import s2wav
L, j_min, B = 1024, 2, 3
old_kappa = pys2let.axisym_wav_l(B, L, j_min)[1]
new_kappa = s2wav.filter_factory.filters.filters_axisym(L, J_min=j_min, lam=B)[0][j_min:].T
np.testing.assert_equal(old_kappa.shape, new_kappa.shape)
np.testing.assert_equal((old_kappa == np.inf).sum(), 0)
np.testing.assert_equal((new_kappa == np.inf).sum(), 2) # infinite values
The new method of calculating kappa
results in some infinite values, which should be close to 1
or in practical terms equal to 1
, and not infinity. This will cause issues when making tiling plots, as scipy.interpolate.pchip
cannot handle infinite values, e.g. https://github.com/astro-informatics/sleplet/blob/f61b2af612d0c6a3a462770dbdffb8ba463a3a40/examples/arbitrary/south_america/tiling_south_america.py#L16-L46
Translate basic math tiling functions from s2let (c) to s2wav (base python).
Difficulty:
Background:
S2let File to translate:
S2Wav File location:
Notes:
Not sure if I'm missing something or the library is designed only for massive GPUs. I'm running on an M1 MacBook Pro with top specs. This is specifically for s2wav.filter_factory.filters.filters_axisym...
, but I've noticed similar for other functions (& s2fft
).
The following works fine for me from L=16
to L=2048
. It is fine, but feels a little slow.
import pys2let
import s2wav
import time
B = 3
J_MIN = 2
ELL_MIN = 4
ELL_MAX = 11
for ell in range(ELL_MIN, ELL_MAX + 1):
L = 2**ell
t0 = time.time()
pys2let.axisym_wav_l(B, L, J_MIN)
t1 = time.time()
s2wav.filter_factory.filters.filters_axisym(L, J_MIN, B)
t2 = time.time()
s2wav.filter_factory.filters.filters_axisym_vectorised(L, J_MIN, B)
t3 = time.time()
s2wav.filter_factory.filters.filters_axisym_jax(L, J_MIN, B)
t4 = time.time()
print(
f"L={L:>4} | pys2let={t1-t0:.0e}s | s2wav={t2-t1:.0e}s | "
f"s2wav_vec={t3-t2:.0e}s | s2wav_jax={t4-t3:.0e}s",
)
# L= 16 | pys2let=7e-05s | s2wav=7e-03s | s2wav_vec=7e-03s | s2wav_jax=9e-02s
# L= 32 | pys2let=9e-05s | s2wav=1e-02s | s2wav_vec=1e-02s | s2wav_jax=2e-01s
# L= 64 | pys2let=2e-04s | s2wav=3e-02s | s2wav_vec=3e-02s | s2wav_jax=4e-01s
# L= 128 | pys2let=4e-04s | s2wav=5e-02s | s2wav_vec=5e-02s | s2wav_jax=8e-01s
# L= 256 | pys2let=7e-04s | s2wav=1e-01s | s2wav_vec=1e-01s | s2wav_jax=2e+00s
# L= 512 | pys2let=1e-03s | s2wav=2e-01s | s2wav_vec=2e-01s | s2wav_jax=5e+00s
# L=1024 | pys2let=3e-03s | s2wav=4e-01s | s2wav_vec=4e-01s | s2wav_jax=2e+01s
# L=2048 | pys2let=7e-03s | s2wav=8e-01s | s2wav_vec=8e-01s | s2wav_jax=5e+01s
Now if we ramp up to L=4096
, it returns something like 1e-02s
, and s2wav
hangs (and at higher L
just never finishes).
import pys2let
import s2wav
import time
B = 3
J_MIN = 2
L = 4096
t0 = time.time()
pys2let.axisym_wav_l(B, L, J_MIN)
t1 = time.time()
print(f"{t1 - t0:.0e}")
Translate the core functions which generate the wavelet filters from s2let (c) to s2wav (base python).
Difficulty:
Background:
S2let File to translate:
S2Wav File location:
Notes:
Switch out the manual integration (trapezium rule) with and efficient inbuilt function -- see e.g. scipy.quad.
Currently we're working with flattened 1D arrays (as is the norm for C programming), but we should switch to multi-dimensional arrays (and later Tensors) for python.
See section 3 of this paper . In the code you'll see this coming up with terms like this, where for a given j the bandlimit will be different. Notice also that this shouldn't just be an upper limit but also a lower limit, i.e. a given wavelet scale j will only have non-zero harmonic coefficients between some restricted range L_0j to L_j. This part of the code update should be relatively straightforward to implement in numpy, but may be more difficult when we come to write JAX versions.
This is a good issue to pick up as a first jump into JAX. I would recommend first reading this introduction for JAX. A few hints would be:
Also remember to consider static arguments (ones that don't change between evaluations i.e. L, this can help when JITing things.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.