Implementation of MNIST classification with NEAT algorithm with backpropagation in JAX.
python3 -m venv .venv
source .venv/bin/activate
pip install hydra-core "jax[cpu]" flax tensorflow matplotlib tqdm scikit-learn
pip install torch torchvision torchaudio
pip install --upgrade setuptools
Note: Change JAX installation type to fit your environment.
While the NEAT algorithm encompasses various aspects in its implementation, our emphasis lies solely on the backpropagation JAX implementation of the addition and removal nodes from the graph, rather than the population and species concept of the algorithm.
python3 main.py
Note: Change hydra configuration file for other datasets (Iris or Digits)
After running main.py
you can access all the training logs. Below is the accuracy graph for the MNIST dataset.