GithubHelp home page GithubHelp logo

ming71 / dal Goto Github PK

View Code? Open in Web Editor NEW
227.0 7.0 39.0 1.94 MB

[AAAI 2021] Official implementation of "Dynamic Anchor Learning for Arbitrary-Oriented Object Detection".

License: Apache License 2.0

Python 86.72% Makefile 0.04% C++ 1.07% Cuda 9.30% SWIG 0.06% Shell 0.02% Cython 2.80%
dal arbitrary-oriented-object-detection

dal's Introduction

DAL

This project hosts the official implementation for our AAAI 2021 paper:

Dynamic Anchor Learning for Arbitrary-Oriented Object Detection [paper] [comments].

Abstract

In this paper, we propose a dynamic anchor learning (DAL) method, which utilizes the newly defined matching degree to comprehensively evaluate the localization potential of the anchors and carry out a more efficient label assignment process. In this way, the detector can dynamically select high-quality anchors to achieve accurate object detection, and the divergence between classification and regression will be alleviated.

Getting Started

The codes build Rotated RetinaNet with the proposed DAL method for rotation object detection. The supported datasets include: DOTA, HRSC2016, ICDAR2013, ICDAR2015, UCAS-AOD, NWPU VHR-10, VOC.

Installation

Insatll requirements:

pip install -r requirements.txt
pip install git+git://github.com/lehduong/torch-warmup-lr.git

Build the Cython and CUDA modules:

cd $ROOT/utils
sh make.sh
cd $ROOT/utils/overlaps_cuda
python setup.py build_ext --inplace

Installation for DOTA_devkit:

cd $ROOT/datasets/DOTA_devkit
sudo apt-get install swig
swig -c++ -python polyiou.i
python setup.py build_ext --inplace

Inference

You can use the following command to test a dataset. Note that weight, img_dir, dataset,hyp should be modified as appropriate.

python demo.py

Train

  1. Move the dataset to the $ROOT directory.
  2. Generate imageset files for daatset division via:
cd $ROOT/datasets
python generate_imageset.py
  1. Modify the configuration file hyp.py and arguments in train.py, then start training:
python train.py

Evaluation

Different datasets use different test methods. For UCAS-AOD/HRSC2016/VOC/NWPU VHR-10, you need to prepare labels in the appropriate format in advance. Take evaluation on HRSC2016 for example:

cd $ROOT/datasets/evaluate
python hrsc2gt.py

then you can conduct evaluation:

python eval.py

Note that :

  • the script needs to be executed only once, but testing on different datasets needs to be executed again.
  • the imageset file used in hrsc2gt.py is generated from generate_imageset.py.

Main Results

Method Dataset Bbox Backbone Input Size mAP/F1
DAL DOTA OBB ResNet-101 800 x 800 71.78
DAL UCAS-AOD OBB ResNet-101 800 x 800 89.87
DAL HRSC2016 OBB ResNet-50 416 x 416 88.60
DAL ICDAR2015 OBB ResNet-101 800 x 800 82.4
DAL ICDAR2013 HBB ResNet-101 800 x 800 81.3
DAL NWPU VHR-10 HBB ResNet-101 800 x 800 88.3
DAL VOC 2007 HBB ResNet-101 800 x 800 76.1

Detections

DOTA_results

Citation

If you find our work or code useful in your research, please consider citing:

@inproceedings{ming2021dynamic,
  title={Dynamic Anchor Learning for Arbitrary-Oriented Object Detection},
  author={Ming, Qi and Zhou, Zhiqiang and Miao, Lingjuan and Zhang, Hongwei and Li, Linhao},
  booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
  volume={35},
  number={3},
  pages={2355--2363},
  year={2021}
}

If you have any questions, please contact me via issue or email.

dal's People

Contributors

ming71 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

dal's Issues

about s2anet-DAL

Dear author:
thank you for your repo!
In the AAAI paper, the experiments on DOTA shows 76.95 mAP with the s2anet-DAL method, I want to know more about the details, such as, whether the DAL applied in the FAM or the ODM module?

IndexError: too many indices for array

Hi ming71:
When I trained the model to 2 iterations with batch 1 of the first epoch for DOTA dataset, I encountered the following problems:
Epoch gpu_mem cls reg total targets img_size
0/99 4.38G 2.35 4.93 7.28 33 768: 0%| | 2/1411 [00:00<10:31, 2.23it/s]
Traceback (most recent call last):
File "/home/mccc/Program_Code/Rotated_Object_Detection/DAL/train.py", line 279, in
train_model(arg, hyps)
File "/home/mccc/Program_Code/Rotated_Object_Detection/DAL/train.py", line 146, in train_model
for i, (ni, batch) in enumerate(pbar):
File "/home/mccc/anaconda3/envs/DAL/lib/python3.6/site-packages/tqdm/std.py", line 1176, in iter
for obj in iterable:
File "/home/mccc/anaconda3/envs/piou/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 322, in next
return self._process_next_batch(batch)
File "/home/mccc/anaconda3/envs/piou/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 357, in _process_next_batch
raise batch.exc_type(batch.exc_msg)
IndexError: Traceback (most recent call last):
File "/home/mccc/anaconda3/envs/piou/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 106, in _worker_loop
samples = collate_fn([dataset[i] for i in batch_indices])
File "/home/mccc/anaconda3/envs/piou/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 106, in
samples = collate_fn([dataset[i] for i in batch_indices])
File "/home/mccc/Program_Code/Rotated_Object_Detection/DAL/datasets/dota_dataset.py", line 43, in getitem
bboxes = roidb['boxes'][gt_inds, :]
IndexError: too many indices for array

Do you have such a problem?
My GPU is nvidia RTX 2080Ti , its memory size is 11G.
Looking forward to your reply! Thanks.

from utils.overlaps_cuda.rbbox_overlaps import rbbx_overlaps的问题!!

程序上午还在正常运行,下午就出错。错误如下:
from utils.overlaps.rbox_overlaps import rbbx_overlaps
ImportError: cannot import name 'rbbx_overlaps' from 'utils.overlaps.rbox_overlaps' (/home/oem/programs/DAL/utils/overlaps/rbox_overlaps.cpython-37m-x86_64-linux-gnu.so)

我删除了.so文件,重新编译之后问题还是没有解决。期待回复!!

About the angle of rbox

Thanks great work!
I have a question about the angle of rbox:
In the paper, i find the angle regression term is CodeCogsEqn. Does the angle of GT rbox (for example, HRSC2016 dataset) in this paper ranges from -90 to 90?

train中间的问题

前面的epoch正常执行,到第13个就出错,错误如下,有没有人知道是为什么?
image

About Loss Weight.

image
请问分类和回归两个权重,前者是加一后乘进去的,后者则未加一(是小于1的),是因为回归不用计算负样本,所以不用考虑小于1的权重会弱化正样本的训练,由于这个原因,回归的权重才未加一吗,还是因为别的原因?

replace RetinaNet with YOLOv5

Hi, I tried to replace RetinaNet with YOLOv5 to get faster speed. The traning loss can decrease but mAP is always low. It cannot detect anything. Before replacement, I can sucessfully get 93% mAP using the default setting (RetinaNet). Could you please tell me if I miss anything here? I read the code about anchor part. It looks like only rectangle anchor is used, but i did not notice where is the way to resize the predicted box to original image corrdinate, and what is the purpose of weights in class BoxCoder
weights=(10., 10., 10., 5., 15.)

Here is the part I changed. I make anthor copy of yolo.py in YOLOv5/models. Only change the Detect class to replace cls head and reg head in RetinaNet. I also changed in anchor.py to reduce the pyramid level from 5 to 3 as in YOLOv5. self.pyramid_levels = [3, 4, 5]

class Detect(nn.Module):
def init(self, nc=80, anchors=(), ch=()): # detection layer
super(Detect, self).init()
self.nc = nc + 1 # number of classes
self.nr = 5 #xywha
self.na = 3 #len(anchors[0]) // 2 # number of anchors
self.cls_head = nn.ModuleList(nn.Conv2d(x, self.na * self.nc, 3, 1, 1) for x in ch) # output conv
self.reg_head = nn.ModuleList(nn.Conv2d(x, self.na * self.nr, 3, 1, 1) for x in ch) # output conv
self.init_weights()

def init_weights(self):
    prior = 0.01
    for m in self.cls_head.modules():
        if isinstance(m, nn.Conv2d):
            m.weight.data.fill_(0)
            m.bias.data.fill_(-math.log((1.0 - prior) / prior))
    for m in self.reg_head.modules():
        if isinstance(m, nn.Conv2d):
            m.weight.data.fill_(0)
            m.bias.data.fill_(0)

def get_cls_feature(self, index, x):
    x = torch.sigmoid(self.cls_head[index](x))
    x = x.permute(0, 2, 3, 1)
    n, w, h, c = x.shape
    x = x.reshape(n, w, h, self.na, self.nc)
    return x.reshape(x.shape[0], -1, self.nc)
def get_reg_feature(self, index, x):
    x = self.reg_head[index](x)
    x = x.permute(0, 2, 3, 1)
    return x.reshape(x.shape[0], -1, self.nr)
def forward(self, x):
    cls_score = torch.cat([self.get_cls_feature(idx, feature) for idx, feature in enumerate(x)], dim=1)
    bbox_pred = torch.cat([self.get_reg_feature(idx, feature) for idx, feature in enumerate(x)], dim=1)
    return [cls_score, bbox_pred]

I define another class to replace class RetinaNet, by only change a few lines.

class YOLOV5(nn.Module):
def init(self, cfg='yolov5l.yaml', ch=3, nc=None, export=False):
super(YOLOV5, self).init()
self.yolov5 = Model(cfg)
self.num_classes = 1 + 1 # class number + 1, I only have one class in the dataset
self.anchor_generator = Anchors(
ratios = np.array([0.5,1,2]),
)
self.num_anchors = self.anchor_generator.num_anchors
self.loss = IntegratedLoss(func='smooth')
self.box_coder = BoxCoder()

def forward(self, ims, gt_boxes=None, test_conf=None,process=None):
    anchors_list, offsets_list, cls_list, var_list = [], [], [], []
    original_anchors = self.anchor_generator(ims)   # (bs, num_all_achors, 5)
    anchors_list.append(original_anchors)
    # features = self.fpn(self.ims_2_features(ims))
    # cls_score = torch.cat([self.cls_head(feature) for feature in features], dim=1)
    # bbox_pred = torch.cat([self.reg_head(feature) for feature in features], dim=1)
    [cls_score, bbox_pred] = self.yolov5(ims)

..........the other parts are the same.

小目标检测性能较低

非常感谢你的工作!
在你的网络上跑自己的数据集发现对小目标的检测性能不是很友好,我的数据集有很多目标大小在10~15个像素之间,我发现最P3层最大特征图上设置的锚框大小为16*16,这对于小目标是否太大了?如果我想提升小目标的检测精度我该怎么做呢?谢谢

学习率

from torch_warmup_lr import WarmupLR
在使用train.py文件时这一行语句报错,在网上也没有找到相关的库,请问一下这个WormupLR如何安装?

python setup.py build_ext --inplace 报gcc错误

作者您好,我运行python setup.py build_ext --inplace 命令后报如下错误,请问 如何解决
image
cc1plus: warning: command line option ‘-Wstrict-prototypes’ is valid for C/ObjC but not for C++
In file included from /home/flora/miniconda3/envs/DAL/lib/python3.7/site-packages/numpy/core/include/numpy/ndarraytypes.h:1822:0,
from /home/flora/miniconda3/envs/DAL/lib/python3.7/site-packages/numpy/core/include/numpy/ndarrayobject.h:12,
from /home/flora/miniconda3/envs/DAL/lib/python3.7/site-packages/numpy/core/include/numpy/arrayobject.h:4,
from rbbox_overlaps.cpp:449:
/home/flora/miniconda3/envs/DAL/lib/python3.7/site-packages/numpy/core/include/numpy/npy_1_7_deprecated_api.h:17:2: warning: #warning "Using deprecated NumPy API, disable it with " "#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION" [-Wcpp]
#warning "Using deprecated NumPy API, disable it with "
^
rbbox_overlaps.cpp: In function ‘void __Pyx__ExceptionSave(PyThreadState*, PyObject**, PyObject**, PyObject**)’:
rbbox_overlaps.cpp:5670:21: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_type’
type = tstate->exc_type;
^
rbbox_overlaps.cpp:5671:22: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_value’
value = tstate->exc_value;
^
rbbox_overlaps.cpp:5672:19: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_traceback’
tb = tstate->exc_traceback;
^
rbbox_overlaps.cpp: In function ‘void __Pyx__ExceptionReset(PyThreadState
, PyObject
, PyObject
, PyObject*)’:
rbbox_overlaps.cpp:5679:24: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_type’
tmp_type = tstate->exc_type;
^
rbbox_overlaps.cpp:5680:25: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_value’
tmp_value = tstate->exc_value;
^
rbbox_overlaps.cpp:5681:22: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_traceback’
tmp_tb = tstate->exc_traceback;
^
rbbox_overlaps.cpp:5682:13: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_type’
tstate->exc_type = type;
^
rbbox_overlaps.cpp:5683:13: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_value’
tstate->exc_value = value;
^
rbbox_overlaps.cpp:5684:13: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_traceback’
tstate->exc_traceback = tb;
^
rbbox_overlaps.cpp: In function ‘int __Pyx__GetException(PyThreadState*, PyObject**, PyObject**, PyObject**)’:
rbbox_overlaps.cpp:5739:24: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_type’
tmp_type = tstate->exc_type;
^
rbbox_overlaps.cpp:5740:25: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_value’
tmp_value = tstate->exc_value;
^
rbbox_overlaps.cpp:5741:22: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_traceback’
tmp_tb = tstate->exc_traceback;
^
rbbox_overlaps.cpp:5742:13: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_type’
tstate->exc_type = local_type;
^
rbbox_overlaps.cpp:5743:13: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_value’
tstate->exc_value = local_value;
^
rbbox_overlaps.cpp:5744:13: error: ‘PyThreadState {aka struct _ts}’ has no member named ‘exc_traceback’
tstate->exc_traceback = local_tb;
^
error: command 'gcc' failed with exit status 1

results on coco dataset

我想问一下作者,没有在coco数据集的结果 也是因为coco数据集太大了 硬件显卡不够的原因吗

rbbx_overlaps和rbbx_overlaps问题

程序上午还在正常运行,下午就出错。错误如下:
from utils.overlaps.rbox_overlaps import rbbx_overlaps
ImportError: cannot import name 'rbbx_overlaps' from 'utils.overlaps.rbox_overlaps' (/home/oem/programs/DAL/utils/overlaps/rbox_overlaps.cpython-37m-x86_64-linux-gnu.so)

我删除了.so文件,重新编译之后问题还是没有解决。

在utils文件夹中的dota_datasets.py

            content = f.read()
            objects = content.split('\n')
            for obj in objects:
                if len(obj) != 0 :
                    *box, class_name, difficult = obj.split(' ')
                    if difficult == 2:
                        continues
                    box = [ eval(x) for x in  obj.split(' ')[:8] ]
                    label = self.class_to_ind[class_name] 
                    boxes.append(box)
                    gt_classes.append(label)```
这里的difficult == '2'  spilt 返回的应该是一个字符串列表啊
但是代码中用的是 2 
这个有影响吗   还有下面的是continue??

too many indices for array

作者您好,在训练dota数据集的时候遇到了问题,报错:too many indices for array 位置在dota_dateset.py的50行

for train.py def eval help

how to eval on train? my data format as {x1,y1,x2,y2,x3,y3,x4,y4,text}, and i want use DLA for text detection, but get mistake(list out of range) on train.py(def eval). the train dataset use IC15dataset,could u write more detail for eval on train.py please.

Questions about the mAP of DOTA in paper

  1. Which schedule is adopted?(It seems 12 epochs if trained for 30k iterations with batch8)
  2. Are multi-scale training and random rotation used as data augment for results marked as belowing?
    image

关于DOTA数据集精度的问题

作者好,非常感谢您的工作。我尝试在DOTA1.0上进行训练与精度评价,这里是我训练的超参:

lr

lr0: 0.0001
warmup_lr: 0.00001
warm_epoch:5

setting

num_classes: 15

training

epochs: 24
batch_size: 4
save_interval: 3
test_interval: 1000

这里图片的输入尺寸是1024^2,但是目前我得到的精度是(task1):
mAP: 0.6017705895879002
ap of each class: plane:0.7982196456704045, baseball-diamond:0.6323334007581611, bridge:0.40836332233092576, ground-track-field:0.46187258734002823, small-vehicle:0.6304814611874696, large-vehicle:0.6403297981395368, ship:0.7565095425203512, tennis-court:0.8959681065504155, basketball-court:0.7136081543213204, storage-tank:0.7763719723087148, soccer-ball-field:0.29802453302029897, roundabout:0.5099949396128973, harbor:0.5388697646315033, swimming-pool:0.5819268480975204, helicopter:0.3836847673289567

请问如果达到项目里70+ap我需要做哪些修改?

关于hyp.py

我想请教关于hyp.py的具体问题。我在在实验中发现num_classes这个参数好像对实验没什么影响。我是用的hrsc2016数据集,一开始我没有注意,将这个参数设置为了10,之后发现错了,就改成了2,但是两次结果差不多。
请问您可以说明一下这个文件中参数的具体用法吗?

variable das's meaning

Hi, could you please tell me the meaning of variable das? or what is das an abbreviation for?

DAL/models/losses.py

Lines 37 to 50 in a4f625a

das = True
cls_losses = []
reg_losses = []
batch_size = classifications.shape[0]
alpha, beta, var = mining_param
# import ipdb;ipdb.set_trace()
for j in range(batch_size):
classification = classifications[j, :, :]
regression = regressions[j, :, :]
bbox_annotation = annotations[j, :, :]
bbox_annotation = bbox_annotation[bbox_annotation[:, -1] != -1]
if bbox_annotation.shape[0] == 0:
cls_losses.append(torch.tensor(0).float().cuda())
reg_losses.append(torch.tensor(0).float().cuda())

Does it mean Dynamic Anchor Select?

Compilation Error

rbbox_overlaps_kernel.cu:2:10: fatal error: rbbox_overlaps.hpp: 没有那个文件或目录
#include "rbbox_overlaps.hpp"
^~~~~~~~~~~~~~~~~~~~
compilation terminated.
error: command '/usr/local/cuda-10.0/bin/nvcc' failed with exit status 1
作者你好,请问下这是什么原因造成的?

my dataset

How to prepare my dataset and training with your repo?

训练时候遇到的问题

作者你好,我是用UCAS-AOD数据集进行训练,使用了数据集里面的car和airplane两个类,在训练了300个epoch之后进行测试,在训练300轮之后cls-loss 和 reg-loss都只有0.0x了,但是测试中发现模型只能检测出airplane一类,car完全检测不出来。不知为何?

train on DOTAdataset

Hi, thanks for your great job! But when i run the code on dota_dataset, there has something wrong with the data processing:
Traceback (most recent call last):
File "train.py", line 279, in
train_model(arg, hyps)
File "train.py", line 147, in train_model
for i, (ni, batch) in enumerate(pbar):
File "/home/cyli/dat01/anaconda3/envs/dal/lib/python3.7/site-packages/tqdm/std.py", line 1166, in iter
for obj in iterable:
File "/home/cyli/dat01/anaconda3/envs/dal/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 582, in next
return self._process_next_batch(batch)
File "/home/cyli/dat01/anaconda3/envs/dal/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 608, in _process_next_batch
raise batch.exc_type(batch.exc_msg)
IndexError: Traceback (most recent call last):
File "/home/cyli/dat01/anaconda3/envs/dal/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 99, in _worker_loop
samples = collate_fn([dataset[i] for i in batch_indices])
File "/home/cyli/dat01/anaconda3/envs/dal/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 99, in
samples = collate_fn([dataset[i] for i in batch_indices])
File "/scratch/cyli/lcy/DAL/datasets/dota_dataset.py", line 40, in getitem
bboxes = roidb['boxes'][gt_inds, :]
IndexError: too many indices for array: array is 1-dimensional, but 2 were indexed

rbbx_overlaps函数计算的iou始终为0

您好,我想在自己的工程里使用您的计算旋转框iou的函数rbbx_overlaps,但我发现无论我输入什么,输出的iou始终为0。下面是我的测试代码,是我的调用方式有问题吗
import numpy as np
from DAL_utils.overlaps_cuda.rbbox_overlaps import rbbx_overlaps
from DAL_utils.overlaps.rbox_overlaps import rbox_overlaps

a = np.array([[10, 10, 20, 10, 0]], dtype=np.float32)
b = np.array([[10, 10, 20, 10, 0]], dtype=np.float32)
c = rbbx_overlaps(a, b)
d = rbox_overlaps(a, b)
print(c)
print(d)

你好,我想在DOTA数据集上跑下您这个代码但是出现问题。

我是先用了DOTA_devkit进行DOTA数据集的train和val的切分,生成相应的images,labelTxt文件夹,用的标签是DOTA1.5 obb的。之后我也用generage_dataset生成相应的包含trainval所有图像的位置的txt文件,之后修改train.py的default参数如下,然后出现了一下错误。自己不是很明白,所以来这里问下。
if name == 'main':

parser = argparse.ArgumentParser(description='Train a detector')
# config
parser.add_argument('--hyp', type=str, default='hyp.py', help='hyper-parameter path')
# network
parser.add_argument('--backbone', type=str, default='res50')
parser.add_argument('--freeze_bn', type=bool, default=False)
parser.add_argument('--weight', type=str, default='')   # 
parser.add_argument('--multi-scale', action='store_true', help='adjust (67% - 150%) img_size every 10 batches')

# NWPU-VHR10
parser.add_argument('--dataset', type=str, default='DOTA')
parser.add_argument('--train_path', type=str, default='/home/ff/WangZF/datasets/dota/train.txt')
parser.add_argument('--test_path', type=str, default='/home/ff/WangZF/datasets/dota/val.txt')

parser.add_argument('--training_size', type=int, default=800)
parser.add_argument('--resume', action='store_true', help='resume training from last.pth')
parser.add_argument('--load', action='store_true', help='load training from last.pth')
parser.add_argument('--augment', action='store_true', help='data augment')
parser.add_argument('--target_size', type=int, default=[800])   
#

arg = parser.parse_args()
hyps = hyp_parse(arg.hyp)
print(arg)
print(hyps)

train_model(arg, hyps)

/home/ff/anaconda3/envs/pytorch1.6/bin/python /home/ff/WangZF/remoteSense/DAL/train.py
fail to speed up training via apex

Namespace(augment=False, backbone='res50', dataset='DOTA', freeze_bn=False, hyp='hyp.py', load=False, multi_scale=False, resume=False, target_size=[800], test_path='/home/ff/WangZF/datasets/dota/val.txt', train_path='/home/ff/WangZF/datasets/dota/train.txt', training_size=800, weight='')
{'lr0': 0.0001, 'warmup_lr': 1e-05, 'warm_epoch': 5.0, 'num_classes': 10.0, 'epochs': 100.0, 'batch_size': 2.0, 'save_interval': 5.0, 'test_interval': 5.0}
Model Summary: 195 layers, 3.63368e+07 parameters, 3.63368e+07 gradients

 Epoch   gpu_mem       cls       reg     total   targets  img_size

0%| | 0/4043 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/home/ff/WangZF/remoteSense/DAL/train.py", line 278, in
train_model(arg, hyps)
File "/home/ff/WangZF/remoteSense/DAL/train.py", line 146, in train_model
for i, (ni, batch) in enumerate(pbar):
File "/home/ff/anaconda3/envs/pytorch1.6/lib/python3.6/site-packages/tqdm/std.py", line 1193, in iter
for obj in iterable:
File "/home/ff/anaconda3/envs/pytorch1.6/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 435, in next
data = self._next_data()
File "/home/ff/anaconda3/envs/pytorch1.6/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1085, in _next_data
return self._process_data(data)
File "/home/ff/anaconda3/envs/pytorch1.6/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1111, in _process_data
data.reraise()
File "/home/ff/anaconda3/envs/pytorch1.6/lib/python3.6/site-packages/torch/_utils.py", line 428, in reraise
raise self.exc_type(msg)
IndexError: Caught IndexError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/ff/anaconda3/envs/pytorch1.6/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop
data = fetcher.fetch(index)
File "/home/ff/anaconda3/envs/pytorch1.6/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/ff/anaconda3/envs/pytorch1.6/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/ff/WangZF/remoteSense/DAL/datasets/dota_dataset.py", line 40, in getitem
bboxes = roidb['boxes'][gt_inds, :]
IndexError: too many indices for array: array is 1-dimensional, but 2 were indexed

Parameter Settings

Can the author provide all parameter Settings in hyp.py, such as HRSC2016.

About HRSC dataset

作者您好,请问您的HRSC dataset的数据集在代码中放置的格式是怎么样的呢?希望得到您的解答。

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.