GithubHelp home page GithubHelp logo

msamogh / nonechucks Goto Github PK

View Code? Open in Web Editor NEW
373.0 3.0 27.0 26 KB

Deal with bad samples in your dataset dynamically, use Transforms as Filters, and more!

License: MIT License

Python 100.00%
pytorch data-processing data-preprocessing data-pipeline data-cleaning preprocessing machine-learning torch

nonechucks's Introduction

nonechucks

nonechucks is a library that provides wrappers for PyTorch's datasets, samplers, and transforms to allow for dropping unwanted or invalid samples dynamically.


What if you have a dataset of 1000s of images, out of which a few dozen images are unreadable because the image files are corrupted? Or what if your dataset is a folder full of scanned PDFs that you have to OCRize, and then run a language detector on the resulting text, because you want only the ones that are in English? Or maybe you have an AlternateIndexSampler, and you want to be able to move to dataset[6] after dataset[4] fails while attempting to load!

PyTorch's data processing module expects you to rid your dataset of any unwanted or invalid samples before you feed them into its pipeline, and provides no easy way to define a "fallback policy" in case such samples are encountered during dataset iteration.

Why do I need it?

You might be wondering why this is such a big deal when you could simply filter out samples before sending it to your PyTorch dataset or sampler! Well, it turns out that it can be a huge deal in many cases:

  1. When you have a small fraction of undesirable samples in a large dataset, or
  2. When your sample-loading operation is expensive, or
  3. When you want to let downstream consumers know that a sample is undesirable (with nonechucks, transforms are not restricted to modifying samples; they can drop them as well),
  4. When you want your dataset and sampler to be decoupled.

In such cases, it's either simply too expensive to have a separate step to weed out bad samples, or it's just plain impossible because you don't even know what constitutes as "bad", or worse - both!

nonechucks allows you to wrap your existing datasets and samplers with "safe" versions of them, which can fix all these problems for you.

1. Dealing with bad samples

Let's start with the simplest use case, which involves wrapping an existing Dataset instance with SafeDataset.

Create a dataset (the usual way)

Using something like torchvision's ImageFolder dataset class, we can load an entire folder of labelled images for a typical supervised classification task.

import torchvision.datasets as datasets
fruits_dataset = datasets.ImageFolder('fruits/')

Without nonechucks

Now, if you have a sneaky fruits/apple/143.jpg (that is corrupted) sitting in your fruits/ folder, to avoid the entire pipeline from surprise-failing, you would have to resort to something like this:

import random

# Shuffle dataset
indices = list(range(len(fruits_dataset))
random.shuffle(indices)

batch_size = 4
for i in range(0, len(indices), batch_size):
    try:
        batch = [fruits_dataset[idx] for idx in indices[i:i + batch_size]]
        # Do something with it
        pass
    except IOError:
        # Skip the entire batch
        continue

Not only do you have to put your code inside an extra try-except block, but you are also forced to use a for-loop, depriving yourself of PyTorch's built-in DataLoader, which means you can't use features like batching, shuffling, multiprocessing, and custom samplers for your dataset.

I don't know about you, but not being able to do that kind of defeats the whole point of using a data processing module for me.

With nonechucks

You can transform your dataset into a SafeDataset with a single line of code.

import nonechucks as nc
fruits_dataset = nc.SafeDataset(fruits_dataset)

That's it! Seriously.

And that's not all. You can also use a DataLoader on top of this.

dataloader = nc.SafeDataLoader(fruits_dataset, batch_size=4, shuffle=True)
for i_batch, sample_batched in enumerate(dataloader):
    # Do something with it
    pass

In this case, SafeDataset will skip the erroneous image, and use the next one in the place of it (as opposed to dropping the entire batch).

2. Use Transforms as Filters!

The function of transorms in PyTorch is restricted to modifying samples. With nonechucks, you can simply return None (or raise an exception) from the transform's __call__ method, and nonechucks will drop the sample from the dataset for you, allowing you to use transforms as filters!

For the example, we'll assume a PDFDocumentsDataset, which reads PDF files from a folder, a PlainTextTransform, which transforms the files into raw text, and a LanguageFilter, which retains only documents of a particular language.

class LanguageFilter:
    def __init__(self, language):
        self.language = language
        
    def __call__(self, sample):
        # Do machine learning magic
        document_language = detect_language(sample)
        if document_language != self.language:
            return None
        return sample

transforms = transforms.Compose([
                PlainTextTransform(),
                LanguageFilter('en')
            ])
en_documents = PDFDocumentsDataset(data_dir='pdf_files/', transform=transforms)
en_documents = nc.SafeDataset(en_documents)

To install nonechucks, simply use pip:

$ pip install nonechucks

or clone this repo, and build from source with:

$ python setup.py install.

All PRs are welcome.

nonechucks is MIT licensed.

nonechucks's People

Contributors

msamogh avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

nonechucks's Issues

ValueError: dataset attribute should not be set after SafeDataLoader is initialized

Hi!

I got this error, and can't get why it is happening.

Traceback (most recent call last):

File "train.py", line 65, in
dataset = dataloader(opt)
File "/root/gans_depth/Synthetic2Realistic/dataloader/data_loader.py", line 130, in dataloader
dataset = nc.SafeDataLoader(datasets, batch_size=opt.batchSize, shuffle=opt.shuffle, num_workers=int(opt.nThreads))
File "/root/anaconda3/lib/python3.7/site-packages/nonechucks/dataloader.py", line 25, in call
obj = type.call(cls, *args, **kwargs)
File "/root/anaconda3/lib/python3.7/site-packages/nonechucks/dataloader.py", line 141, in init
self.dataset = _OriginalDataset(self.safe_dataset)
File "/root/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 271, in setattr
'initialized'.format(attr, self.class.name))
ValueError: dataset attribute should not be set after SafeDataLoader is initialized

My dataset is:

class CreateDataset(data.Dataset):

def initialize(self, opt):
    self.opt = opt
    self.cat_depth = opt.cat_depth  
    self.depth_separate = opt.depth_separate
    self.img_source_paths, self.img_source_size = make_dataset(opt.img_source_file)
    self.img_target_paths, self.img_target_size = make_dataset(opt.img_target_file)

    if True:
        self.lab_source_paths, self.lab_source_size = make_dataset(opt.lab_source_file)
        # for visual results, not for training
        self.lab_target_paths, self.lab_target_size = make_dataset(opt.lab_target_file)

    self.transform_augment = get_transform(opt, True)
    self.transform_no_augment = get_transform(opt, False)
    self.transform_no_augment_lab = get_transform(opt, False, ch=1)

def __getitem__(self, item):
    .......
    return ...

I call dataloader like this:

def dataloader(opt):
datasets = CreateDataset()
datasets.initialize(opt)
datasets = nc.SafeDataset(datasets)
dataloader = nc.SafeDataLoader(datasets, batch_size=opt.batchSize, shuffle=opt.shuffle, num_workers=int(opt.nThreads))
return dataloader

Thank you in advance!

Potential bug in `_reset_index`

SafeDataset._reset_index is defined as follows:

    def _reset_index(self):
        """Resets the safe and unsafe samples indices."""
        self._safe_indices = self._unsafe_indices = []

I believe this has a bug, as after calling this function _safe_indices and _unsafe_indices will both point to the same underlying list in memory.

Compare:

>> a = []
>> b = []
>> a.append(2)
>> print(b)
[]
  --> expected!

with

>> a = b = []
>> a.append(2)
>> print(b)
[2]
  --> unexpected!

batch size reduction

Hi,

how can I not reduce the size of the batch when I use nonechucks?

Thanks in advance!

KeyError:(<function SafeDataset.__getitem__ at ...>)

Hi, I used to create a dataset use SafeDataset from csv file, but failed with the error of
KeyError:(<function SafeDataset.__getitem__ at ...>)
Detalis

Traceback (most recent call last):
  File "test.py", line 77, in <module>
    main()
  File "test.py", line 59, in main
    for batch in test_loader:
  File "/home/hotel_ai/python3/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 582, in __next__
    return self._process_next_batch(batch)
  File "/home/hotel_ai/python3/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 608, in _process_next_batch
    raise batch.exc_type(batch.exc_msg)
IndexError: Traceback (most recent call last):
  File "/home/hotel_ai/python3/lib/python3.5/site-packages/nonechucks/utils.py", line 49, in __call__
    res = cache[key]
KeyError: (<function SafeDataset.__getitem__ at 0x7f81b186b950>, (0,), frozenset())

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/hotel_ai/python3/lib/python3.5/site-packages/torch/utils/data/_utils/worker.py", line 99, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/hotel_ai/python3/lib/python3.5/site-packages/torch/utils/data/_utils/worker.py", line 99, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/hotel_ai/python3/lib/python3.5/site-packages/nonechucks/utils.py", line 51, in __call__
    res = cache[key] = self.func(*args, **kw)
  File "/home/hotel_ai/python3/lib/python3.5/site-packages/nonechucks/dataset.py", line 96, in __getitem__
    raise IndexError
IndexError

and here is my code

test_set = ImageSet('./test.csv', test_trainsforms)
test_set = nc.SafeDataset(test_set)

ImageSet code, open image from http source

class ImageSet(data.Dataset):
    def __init__(self, data_txt,data_transforms):
        f = open(data_txt, "r")
        data_list=[]
        #label_list = []
        cnt = 0
        lines = f.readlines()
        for line in lines[1:]:
            cnt += 1
            tmp = line.strip().split(',')
            data_path = tmp[1]
            data_list.append(data_path)
        f.close()
        self.data_list = data_list
        self.transforms = data_transforms


    def __getitem__(self, index):
        url_prefix = 'this is a http-url-prefix such as: http://images.baidu.com/'

        data_path = self.data_list[index]

        file0 = urllib.request.urlopen(url_prefix + data_path)
        image_file0 = io.BytesIO(file0.read())
        data = Image.open(image_file)
        if data.mode != 'RGB':
            data = data.convert("RGB")

        data = self.transforms(data)
        
       return data, data_path
        
    def __len__(self):
        return len(self.data_list)

Pytorch's IterableDataset

Hello, I've been using this (excellent) library for a while, and I just stumbled upon a new feature in pytorch. It seems that pytorch now has an IterableDataset class that is meant to solve the exact issues that this library was trying to solve.

Is this correct? I feel like nonechucks is doing more than what can be done with the class, but it seems to me, safe dataloading and transforms as filters can be done with this (provided one's careful with the multithreading).

AttributeError: module 'torch.utils.data' has no attribute 'Sampler'

Trying to import nonechucks under Pytorch 0.4 I get an error, is this related to the Pytorch version?

AttributeError                            Traceback (most recent call last)
<ipython-input-3-dc35b3a9e8ee> in <module>()
      2 import torch.utils.data
      3 
----> 4 import nonechucks

/opt/conda/lib/python3.6/site-packages/nonechucks/__init__.py in <module>()
      1 from .dataset import SafeDataset
----> 2 from .sampler import SafeSampler
      3 from .dataloader import SafeDataLoader
      4 
      5 __all__ = ['SafeDataset', 'SafeSampler', 'SafeDataLoader']

/opt/conda/lib/python3.6/site-packages/nonechucks/sampler.py in <module>()
      5 
      6 
----> 7 class SafeSampler(torch.utils.data.Sampler):
      8     """SafeSampler can be used both as a standard Sampler (over a Dataset),
      9     or as a wrapper around an existing `Sampler` instance. It allows you to

AttributeError: module 'torch.utils.data' has no attribute 'Sampler'

Not compatible with PyTorch v1.2 any more

There is underline private API changes of the latest version of PyTorch.
That introduces the following error.

lib/python3.6/site-packages/nonechucks/__init__.py in <module>
      1 from .dataset import SafeDataset
      2 from .sampler import SafeSampler
----> 3 from .dataloader import SafeDataLoader
      4 
      5 __all__ = ['SafeDataset', 'SafeSampler', 'SafeDataLoader']

lib/python3.6/site-packages/nonechucks/dataloader.py in <module>
     42 
     43 
---> 44 class _SafeDataLoaderIter(data.dataloader._DataLoaderIter):
     45 
     46     def __init__(self, loader):

AttributeError: module 'torch.utils.data.dataloader' has no attribute '_DataLoaderIter'

How to handle TypError using nonechucks

Hi,

How to handle TypeError using nonechucks?

I have 162400 annotations (260 labels per one annotation) in json file. Each annotation contains filename for example aaa.jpg. I load photos based on filename from annotation json file. My image dataset contains 162000 images, so I do not have 400 photos in the collection. If the dataloader comes across the name of the annotation image and is not in the image set, it returns TypeError to me.
Do you know how to catch this error using your library? I tried SafeDataLoader and it does not work.
I also tried try except, but if I print a picture as None, I get a Class TypeNone error.
Best!

nonechucks breaks with pytorch's revision not in int

Traceback (most recent call last):
  File "<redacted>", line 24, in <module>
    import nonechucks as nc
  File "/usr/local/lib/python3.6/dist-packages/nonechucks/__init__.py", line 27, in <module>
    MAJOR, MINOR = _get_pytorch_version()
  File "/usr/local/lib/python3.6/dist-packages/nonechucks/__init__.py", line 12, in _get_pytorch_version
    major, minor, patch = [int(x) for x in version.split(".")]
  File "/usr/local/lib/python3.6/dist-packages/nonechucks/__init__.py", line 12, in <listcomp>
    major, minor, patch = [int(x) for x in version.split(".")]
ValueError: invalid literal for int() with base 10: '0+cu92'
root@842e18f1c34c:/<redacted># python
Python 3.6.7 (default, Oct 21 2018, 04:56:05)
[GCC 5.4.0 20160609] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.__version__
'1.2.0+cu92'
>>>

With Safedataset and Safedataloader did I should change dataset?

Hi, Your solution are more wise and resonable, My simple original dataset getitem are:

  def __getitem__(self, index):
        data_file = self.files[self.split][index]
        # load image
        img_file = data_file['img']

        # PIL corrupt sometimes
        img = PIL.Image.open(img_file)
        # img = cv2.cvtColor(cv2.imread(img_file), cv2.COLOR_BGR2RGB)
        img = np.array(img, dtype=np.uint8)

        # load label
        lbl_file = data_file['lbl']
        lbl = PIL.Image.open(lbl_file)
        # lbl = cv2.imread(lbl_file, 0)
        # lbl = np.array(lbl / 255, dtype=np.uint8)
        lbl = np.array(lbl, dtype=np.int32)
        lbl[lbl == 255] = -1
        if self._transform:
            return self.transform(img, lbl)
        else:
            return img, lbl

I wonder, with those wrappers:

   train_dataset = SafeDataset(VOC2012ClassSeg(root=voc_root, transform=True))
    train_loader = SafeDataLoader(train_dataset, batch_size=1, shuffle=True)

    val_dataset = SafeDataset(VOC2011ClassSeg(root=voc_root, split='seg11valid', transform=True))
    val_loader = SafeDataLoader(val_dataset, batch_size=1, shuffle=False)

Should I still need to try....excaption that corrupt image?
I have one corrupt image, and I dont know which one. I can try..except that, but I just can not return None value, or next index

_get_pytorch_version can't deal with +cuda versions

  File "/home/marcel/.local/lib/python3.8/site-packages/nonechucks/__init__.py", line 27, in <module>
    MAJOR, MINOR = _get_pytorch_version()
  File "/home/marcel/.local/lib/python3.8/site-packages/nonechucks/__init__.py", line 12, in _get_pytorch_version
    major, minor, patch = [int(x) for x in version.split(".")]
  File "/home/marcel/.local/lib/python3.8/site-packages/nonechucks/__init__.py", line 12, in <listcomp>
    major, minor, patch = [int(x) for x in version.split(".")]
ValueError: invalid literal for int() with base 10: '1+cu110'

Skip filtering step for safe samples

The trick to blacklist the bad samples so that the computation is never ran again for these samples is very neat. I think we should extend it to the safe samples. In many cases, the verification can be long and costly (machine learning analysis of the image, etc.) and should only be done once when the sample is first loaded.

Colab issue

After installing on Colab and trying to import it, system gives errors

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-100-31f98a41c17c> in <module>()
----> 1 import nonechucks as nc
      2 #fruits_dataset = nc.SafeDataset(fruits_dataset)

2 frames
/usr/local/lib/python3.7/dist-packages/nonechucks/__init__.py in <listcomp>(.0)
     10 def _get_pytorch_version():
     11     version = torch.__version__
---> 12     major, minor, patch = [int(x) for x in version.split(".")]
     13     if major != 1:
     14         raise RuntimeError(

ValueError: invalid literal for int() with base 10: '1+cu101'

Override some functions of the original dataset

Hi, thank you for your work.
I am using nonechucks to deal with bad examples. However, in my original dataset class, one function get_img_info is defined. After using nc.SafeDataset to wrap it, it shows, 'SafeDataset object has no attribute 'get_img_info'

Memory leak

Hi,

when the below example is run, the RAM usage grows forever:

import torch, torch.utils.data
import nonechucks

class DummyDataset(torch.utils.data.Dataset):
    def __len__(self):
        return 1_000_000

    def __getitem__(self, idx):
        return 666

dataset = nonechucks.SafeDataset(DummyDataset())

for _ in torch.utils.data.DataLoader(dataset):
    pass

Notes:

  • Here the increase is quite slow; for a RAPID bug demonstration, replace 666 with torch.empty(10_000) (be careful to kill the process in time, before you're OOM!).
  • No problems without SafeDataset.
  • Without torch.utils.data.DataLoader, the leak is still there, although at a smaller scale, around 1 MB of RAM is lost per 30000-40000 __getitem__ calls.
  • PyTorch 1.0.1, nonechucks 0.3.1.

Poor performance of SafeSampler compared to the default sequential sampler

Hi, I have been playing around with nonechucks a bit. I observed, that if I use SafeDataset together with standard DataLoader (using default sequential sampler), my CPUs are fully loaded. However, when I use the DataLoader with SafeSampler, then I see usually only one process running and the others are sleeping (probably waiting for synchronization). Could it be that in SafeSampler __next__() method the threads needs to be synchronized due to the while loop? It is a really HUGE difference in performance between using and not using SafeSampler...

However, I understand that if I use DataLoader without SafeSampler, then the sampled examples can be returned several times, which is not usable in my case.

Originally posted by @brejchajan in #5 (comment)

Write tests

Write unit test coverage for SafeDataset and SafeDataLoader, along with the functions in utils.py.

Overriding torch's classes

After importing module, it overrides torch's dataloader so that in no longer can accept another datasets. :(
Screen Shot 2019-05-12 at 4 39 55 PM

Failed to shuffle data across epoches

Hi, I found this project quite helpful and implement into my project. However, it seems something wrong with the shuffling method of the loader. Here is my testing code:

import torch
import torch.utils.data as Data
import nonechucks as nc

torch.manual_seed(1)    # reproducible
BATCH_SIZE = 5
SHUFFLE = True
NUMWORKER = 2

x = torch.linspace(1, 10, 10)       # x data (torch tensor)
y = torch.linspace(10, 1, 10)       # y data (torch tensor)

torch_dataset = Data.TensorDataset(x,y)
loader = Data.DataLoader(
    dataset=torch_dataset,      # torch TensorDataset format
    batch_size=BATCH_SIZE,      # mini batch size
    shuffle=SHUFFLE,              
    num_workers=NUMWORKER,
    )

loader_safe = nc.SafeDataLoader(nc.SafeDataset(torch_dataset), batch_size=BATCH_SIZE, shuffle=SHUFFLE, num_workers=NUMWORKER)

print('\nNormal dataloader')
for epoch in range(3):
    for step, (batch_x, batch_y) in enumerate(loader):
        print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
              batch_x.numpy(), '| batch y: ', batch_y.numpy())
        
print('\nSafe dataloader')
for epoch in range(3):
    for step, (batch_x, batch_y) in enumerate(loader_safe):
        print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
              batch_x.numpy(), '| batch y: ', batch_y.numpy())

And the results are like this:

Normal dataloader
Epoch:  0 | Step:  0 | batch x:  [10.  3.  1.  6.  5.] | batch y:  [ 1.  8. 10.  5.  6.]
Epoch:  0 | Step:  1 | batch x:  [8. 4. 2. 9. 7.] | batch y:  [3. 7. 9. 2. 4.]
Epoch:  1 | Step:  0 | batch x:  [7. 9. 6. 2. 4.] | batch y:  [4. 2. 5. 9. fe  6. 10.  1.]
Epoch:  2 | Step:  0 | batch x:  [4. 9. 1. 8. 5.] | batch y:  [ 7.  2. 10.  3.  6.]
Epoch:  2 | Step:  1 | batch x:  [ 3.  2.  6. 10.  7.] | batch y:  [8. 9. 5. 1. 4.]

Safe dataloader
Epoch:  0 | Step:  0 | batch x:  [6. 7. 2. 3. 1.] | batch y:  [ 5.  4.  9.  8. 10.]
Epoch:  0 | Step:  1 | batch x:  [ 9. 10.  4.  8.  5.] | batch y:  [2. 1. 7. 3. 6.]
Epoch:  1 | Step:  0 | batch x:  [6. 7. 2. 3. 1.] | batch y:  [ 5.  4.  9.  8. 10.]
Epoch:  1 | Step:  1 | batch x:  [ 9. 10.  4.  8.  5.] | batch y:  [2. 1. 7. 3. 6.]
Epoch:  2 | Step:  0 | batch x:  [6. 7. 2. 3. 1.] | batch y:  [ 5.  4.  9.  8. 10.]
Epoch:  2 | Step:  1 | batch x:  [ 9. 10.  4.  8.  5.] | batch y:  [2. 1. 7. 3. 6.]

It seems that the safeloader doesn't shuffle correctly across epoches. Do I somehow miss use the package?

By the way, the pytorch version I used is 1.0.1.post2

Generator keyword error

I get this error when try to use safedataloader

dl = nc.SafeDataLoader(ds, batch_size = batch_ss, shuffle = shuffle)
File "/lib/python3.8/site-packages/nonechucks/dataloader.py", line 25, in call
obj = type.call(cls, args, **kwargs)
File "
/lib/python3.8/site-packages/nonechucks/dataloader.py", line 138, in init
super(SafeDataLoader, self).init(dataset, **kwargs)
File "
/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 277, in init
sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type]
TypeError: safe_sampler_callable() got an unexpected keyword argument 'generator'

How can I fix it?

P.S. I'm using PyTorch (1.11.0+cu113)

Unable to import nonechucks

Hello,

I get the following error when trying to load nonechucks (after installing either with pip or cloning) in Python 3.6.3:

import nonechucks
Traceback (most recent call last):
File "", line 1, in
File "/anaconda3/lib/python3.6/site-packages/nonechucks/init.py", line 1, in
from dataset import *
ModuleNotFoundError: No module named 'dataset'

Thank you!

dataset attribute should not be set after SafeDataLoader is initialized

I have my own custom DataSet class. When I wrap it in nc.SafeDataset it runs with no error. When I then run loader = nc.SafeDataLoader(nc.SafeDataset(my_dataset)) it breaks and I get this error: dataset attribute should not be set after SafeDataLoader is initialized. Is this because I am not using a supported pytorch version?

NotImplementedError for __len__ function in SafeDataLoader

Hi @msamogh,

Having a __len__ function in SafeDataLoader, identical to the one in torch.utils.data.Dataloader, would be very helpful. I currently get the following error:

dataloader = nc.SafeDataLoader(dataset)
if i == len(dataloader):
File "envs/deep/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 504, in __len__
return len(self.batch_sampler)
File "envs/deep/lib/python3.6/site-packages/torch/utils/data/sampler.py", line 150, in __len__
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
File "envs/deep/lib/python3.6/site-packages/torch/utils/data/sampler.py", line 20, in __len__
raise NotImplementedError
NotImplementedError

Thank you,

Should SafeDataset drop __getitem__ and inherrit IterableDataset?

I quickly looked under the hood of this library because I needed to handle None values in my own dataset, but felt suspicious that this is trying to do something impossible.

Looking at https://github.com/msamogh/nonechucks/blob/master/nonechucks/dataset.py#L87-L96, I am under the impression that __getitem__ will return the same value for multiple indices. E.g. suppose index 2 is None, then dataset[2] == dataset[3].

Surely that doesn't make sense for a well-behaved map-style dataset?

Alternatively indices could be remapped via a Dict[int,int] for random access.

SafeDataset cann't wrapper different dataset

trainset_ = my_data((140,224),transform=image_transform)
testset_ = my_data((140,224),image_set='val',transform=image_transform)

trainset = nc.SafeDataset(trainset_)
testset = nc.SafeDataset(testset_)

Before executing trainset=nc.SafeDataset(trainset_), I try to plot the first image in test by (plt.imshow(testset_[0][0].permute(1,2,0)), so I got the first image in valset which is right. After execute trainset=nc.SafeDataset(trainset_) , the (plt.imshow(testset_[0][0].permute(1,2,0)) shows me the first image in trainset which is wrong. I tried to check the image path in testset object, it' still the the path of first image in valset. Could you give me some suggestion to solve this problem?

from PIL import Image as image
import torch
import numpy as np
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
import os
import torchvision.datasets as dset
voc_colormap = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
                [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
                [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
                [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
                [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
                [0, 64, 128]]
class my_data(torch.utils.data.Dataset):
    #if target_transform=mask_transform is
    def __init__(self,data_size,root='data',image_set='train',transform=None,target_transform=None):
        self.shape=data_size
        self.root=os.path.expanduser(root)
        self.transform=transform
        self.target_transform=target_transform
        self.image_set=image_set
        voc_dir=os.path.join(self.root,'VOCdevkit/VOC2012')
        image_dir=os.path.join(voc_dir,'JPEGImages')
        mask_dir=os.path.join(voc_dir,'SegmentationClass')
        splits_dir=os.path.join(voc_dir,'ImageSets/Segmentation')
        splits_f=os.path.join(splits_dir, self.image_set + '.txt')
        with open(os.path.join(splits_f),'r') as f:
            file_name=[x.strip() for x in f.readlines()]
        self.image=[os.path.join(image_dir,x+'.jpg') for x in file_name]
        self.mask=[os.path.join(mask_dir,x+'.png') for x in file_name]
        assert (len(self.image)==len(self.mask))

        self.class_index=np.zeros(256**3)
        for i,j in enumerate(voc_colormap):
            tmp=(j[0]*256+j[1])*256+j[2]
            self.class_index[tmp]=i
    def __getitem__(self, index):
        img=image.open(self.image[index]).convert('RGB')
        target=image.open(self.mask[index]).convert('RGB')
        i,j,h,w=transforms.RandomCrop.get_params(img,self.shape)
        # if i<0 or j<0 or h <0 or w<0:
        #     return None,None
        img=TF.crop(img,i,j,h,w)
        target=TF.crop(target,i,j,h,w)
        if  self.target_transform is not None:
            return self.transform(img),self.target_transform(target)
        target=np.array(target).transpose(2,0,1).astype(np.int32)
        target=(target[0]*256+target[1])*256+target[2]
        target=self.class_index[target]
        return self.transform(img),target

    def __len__(self):
        return len(self.image)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.