GithubHelp home page GithubHelp logo

torchvision bindings about tch-rs HOT 1 CLOSED

vesuvisian avatar vesuvisian commented on June 29, 2024
torchvision bindings

from tch-rs.

Comments (1)

vesuvisian avatar vesuvisian commented on June 29, 2024

For what it's worth, I ended up just implementing some of these functions myself, more or less transcribing the batched_nms functions from the torchvision source, and spinning my own nms and iou functions:

fn batched_nms(boxes: &Tensor, scores: &Tensor, idxs: &Tensor, iou_threshold: f32) -> Tensor {
    if boxes.numel() > (if boxes.device() == tch::Device::Cpu {4000} else {20000}) {
        _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
    } else {
        _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
    }
}

fn _batched_nms_coordinate_trick(boxes: &Tensor, scores: &Tensor, idxs: &Tensor, iou_threshold: f32) -> Tensor {
    // strategy: in order to perform NMS independently per class,
    // we add an offset to all the boxes. The offset is dependent
    // only on the class idx, and is large enough so that boxes
    // from different classes do not overlap
    if boxes.numel() == 0 {
        Tensor::empty([0], (tch::Kind::Float, boxes.device()))
    } else {
        let max_coordinate = boxes.max();
        let offsets = idxs * (max_coordinate + Tensor::ones([1], (tch::Kind::Float, boxes.device())));
        let boxes_for_nms = boxes + offsets.unsqueeze(1);
        nms(boxes_for_nms, scores, iou_threshold)
    }
}

fn _batched_nms_vanilla(boxes: &Tensor, scores: &Tensor, idxs: &Tensor, iou_threshold: f32) -> Tensor {
    // Based on Detectron2 implementation, just manually call nms() on each class independently
    let mut keep_mask = Tensor::zeros_like(scores).to_kind(tch::Kind::Bool);
    let unique = idxs.view(-1).unique_dim(0, false, false, false).0;
    for i in 0..unique.size()[0] {
        let curr_indices = Tensor::where_(&idxs.eq_tensor(&unique.i(i))).remove(0);
        let curr_keep_indices = nms(boxes.i(&curr_indices), &scores.i(&curr_indices), iou_threshold);
        keep_mask = keep_mask.index_fill(0, &curr_indices.i(&curr_keep_indices), 1);
    }
    let keep_indices = Tensor::where_(&keep_mask).remove(0);
    keep_indices.i(&scores.i(&keep_indices).sort(-1, true).1)
}

fn nms(boxes: Tensor, scores: &Tensor, iou_threshold: f32) -> Tensor {
    // Perform non-maximum suppression, returning tensor of indices to keep
    let mut sorting: Vec<i64> = scores.argsort(0, false).try_into().unwrap();
    let mut keep: Vec<i64> = Vec::new();
    while let Some(idx) = sorting.pop() {
        keep.push(idx);
        for i in (0..sorting.len()).rev() {
            if iou(&boxes.i(idx), &boxes.i(sorting[i])).double_value(&[]) > iou_threshold as f64 {
                _ = sorting.remove(i);
            }
        }
    }
    Tensor::try_from(keep).unwrap().to_device(boxes.device())
}

fn iou(box1: &Tensor, box2: &Tensor) -> Tensor {
    // Calculate Intersection over Union of two bounding boxes
    let zero = Tensor::zeros_like(&box1.i(0));
    let b1_area = (box1.i(2) - box1.i(0) + 1) * (box1.i(3) - box1.i(1) + 1);
    let b2_area = (box2.i(2) - box2.i(0) + 1) * (box2.i(3) - box2.i(1) + 1);
    let i_xmin = box1.i(0).max_other(&box2.i(0));
    let i_xmax = box1.i(2).min_other(&box2.i(2));
    let i_ymin = box1.i(1).max_other(&box2.i(1));
    let i_ymax = box1.i(3).min_other(&box2.i(3));
    let i_area = (i_xmax - i_xmin + 1).max_other(&zero) * (i_ymax - i_ymin + 1).max_other(&zero);
    &i_area / (b1_area + b2_area - &i_area)
}

I make no claims about its performance.

from tch-rs.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.