GithubHelp home page GithubHelp logo

distnet's People

Contributors

whu-wxy avatar

Stargazers

 avatar

Watchers

 avatar  avatar

distnet's Issues

loss not decreasing

Hello, I'm using the 'dla' branch, downloaded the 'ctw1500_dla_93....pth' and inference using it works fine.
I'm trying to finetune the model to custom dataset, here's the dataset class:

class GenericDataset(data.Dataset):
    def __init__(self, data_dir, data_shape: int = 640, dataset_type='ctw1500', transform=None,
                 target_transform=None, for_test=False, count=None):
        self.for_test = for_test
        self.dataset_type = dataset_type
        self.imgs_dir = os.path.join(data_dir, 'imgs')
        self.annos_dir = os.path.join(data_dir, 'annos')
        self.fns = os.listdir(self.imgs_dir)
        if count is not None:
            self.fns = self.fns[:count]
        # self.data_list = self.load_data(data_dir)
        self.data_shape = data_shape
        self.transform = transform
        self.target_transform = target_transform

    # self.aug = augument()  #20200302增加新augument方式

    def __getitem__(self, index):
        fn = self.fns[index]
        im = cv2.imread(os.path.join(self.imgs_dir, fn), cv2.IMREAD_COLOR)
        fnjson = fn + '.json'
        with open(os.path.join(self.annos_dir, fnjson), 'r') as f:
            cnts = [np.array(cnt).astype(int).reshape(-1, 2) for cnt in json.load(f)]
        polys = [cont2poly(cnt).simplify(1) for cnt in cnts]
        cnts = [poly2cont(poly) for poly in polys]
        # pad
        max_pts = max([len(x) for x in cnts])
        cnts2 = []
        for cnt in cnts:
            if len(cnt) < max_pts:
                v = cnt[-1]
                cnt2 = np.vstack([cnt, np.repeat([v], max_pts - len(cnt), 0)])
                cnts2.append(cnt2)
            else:
                cnts2.append(cnt)
        text_polys = np.array(cnts2, dtype=np.float32)
        tags = [True] * len(text_polys)

        img, training_mask, distance_map = image_label_v4(im, text_polys, tags,
                                                          input_size=self.data_shape,
                                                          scales=np.array(configdist.random_scales),
                                                          for_test=self.for_test)

        if self.transform:
            img = self.transform(img)
        if self.target_transform:
            training_mask = self.target_transform(training_mask)

        return img, training_mask, distance_map

    def __len__(self):
        return len(self.fns)

'image_label_v4' is the same as 'image_label_v3' except that it accepts the img as a ndarray.

testing the dataset:

train_data = GenericDataset('/content/synth_generic_line', data_shape=config.data_shape, dataset_type=config.dataset_type,
                            transform=transforms.ToTensor())
train_loader = DataLoaderX(dataset=train_data, batch_size=1, shuffle=True,
                            num_workers=int(config.workers), pin_memory=config.pin_memory)

for i, (images, training_mask, distance_map) in enumerate(train_loader):
  print('images shape', images.shape)
  print('training_mask shape', training_mask.shape)
  print('distance_map shape', distance_map.shape)
  break
images shape torch.Size([1, 3, 640, 640])
training_mask shape torch.Size([1, 640, 640])
distance_map shape torch.Size([1, 640, 640])

images[0]:
download (26)

training_mask[0]
download (27)

distance_map[0]
download (28)

training :

model = get_dlaseg_net(34, heads={'seg_hm': 2}, down_ratio=4, head_conv=256, bFSM=False)
criterion = Loss(OHEM_ratio=config.OHEM_ratio, reduction='mean')
config.lr = 1e-4
config.weight_decay = 5e-4
optimizer = torch.optim.AdamW([{'params': model.parameters(), 'initial_lr': config.lr}], lr=config.lr,
                                      weight_decay=config.weight_decay,)
scheduler = None
#----
config.checkpoint = '/content/DistNet/dist_ctw.pth'
load_checkpoint(config.checkpoint, model, logger, device, None)
#------
for epoch in range(start_epoch, config.epochs):
        start = time.time()
        train_loss, lr = train_epoch(model, optimizer, scheduler, train_loader, device, criterion,
                                     epoch, all_step,writer, logger)
       .....
2024-02-17 17:36:56 INFO      train_curve.py: [0/10], [0/834], step: 0, 0.116 samples/sec, loss: 3.0000, dice_center_loss: 1.0000, dice_region_loss: 1.0000, weighted_mse_region_loss: 0.0000, dice_bi_region: 1.0000, time:51.9259, lr:0.0001
INFO:project:[0/10], [0/834], step: 0, 0.116 samples/sec, loss: 3.0000, dice_center_loss: 1.0000, dice_region_loss: 1.0000, weighted_mse_region_loss: 0.0000, dice_bi_region: 1.0000, time:51.9259, lr:0.0001
2024-02-17 17:36:57 INFO      train_curve.py: [0/10], [1/834], step: 1, 5.890 samples/sec, loss: 3.0000, dice_center_loss: 1.0000, dice_region_loss: 1.0000, weighted_mse_region_loss: 0.0000, dice_bi_region: 1.0000, time:1.0186, lr:0.0001
INFO:project:[0/10], [1/834], step: 1, 5.890 samples/sec, loss: 3.0000, dice_center_loss: 1.0000, dice_region_loss: 1.0000, weighted_mse_region_loss: 0.0000, dice_bi_region: 1.0000, time:1.0186, lr:0.0001
2024-02-17 17:36:58 INFO      train_curve.py: [0/10], [2/834], step: 2, 5.560 samples/sec, loss: 3.0000, dice_center_loss: 1.0000, dice_region_loss: 1.0000, weighted_mse_region_loss: 0.0000, dice_bi_region: 1.0000, time:1.0792, lr:0.0001
INFO:project:[0/10], [2/834], step: 2, 5.560 samples/sec, loss: 3.0000, dice_center_loss: 1.0000, dice_region_loss: 1.0000, weighted_mse_region_loss: 0.0000, dice_bi_region: 1.0000, time:1.0792, lr:0.0001
.....

the losses remain the same even after 30 minutes, I did not change the project's code except for dataset preparation.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.