GithubHelp home page GithubHelp logo

A differentiable program for mapping brain function

Continuous Integration Documentation Code Coverage GitHub License Preprint

https://raw.githubusercontent.com/hypercoil/hypercoil/xrecore/docs/source/_static/logo.png

In functional neuroimaging and adjacent fields, the advent of large, public data repositories has brought with it a proliferation of instruments and methods for analysing these data. This has introduced new challenges for the field: How can we ensure that our analyses are reproducible? How can we ensure that our analyses are valid? Conditioned on our dataset and our scientific question, how do we choose from among the available methods to design an analytic workflow in a principled way? How can we know that the workflow we've designed is suited to answering the questions we are asking?

We typically approach the problem of designing a scientific workflow combinatorially: we begin from a set of analytic options, and we choose from among these to configure our workflow. The hypercoil library provides a framework for going beyond this combinatorial approach, and for designing principled workflows that are differentiable. Instead of selecting from among a set of available methods, we can learn the (locally) best workflow for our dataset and our scientific question. This approach is particularly well-suited to the functional neuroimaging domain, where the data are high-dimensional and the scientific questions are often complex.

Built upon the same principles that power deep neural networks, this library provides software instruments for designing, deploying, and evaluating both differentiable programs and GPU-accelerated workflows. Our current focus is on applications related to fMRI-derived brain connectivity, but the library is designed to be eventually generalisable to other domains.

Public warning: At this time, this software should be used as if it were in a pre-alpha state. Many operations are fragile or incompletely documented. Edge cases might not be covered, and there are certainly bugs lurking in the code base. Expect breaking changes as the code base is further expanded and refined. Contributions or ideas for improvement are always welcome.

Technical overview

This library is implemented using JAX, which combines a NumPy-like API with automatic differentiation and support for GPU acceleration. The library is designed to be modular and extensible, and to be used in conjunction with existing tools in the Python ecosystem. The library is currently under active development, and is not yet ready for use outside of research and development.

functional and nn: composable differentiable functionals

The functional module provides a set of composable differentiable functionals, which can be used to construct differentiable programs. Common steps in a functional neuroimaging workflow are provided as pre-defined functionals, including (among others):

  • cov : Covariance estimation: empirical covariance, Pearson correlation, partial correlation, conditional correlation, and others
  • fourier : Frequency-domain operations, such as temporal filtering
  • graph : Graph-theoretic operations, such as graph Laplacian estimation and community detection
  • interpolate : Methods for temporal interpolation over artefact-contaminated time frames
  • kernel : Similarity kernels and pairwise distance metrics
  • resid : Residualisation and regression
  • semidefinite : Projection between the positive semidefinite cone and tangent spaces
  • sphere : Operations on spherical approximations to the cortex, such as geodesics and spherical convolution

The nn module provides a set of neural network layers that can be used to construct differentiable programs. These layers provide an alternative API to the functional module and also include more complicated parameterised functionals. They are implemented using the JAX-based Equinox library.

init: functional parameterisation

The init module provides mechanisms for parameterising differentiable functionals without learning. This includes a set of pre-defined parameterisations that incorporate domain knowledge from functional neuroimaging. These components can be used to implement pre-existing workflows or to learn a new workflow starting from a pre-existing one.

Parameterisation includes both initialisation and mapping to a constrained space. For example, the init.atlas module provides a set of parameterisations that correspond to different types of brain atlases, such as surface atlases, volumetric atlases, discrete parcellations, and probabilistic functional modes. Complementarily, the init.mapparam module uses transformations to constrain parameters to a particular subspace or manifold, such as the positive semidefinite cone, the sphere, or the probability simplex.

loss: learning signals

The loss module provides a set of learning signals (i.e., loss functions) that can be used to train differentiable programs. Loss functions are designed for use in combination with Equinox's filters and the excellent optax library for optimisation.

Loss functions are implemented compositionally using a functional API, and comprise two components: a score function and a scalarisation. The score function maps tensors from a differentiable program to a tensor of scores, and the scalarisation maps the scores to a scalar loss. The loss module provides a set of pre-defined score functions and scalarisations with applications in functional neuroimaging and beyond.

formula: functional grammar

This library also includes an extensible functional grammar for various purposes. Internally, we use it to implement confound model specification, an FSLmaths-like API for image manipulation, and a syntax for addressing and filtering neural network parameters.

viz: visualisation

Visualisation utilities will include (inter alia) a PyVista-based 3D visualisation API for plotting brain surfaces, atlases, and networks, and a set of utilities for plotting brain connectivity matrices. These utilities will be designed to automatically read information from differentiable models using a functional reporting system. This framework remains under development.

Installation

Right now, just pip install from GitHub. Come back in a few weeks and ask again about PyPI.

A simple example

Here's a small example that shows how the above modules can be combined to construct a simple differentiable program for first filtering a time series, next estimating its correlation conditioned on a confound model, and finally projecting the estimated covariance out of the positive semidefinite cone and into a tangent space. The model is then trained using a simple loss function that promotes correlations with a large magnitude.

Note that this is not a particularly useful model, but it serves to illustrate the basic principles. (Astute readers will also remark several instances in the code of incorrect or oversimplified processing decisions. This is intentional, as this vignette is not intended to be instructional with regard to functional neuroimaging.)

import json
from functools import partial
from pkg_resources import resource_filename as pkgrf

import jax
import jax.numpy as jnp
import equinox as eqx
import optax
import pandas as pd

from hypercoil.formula import ConfoundFormulaGrammar
from hypercoil.functional import conditionalcorr
from hypercoil.init import (
    FreqFilterSpec,
    DirichletInitialiser,
    MappedLogits,
    SPDGeometricMean,
)
from hypercoil.loss import (
    bimodal_symmetric,
    vnorm_scalarise,
)
from hypercoil.neuro.synth import (
    synthesise_matched,
)
from hypercoil.nn import (
    FrequencyDomainFilter,
    TangentProject,
    BinaryCovariance,
)

#-----------------------------------------------------------------------------#
# 1. Generate some synthetic data: first, configure the dimensions.
max_epoch = 10
log_interval = 1
n_subjects = 10
n_voxels = 400
n_time_points = 200
n_channels = 4  # Data channels: These could be different connectivity
                #                "states" captured by the covariance.
                #                Or, if we made the weights fixed rather
                #                than trainable, they could be different
                #                pipeline configurations for multiverse
                #                analysis.
key = jax.random.PRNGKey(0)
data_key, filter_key, cov_key, proj_key = jax.random.split(key, 4)

#-----------------------------------------------------------------------------#
# 2. Create a synthetic time series with spectrum and covariance matched to
#    a parcellated human brain.
ref_path = pkgrf(
    'hypercoil',
    'examples/synthetic/data/synth-regts/atlas-schaefer400_desc-synth_ts.tsv'
)
ref_data = pd.read_csv(ref_path, sep='\t', header=None).values.T
reference = jnp.array(ref_data)

X = synthesise_matched(
    reference=reference,
    key=key,
)[..., :n_time_points]

#-----------------------------------------------------------------------------#
# 3. Define the confound model. Let's use a standard 36-parameter model with
#    censoring.
confounds = pkgrf('hypercoil', 'examples/data/desc-confounds_timeseries.tsv')
metadata = pkgrf('hypercoil', 'examples/data/desc-confounds_timeseries.json')
confounds = pd.read_csv(confounds, sep='\t')
with open(metadata) as file:
    metadata = json.load(file)

# Specify the confound model using a formula.
model_36p = 'dd1((rps + wm + csf + gsr)^^2)'
model_censor = '[SCATTER]([OR](1_[>0.5](fd) + 1_[>1.5](dv)))'
model_formula = f'{model_36p} + {model_censor}'

# Parse the formula into a function.
f = ConfoundFormulaGrammar().compile(model_formula)
confounds, metadata = f(confounds, metadata)
confounds = confounds.fillna(0)
confounds = jnp.array(confounds.values).T[..., :n_time_points]

#-----------------------------------------------------------------------------#
# 4. Create the differentiable program.

# Define a parameterisation for the filter. Here, we're using an ideal
# bandpass filter with a frequency range of 0.01-0.1 Hz.
high_pass, low_pass = 0.01, 0.1
filter_spec = FreqFilterSpec(Wn=(high_pass, low_pass), ftype='ideal')

# Define a parameterisation for the tangent projection. Here, we're using
# the geometric mean of the covariance matrices as the initial point of
# tangency.
proj_spec = SPDGeometricMean(psi=1e-3)

# Instantiate the filter layer using the parameterisation we defined above.
filter = FrequencyDomainFilter.from_specs(
    (filter_spec,),
    time_dim=n_time_points,
    key=filter_key,
)
# Using the `MappedLogits` parameter mapping, we can constrain the filter
# weights within the range (0, 1). Each weight then represents the
# attenuation of amplitude in a frequency band.
filter = MappedLogits.map(filter, where='weight')

# Instantiate the covariance estimator layer.
cov = BinaryCovariance(
    estimator=conditionalcorr,
    dim=n_time_points,
    out_channels=n_channels,
    l2=0.1,
    key=cov_key,
)
# Let's initialise the covariance weights from a Dirichlet distribution.
cov = DirichletInitialiser.init(
    cov,
    concentration=[1.0] * n_channels,
    where='weight',
    axis=0,
    key=cov_key,
)
# Note that the Dirichlet initialiser automatically transforms our
# weight into a `ProbabilitySimplexParameter`! This way, the weights
# are always guaranteed to be valid categorical probability distributions.

# Instantiate the tangent projection layer using the parameterisation
# we defined above.
init_data = cov(filter(X), filter(confounds))
proj = TangentProject.from_specs(
    mean_specs=(proj_spec,),
    init_data=init_data,
    recondition=1e-5,
    key=proj_key,
)

# Finally, let's create the program that combines the filter, covariance
# estimator, and tangent projection layers.
class Model(eqx.Module):
    filter: FrequencyDomainFilter
    cov: BinaryCovariance
    proj: TangentProject

    def __call__(self, x, confounds, *, key):
        x, confounds = self.filter(x), self.filter(confounds)
        x = self.cov(x, confounds)
        x = self.proj(x, key=key)
        return x

model = Model(filter=filter, cov=cov, proj=proj)

#-----------------------------------------------------------------------------#
# 5. Define a learning signal. The "bimodal symmetric" score measures the
#    distance from each element in the correlation matrix to the nearest
#    of two modes. By setting the modes to -1 and 1, we assign large scores to
#    weak correlations and small scores to strong correlations.
#
#    The "vnorm scalarise" function then takes the matrix of scores and
#    converts it into a scalar by summing the absolute values of the scores.
#    Later, we'll use an optimisation algorithm to minimise this scalar score,
#    thereby promoting strong correlations.

scalarisation = vnorm_scalarise(p=1, axis=None)
score = partial(bimodal_symmetric, modes=(-1, 1))
loss = scalarisation(score) # We are composing the two functions here to
                            # create a new function that takes a matrix
                            # and returns a scalar.

#-----------------------------------------------------------------------------#
# 6. Define the "forward pass" of the differentiable program. This is the
#    function that maps from input data to the output score.
def forward(model, X, confounds, *, key):
    return loss(model(X, confounds, key=key))

#-----------------------------------------------------------------------------#
# 7. Configure the optimisation algorithm. Here, we're using Adam with a
#    learning rate of 5e-4.
opt = optax.adam(5e-4)
opt_state = opt.init(eqx.filter(model, eqx.is_inexact_array))

#-----------------------------------------------------------------------------#
# 8. Define a function that updates the model parameters and returns the
#    updated parameters and the loss.
def update(model, opt_state, X, confounds, *, key):
    value, grad = eqx.filter_value_and_grad(forward)(
        model, X, confounds, key=key)
    updates, opt_state = opt.update(
        eqx.filter(grad, eqx.is_inexact_array),
        opt_state,
        eqx.filter(model, eqx.is_inexact_array),
    )
    model = eqx.apply_updates(model, updates)
    return model, opt_state, value

#-----------------------------------------------------------------------------#
# 9. Run the optimisation loop.
for i in range(max_epoch):
    model, opt_state, value = eqx.filter_jit(update)(
        model, opt_state, X, confounds, key=jax.random.fold_in(key, i))
    if i % log_interval == 0:
        print(f'Iteration {i}: loss = {value:.3f}')

hypercoil's Projects

gramform icon gramform

Grammar for string-to-function formulae

hypercoil icon hypercoil

A differentiable program for mapping brain function

hyve icon hyve

Interactive and static 3D visualisation for functional brain mapping

lytemaps icon lytemaps

Minimal subset of neuromaps functionality

minimsc icon minimsc

Minimal MSC dataset for working examples

notebooks icon notebooks

Jupyter notebooks for tutorials and demonstrations

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.