- Python >=3.11
- Rye
- Download M5 Dataset following the instructions in Reamde
- Run
./dataset/M5/extract.sh
- Run
rye sync
to download the required dependencies - Run
rye run python pretrainm5.py
for pre-training (atleast for 10 epochs) - Run
rye run python trainm5.py
for training (atleast for 100 epochs)
- Requires atleast 70 GB VRAM with Mixed precision
- Toggle
SCALE_PREC = False
intrainm5.py
to use FP16 to run on GPUs of less than 40 GB VRAM