Пример классификации изображений с моделями с нуля на PyTorch. Датасет Animals-10 взят с сайта Kaggle.
Дообучение готовых моделей здесь.
- Подготовить виртуальную среду. Использовался
python 3.11.8
- Установить зависимости.
pip install -r requirements.txt
- В терминале запустить
main.py
Пример:python main.py path/to/dataset densenet
Если написатьpython main.py
без аргументов, то выведется подробная инструкция по использованию. - Информация об обучении сохраняется в папку
runs
.
Подробнее об аргументах
- Обязательные аргументы:
- путь до данных - данные могут быть в формате папок train, val, test, или сразу папки с классами. В этом случае для val будет отобрано 10% датасета.
- модель или путь до модели - есть подготовленные модели mobilenetv1 и densenet. Можно так же указать путь до готовой модели в формате .pt или .pth.
- Необязательные аргументы:
- alpha - множитель для модели mobilenetv1.
- reps - список повторений слоёв для densenet.
- bottleneck - для densenet использовать или нет стиль bottleneck
- batch - размер батча
- epochs - количество эпох
- img_size - размер изображений
- augs - какие аугментации применять к изображениям
- name - имя проекта, под которым будет сохраняться информация об обучении
python main.py path/to/dataset mobilenetv1 alpha=0.5 batch=32 epochs=100 img_size=224 name=experiment1
python main.py path/to/dataset/ densenet reps=[2,4,8] bottleneck=True augs='soft'
Выполнить python predict.py path/to/best.pt path/to/images
Результаты обучения сохраняются в папку runs/project_name
со следующей информацией:
![]() |
![]() |
![]() |
Достигнута точность ~80%: Как можно её улучшить:
- экспериментировать с аугментацией
- экспериментировать с гиперпараметрами(количество эпох, размер батчей, размер изображений и т.д.)
- собрать больше данных
- Использовать другие архитектуры (например, ViT, ResNet, EfficientNet)
- использовать fine-tuning и feature extraction предобученных моделей
![]() |
![]() |
![]() |
![]() |
![]() |
- save best model
- add tensorboard support
- add args.yaml file to store all parameters
- add resume training from saved model