Unofficial implementation in PyTorch of the optical flow model presented in "RAFT: Recurrent All-Pairs Field Transforms for Optical Flow"
import torch
from raft import RAFT
batch_size = 2
frame_1 = torch.rand(batch_size, 3, 256, 256)
frame_2 = torch.rand(batch_size, 3, 256, 256)
optical_flow = RAFT(
small = False,
pretrained = "weights/raft-sintel.pth"
)
optical_flow_estimate = optical_flow(frame_1, frame_2, iters = 20)
- Original repository: princeton-vl/RAFT
- Code parts: hmorimitsu/ptlflow
If you plan to use RAFT in your work, please cite the original paper:
@misc{teed2020raft,
title={RAFT: Recurrent All-Pairs Field Transforms for Optical Flow},
author={Zachary Teed and Jia Deng},
year={2020},
eprint={2003.12039},
archivePrefix={arXiv},
primaryClass={cs.CV}
}