Comments (1)
For what it's worth, I ended up just implementing some of these functions myself, more or less transcribing the batched_nms functions from the torchvision source, and spinning my own nms and iou functions:
fn batched_nms(boxes: &Tensor, scores: &Tensor, idxs: &Tensor, iou_threshold: f32) -> Tensor {
if boxes.numel() > (if boxes.device() == tch::Device::Cpu {4000} else {20000}) {
_batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
} else {
_batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
}
}
fn _batched_nms_coordinate_trick(boxes: &Tensor, scores: &Tensor, idxs: &Tensor, iou_threshold: f32) -> Tensor {
// strategy: in order to perform NMS independently per class,
// we add an offset to all the boxes. The offset is dependent
// only on the class idx, and is large enough so that boxes
// from different classes do not overlap
if boxes.numel() == 0 {
Tensor::empty([0], (tch::Kind::Float, boxes.device()))
} else {
let max_coordinate = boxes.max();
let offsets = idxs * (max_coordinate + Tensor::ones([1], (tch::Kind::Float, boxes.device())));
let boxes_for_nms = boxes + offsets.unsqueeze(1);
nms(boxes_for_nms, scores, iou_threshold)
}
}
fn _batched_nms_vanilla(boxes: &Tensor, scores: &Tensor, idxs: &Tensor, iou_threshold: f32) -> Tensor {
// Based on Detectron2 implementation, just manually call nms() on each class independently
let mut keep_mask = Tensor::zeros_like(scores).to_kind(tch::Kind::Bool);
let unique = idxs.view(-1).unique_dim(0, false, false, false).0;
for i in 0..unique.size()[0] {
let curr_indices = Tensor::where_(&idxs.eq_tensor(&unique.i(i))).remove(0);
let curr_keep_indices = nms(boxes.i(&curr_indices), &scores.i(&curr_indices), iou_threshold);
keep_mask = keep_mask.index_fill(0, &curr_indices.i(&curr_keep_indices), 1);
}
let keep_indices = Tensor::where_(&keep_mask).remove(0);
keep_indices.i(&scores.i(&keep_indices).sort(-1, true).1)
}
fn nms(boxes: Tensor, scores: &Tensor, iou_threshold: f32) -> Tensor {
// Perform non-maximum suppression, returning tensor of indices to keep
let mut sorting: Vec<i64> = scores.argsort(0, false).try_into().unwrap();
let mut keep: Vec<i64> = Vec::new();
while let Some(idx) = sorting.pop() {
keep.push(idx);
for i in (0..sorting.len()).rev() {
if iou(&boxes.i(idx), &boxes.i(sorting[i])).double_value(&[]) > iou_threshold as f64 {
_ = sorting.remove(i);
}
}
}
Tensor::try_from(keep).unwrap().to_device(boxes.device())
}
fn iou(box1: &Tensor, box2: &Tensor) -> Tensor {
// Calculate Intersection over Union of two bounding boxes
let zero = Tensor::zeros_like(&box1.i(0));
let b1_area = (box1.i(2) - box1.i(0) + 1) * (box1.i(3) - box1.i(1) + 1);
let b2_area = (box2.i(2) - box2.i(0) + 1) * (box2.i(3) - box2.i(1) + 1);
let i_xmin = box1.i(0).max_other(&box2.i(0));
let i_xmax = box1.i(2).min_other(&box2.i(2));
let i_ymin = box1.i(1).max_other(&box2.i(1));
let i_ymax = box1.i(3).min_other(&box2.i(3));
let i_area = (i_xmax - i_xmin + 1).max_other(&zero) * (i_ymax - i_ymin + 1).max_other(&zero);
&i_area / (b1_area + b2_area - &i_area)
}
I make no claims about its performance.
from tch-rs.
Related Issues (20)
- Load model once and run inference across threads?
- Build fails unless I downgrade `google-glog` to 0.6.0. HOT 5
- Copying Scalars from Torch to Rust? HOT 1
- Second-order utils ?
- Any plan on `at::Generator`?
- Could not be compliered. HOT 1
- Error with: Internal torch error: can't fopen
- concat doesn't support gradient HOT 1
- Double free or corruption (fasttop) HOT 1
- grads become zeros after a short period of training on metal backend
- Can we `.set_retains_grad(true)` ?
- model in rust, optimizer.step in python
- Preserving gradients with copy()? HOT 2
- Can't compile or test tch-rs HOT 3
- What if I am not using `pyo3==0.18.3`? HOT 4
- Error when building burn on Windows when upgrading to tch 0.16 HOT 2
- Can you help me setting environment variables? HOT 2
- la HOT 1
- Publish pyo3-tch 0.16 HOT 2
- getting gradient for intermediate tensors
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tch-rs.