GithubHelp home page GithubHelp logo

lightning-universe / lightning-flash Goto Github PK

View Code? Open in Web Editor NEW
1.7K 36.0 213.0 13.12 MB

Your PyTorch AI Factory - Flash enables you to easily configure and run complex AI recipes for over 15 tasks across 7 data domains

Home Page: https://lightning-flash.readthedocs.io

License: Apache License 2.0

Python 99.86% Makefile 0.03% Shell 0.10% HTML 0.01%
pytorch-lightning deep-learning machine-learning pytorch tasks-flash classification tabular-data object-detection icevision open3d

lightning-flash's Introduction

Your PyTorch AI Factory


InstallationFlash in 3 StepsDocsContributeCommunityWebsiteLicense

PyPI - Python Version PyPI Status PyPI - Downloads Slack license

CI testing codecov Documentation Status DOI


Flash makes complex AI recipes for over 15 tasks across 7 data domains accessible to all.
In a nutshell, Flash is the production grade research framework you always dreamed of but didn't have time to build.

Getting Started

From PyPI:

pip install lightning-flash

See our installation guide for more options.

Flash in 3 Steps

Step 1. Load your data

All data loading in Flash is performed via a from_* classmethod on a DataModule. To decide which DataModule to use and which from_* methods are available, it depends on the task you want to perform. For example, for image segmentation where your data is stored in folders, you would use the from_folders method of the SemanticSegmentationData class:

from flash.image import SemanticSegmentationData

dm = SemanticSegmentationData.from_folders(
    train_folder="data/CameraRGB",
    train_target_folder="data/CameraSeg",
    val_split=0.1,
    image_size=(256, 256),
    num_classes=21,
)

Step 2: Configure your model

Our tasks come loaded with pre-trained backbones and (where applicable) heads. You can view the available backbones to use with your task using available_backbones. Once you've chosen one, create the model:

from flash.image import SemanticSegmentation

print(SemanticSegmentation.available_heads())
# ['deeplabv3', 'deeplabv3plus', 'fpn', ..., 'unetplusplus']

print(SemanticSegmentation.available_backbones('fpn'))
# ['densenet121', ..., 'xception'] # + 113 models

print(SemanticSegmentation.available_pretrained_weights('efficientnet-b0'))
# ['imagenet', 'advprop']

model = SemanticSegmentation(
  head="fpn", backbone='efficientnet-b0', pretrained="advprop", num_classes=dm.num_classes)

Step 3: Finetune!

from flash import Trainer

trainer = Trainer(max_epochs=3)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
trainer.save_checkpoint("semantic_segmentation_model.pt")

PyTorch Recipes

Make predictions with Flash!

Serve in just 2 lines:

from flash.image import SemanticSegmentation

model = SemanticSegmentation.load_from_checkpoint("semantic_segmentation_model.pt")
model.serve()

or make predictions from raw data directly.

from flash import Trainer

trainer = Trainer(strategy='ddp', accelerator="gpu", gpus=2)
dm = SemanticSegmentationData.from_folders(predict_folder="data/CameraRGB")
predictions = trainer.predict(model, dm)

Flash Training Strategies

Training strategies are PyTorch SOTA Training Recipes which can be utilized with a given task.

Check out this example where the ImageClassifier supports 4 Meta Learning Algorithms from Learn2Learn. This is particularly useful if you use this model in production and want to make sure the model adapts quickly to its new environment with minimal labelled data.

from flash.image import ImageClassifier

model = ImageClassifier(
    backbone="resnet18",
    optimizer=torch.optim.Adam,
    optimizer_kwargs={"lr": 0.001},
    training_strategy="prototypicalnetworks",
    training_strategy_kwargs={
        "epoch_length": 10 * 16,
        "meta_batch_size": 4,
        "num_tasks": 200,
        "test_num_tasks": 2000,
        "ways": datamodule.num_classes,
        "shots": 1,
        "test_ways": 5,
        "test_shots": 1,
        "test_queries": 15,
    },
)

In detail, the following methods are currently implemented:

Flash Optimizers / Schedulers

With Flash, swapping among 40+ optimizers and 15+ schedulers recipes are simple. Find the list of available optimizers, schedulers as follows:

from flash.image import ImageClassifier

ImageClassifier.available_optimizers()
# ['A2GradExp', ..., 'Yogi']

ImageClassifier.available_schedulers()
# ['CosineAnnealingLR', 'CosineAnnealingWarmRestarts', ..., 'polynomial_decay_schedule_with_warmup']

Once you've chosen, create the model:

#### The optimizer of choice can be passed as
from flash.image import ImageClassifier

# - String value
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler=None)

# - Callable
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer=functools.partial(torch.optim.Adadelta, eps=0.5), lr_scheduler=None)

# - Tuple[string, dict]: (The dict takes in the optimizer kwargs)
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer=("Adadelta", {"epa": 0.5}), lr_scheduler=None)

#### The scheduler of choice can be passed as a
# - String value
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler="constant_schedule")

# - Callable
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler=functools.partial(CyclicLR, step_size_up=1500, mode='exp_range', gamma=0.5))

# - Tuple[string, dict]: (The dict takes in the scheduler kwargs)
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler=("StepLR", {"step_size": 10}))

You can also register you own custom scheduler recipes beforeahand and use them shown as above:

from flash.image import ImageClassifier

@ImageClassifier.lr_schedulers_registry
def my_steplr_recipe(optimizer):
    return torch.optim.lr_scheduler.StepLR(optimizer, step_size=10)

model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler="my_steplr_recipe")

Flash Transforms

Flash includes some simple augmentations for each task by default, however, you will often want to override these and control your own augmentation recipe. To this end, Flash supports custom transformations with the InputTransform. The InputTransform is like a callback for transforms, with hooks that can be used to apply transforms to samples or batches, on and off the device / accelerator. In addition, hooks can be specialized to apply transforms only to the input or target. With these hooks, complex transforms like MixUp can be implemented with ease. Here's an example (with an albumentations transform thrown in too!):

import torch
import numpy as np
import albumentations
from flash import InputTransform
from flash.image import ImageClassificationData
from flash.image.classification.input_transform import AlbumentationsAdapter


def mixup(batch, alpha=1.0):
    images = batch["input"]
    targets = batch["target"].float().unsqueeze(1)

    lam = np.random.beta(alpha, alpha)
    perm = torch.randperm(images.size(0))

    batch["input"] = images * lam + images[perm] * (1 - lam)
    batch["target"] = targets * lam + targets[perm] * (1 - lam)
    return batch


class MixUpInputTransform(InputTransform):

    def train_input_per_sample_transform(self):
        return AlbumentationsAdapter(albumentations.HorizontalFlip(p=0.5))

    # This will be applied after transferring the batch to the device!
    def train_per_batch_transform_on_device(self):
        return mixup


datamodule = ImageClassificationData.from_folders(
    train_folder="data/train",
    transform=MixUpInputTransform,
    batch_size=2,
)

Flash Zero - PyTorch Recipes from the Command Line!

Flash Zero is a zero-code machine learning platform built directly into lightning-flash using the Lightning CLI.

To get started and view the available tasks, run:

  flash --help

For example, to train an image classifier for 10 epochs with a resnet50 backbone on 2 GPUs using your own data, you can do:

  flash image_classification --trainer.max_epochs 10 --trainer.gpus 2 --model.backbone resnet50 from_folders --train_folder {PATH_TO_DATA}

Kaggle Notebook Examples

Contribute!

The lightning + Flash team is hard at work building more tasks for common deep-learning use cases. But we're looking for incredible contributors like you to submit new tasks!

Join our Slack and/or read our CONTRIBUTING guidelines to get help becoming a contributor!

Note: Flash is currently being tested on real-world use cases and is in active development. Please open an issue if you find anything that isn't working as expected.


Community

Flash is maintained by our core contributors.

For help or questions, join our huge community on Slack!


Citations

We’re excited to continue the strong legacy of opensource software and have been inspired over the years by Caffe, Theano, Keras, PyTorch, torchbearer, and fast.ai. When/if additional papers are written about this, we’ll be happy to cite these frameworks and the corresponding authors.

Flash leverages models from many different frameworks in order to cover such a wide range of domains and tasks. The full list of providers can be found in our documentation.


License

Please observe the Apache 2.0 license that is listed in this repository.

lightning-flash's People

Contributors

actis92 avatar akihironitta avatar ananyahjha93 avatar aniketmaurya avatar aribornstein avatar borda avatar carmocca avatar deepsource-autofix[bot] avatar dependabot[bot] avatar edenlightning avatar edgarriba avatar ehofesmann avatar ethanwharris avatar flozi00 avatar ibraheemmmoosa avatar justusschock avatar karthikrangasai avatar kaushikb11 avatar kingyiusuen avatar krshrimali avatar mabu-dev avatar pietrolesci avatar pre-commit-ci[bot] avatar seannaren avatar skaftenicki avatar sumanmichael avatar tchaton avatar teddykoker avatar uakarsh avatar williamfalcon avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

lightning-flash's Issues

Add Backbone API

🚀 Feature

Each backbone is more or less hard coded for each task. I think their should be a unified API for extending Task backbones.

Pass Trainer Arguments into Model.predict

🚀 Feature

I should have a way to override the predict trainer in model.predict

Motivation

Flashes model.predict doesn't support features like distributed inference this causes pretty significant bottlenecks when evaluating large models we need a way of streamlining this so evaluation will be faster.

[RFC] Simplify optional package imports in the codebase

(Originally suggested in Bolts Lightning-Universe/lightning-bolts#346)

🚀 Feature

I would like to suggest implementing a context manager class in order to handle optional imports better. With the context manager implemented, importing optional packages and deferring raising ModuleNotFoundError will look like the following.

# === before ===
_KORNIA_AVAILABLE = _module_available("kornia")
if _KORNIA_AVAILABLE:
    import kornia

class SomeClass:
    def __init__(...):
       if not _KORNIA_AVAILABLE:
           raise ModuleNotFoundError("Package `kornia` is not installed, install it with `pip install lightning-flash[image]`")

# === after ===
with ContextManager() as cm:
    import kornia  # doesn't raise `ModuleNotFoundError` here even if it fails.

class SomeClass:
    def __init__(self, ...):
        cm.check()  # raises `ModuleNotFoundError` if `import kornia` failed

Motivation

We can simplify and unify "raise error if not found" code blocks.

# === before ===
if not _KORNIA_AVAILABLE:
    raise ModuleNotFoundError("Package `kornia` is not installed, install it with `pip install lightning-flash[image]`")

# === after ===
cm.check()  # <- simple :)

As Flash depends on quite a number of optional packages, I believe it would reduce a lot of code duplication.

Pitch

Here is a minimal implementation of context manager and its demo.

Example

class ContextManager:
    def __init__(self):
        self.exc = None
        
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.exc = exc_type, exc_value, traceback
        return True  # Ignores exception if True and reraises exception if False

    def check(self):
        if self.exc is not None:
            exc_type, exc_value, traceback = self.exc
            raise exc_type(exc_value).with_traceback(traceback)

# Here's how it behaves
print("Try to import.")
with ContextManager() as cm:
    import aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
print("Tried to import.")
cm.check()
Try to import.
Tried to import.
Traceback (most recent call last):
  File "main.py", line 22, in <module>
    cm.check()
  File "main.py", line 15, in check
    raise exc_type(exc_value).with_traceback(traceback)
  File "main.py", line 19, in <module>
    import aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
ModuleNotFoundError: No module named 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'

See also

Built-in Types - Context Manager Types
Stack Overflow - How to safely handle an exception inside a context manager

Other comments

We can implement this suggested context manager class in PL and use it from Flash.

Let me know if it sounds reasonable or too much engineering for optional imports!

Installation fails when trying to install pandas

🐛 Bug

When trying to install lightning-flash via pip install lightning-flash, I get an error caused by the dependency on pandas==1.1.2:

ERROR: Could not build wheels for pandas which use PEP 517 and cannot be installed directly

I get the same error when I try to install pandas==1.1.2 independent of lightning-flash. I am able to install pandas==1.2.2 without any problems. I'm using Python 3.9.1 and Ubuntu 20.04.

Is it necessary to have the pandas version fixed at 1.1.2?

Expected behavior

Installation without any problems

Environment

  • PyTorch Version (e.g., 1.0): 1.7
  • OS (e.g., Linux): Ubuntu 20.04
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.9.1
  • CUDA/cuDNN version: 10.2
  • GPU models and configuration: 1080TI
  • Any other relevant information:

Support for Multi Label Classification

🚀 Feature

Currently Classification assumes that each image only has one label this is not always the case we should brainstorm how to support multi label classification scenarios.

Add split support in from folders

🚀 Feature

From ImageClassificationData.from_filepaths allows me to define splits for validation but from ImageClassificationData.from_folders does not this is inconsistent and should be a standard feature.

How do I ensure no data leakage in the validation split

What is your question?

I'm working on a dataset that has some identifiable information which would lead to data leakage if not the data is not split properly. Currently the validation split is hard coded in each respective DataModules.

        if valid_split:
            full_length = len(train_ds)
            train_split = int((1.0 - valid_split) * full_length)
            valid_split = full_length - train_split
            train_ds, valid_ds = torch.utils.data.random_split(
                train_ds,
                [train_split, valid_split],
                generator=torch.Generator().manual_seed(seed)
            )

Ideally I'd like a flag that enables me to ensure there is no overlap in the train and validation data on these fields by rebalancing any overlap. Due to the way dataset initialization is hardcoded once the validation dataset is created it becomes immutable.

What is the best way to handle such a check, in Flash?

ImageClassifier Predict Label

🚀 Feature

Prediction Currently Returns the class index based on the folder position in train. It would be great to be able to able to see the targets.

Motivation

It's hard to interpret class id numbers when you have more than a few classes.
model.predict(["cat.png"], target=true)
output -> 1

What does one mean?

Pitch

model.predict(["cat.png"], target=true)
output ->cat

Alternatives

Open to discussion

Additional context

Training the Image Embedder

I would like to perform large scale image retrieval. How do I retrain the Image Embedder?

I have clustered the images into multiple folders by clustering based on the pose and trained the ImageClassifier thinking that I can use it as a model to extract the embedding. Unfortunately the loss is not reducing at all. Is there any specific way to train for image retrieval?

FilePathDataset does not apply transform to image.

🐛 Bug

To Reproduce

Steps to reproduce the behavior:

  1. Go to '...'
  2. Run '....'
  3. Scroll down to '....'
  4. See error

Code sample

Expected behavior

Environment

  • PyTorch Version (e.g., 1.0):
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, source):
  • Build command you used (if compiling from source):
  • Python version:
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

Cannot import name 'PY3' with Pytorch 1.8

🐛 Bug

I get the following error after a fresh install (which install pytorch 1.8) and try to run the image classification example.

ImportError: cannot import name 'PY3' from 'torch._six'

It turns out that installing 1.7.1 the script seems to work as expected - so some dependency in pytorch 1.8 make this to happen.

To Reproduce

Steps to reproduce the behavior:

After creating a fresh venv

git clone https://github.com/PyTorchLightning/lightning-flash.git
cd lightning-flash
pip install -e .
python flash_examples/finetuning/image_classification.py

This is the full error message: https://pastebin.com/4xEtr5tv

Index out of error

It throws the following error on training

self_supervised/simclr/simclr_module.py", line 249, in optimizer_step
    param_group["lr"] = self.lr_schedule[self.trainer.global_step]
IndexError: index 900 is out of bounds for axis 0 with size 900

[text] Extendable Backbones and Dataloaders

🚀 Feature

Backbones and Dataloaders are tightly coupled to tasks this makes it hard for example to use a non huggingface model for text classification or flair embeddings for featurization.

Motivation

In order to prevent duplicate tasks for different backbone libraries Flash needs an extendable way of adding backbones to a task.

Pitch

Alternatives

Additional context

SWA not recognized by flash trainer

🐛 Bug

trainer = Trainer(
    gpus=1,
    auto_lr_find=True,
    precision=16,
    stochastic_weight_avg=True,
    max_epochs=10,
    auto_scale_batch_size='binsearch',
    callbacks=[EarlyStopping(monitor='val_binary_cross_entropy_with_logits')]
)
trainer.finetune(clf, data, strategy='freeze_unfreeze')

TypeError Traceback (most recent call last)
in
6 max_epochs=10,
7 auto_scale_batch_size='binsearch',
----> 8 callbacks=[EarlyStopping(monitor='val_binary_cross_entropy_with_logits')]
9 )
10 trainer.finetune(clf, data, strategy='freeze_unfreeze')

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/env_vars_connector.py in overwrite_by_env_vars(self, *args, **kwargs)
39
40 # all args were already moved to kwargs
---> 41 return fn(self, **kwargs)
42
43 return overwrite_by_env_vars

TypeError: init() got an unexpected keyword argument 'stochastic_weight_avg'

Flash Requires Internet to Load a Local Checkpoint

🐛 Bug

To Reproduce

Steps to reproduce the behavior:

  1. Turn of internet on device
  2. Load checkpoint from local file
clf = ImageClassifier.load_from_checkpoint('/kaggle/input/baserazcrmodels/razcr_resnet50_base_model.pt')
  1. Error caused by ImageClassifer Init which pulls from torchvision on init even though it's loading from a local checkpoint
---------------------------------------------------------------------------
gaierror                                  Traceback (most recent call last)
/opt/conda/lib/python3.7/urllib/request.py in do_open(self, http_class, req, **http_conn_args)
   1349                 h.request(req.get_method(), req.selector, req.data, headers,
-> 1350                           encode_chunked=req.has_header('Transfer-encoding'))
   1351             except OSError as err: # timeout error
/opt/conda/lib/python3.7/http/client.py in request(self, method, url, body, headers, encode_chunked)
   1276         """Send a complete request to the server."""
-> 1277         self._send_request(method, url, body, headers, encode_chunked)
   1278 
/opt/conda/lib/python3.7/http/client.py in _send_request(self, method, url, body, headers, encode_chunked)
   1322             body = _encode(body, 'body')
-> 1323         self.endheaders(body, encode_chunked=encode_chunked)
   1324 
/opt/conda/lib/python3.7/http/client.py in endheaders(self, message_body, encode_chunked)
   1271             raise CannotSendHeader()
-> 1272         self._send_output(message_body, encode_chunked=encode_chunked)
   1273 
/opt/conda/lib/python3.7/http/client.py in _send_output(self, message_body, encode_chunked)
   1031         del self._buffer[:]
-> 1032         self.send(msg)
   1033 
/opt/conda/lib/python3.7/http/client.py in send(self, data)
    971             if self.auto_open:
--> 972                 self.connect()
    973             else:
/opt/conda/lib/python3.7/http/client.py in connect(self)
   1438 
-> 1439             super().connect()
   1440 
/opt/conda/lib/python3.7/http/client.py in connect(self)
    943         self.sock = self._create_connection(
--> 944             (self.host,self.port), self.timeout, self.source_address)
    945         self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
/opt/conda/lib/python3.7/socket.py in create_connection(address, timeout, source_address)
    706     err = None
--> 707     for res in getaddrinfo(host, port, 0, SOCK_STREAM):
    708         af, socktype, proto, canonname, sa = res
/opt/conda/lib/python3.7/socket.py in getaddrinfo(host, port, family, type, proto, flags)
    751     addrlist = []
--> 752     for res in _socket.getaddrinfo(host, port, family, type, proto, flags):
    753         af, socktype, proto, canonname, sa = res
gaierror: [Errno -3] Temporary failure in name resolution
During handling of the above exception, another exception occurred:
URLError                                  Traceback (most recent call last)
<ipython-input-8-d3a1d7810d85> in <module>
      5 #                       num_classes=len(columns))
      6 
----> 7 clf = ImageClassifier.load_from_checkpoint("../input/baserazcrmodels/razcr_resnet50_base_model.pt")
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/core/saving.py in load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)
    155         checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
    156 
--> 157         model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
    158         return model
    159 
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/core/saving.py in _load_model_state(cls, checkpoint, strict, **cls_kwargs_new)
    196             _cls_kwargs = {k: v for k, v in _cls_kwargs.items() if k in cls_init_args_name}
    197 
--> 198         model = cls(**_cls_kwargs)
    199 
    200         # give model a chance to load something
/opt/conda/lib/python3.7/site-packages/flash/vision/classification/model.py in __init__(self, num_classes, backbone, pretrained, loss_fn, optimizer, metrics, learning_rate, multilabel)
     61         self.save_hyperparameters()
     62 
---> 63         self.backbone, num_features = backbone_and_num_features(backbone, pretrained=pretrained)
     64 
     65         self.head = nn.Sequential(
/opt/conda/lib/python3.7/site-packages/flash/vision/backbones.py in backbone_and_num_features(model_name, fpn, pretrained, trainable_backbone_layers, **kwargs)
     69 
     70     if model_name in TORCHVISION_MODELS:
---> 71         return torchvision_backbone_and_num_features(model_name, pretrained)
     72 
     73     raise ValueError(f"{model_name} is not supported yet.")
/opt/conda/lib/python3.7/site-packages/flash/vision/backbones.py in torchvision_backbone_and_num_features(model_name, pretrained)
    128 
    129     elif model_name in RESNET_MODELS:
--> 130         model = model(pretrained=pretrained)
    131         # remove the last two layers & turn it into a Sequential model
    132         backbone = nn.Sequential(*list(model.children())[:-2])
/opt/conda/lib/python3.7/site-packages/torchvision/models/resnet.py in resnet50(pretrained, progress, **kwargs)
    263     """
    264     return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
--> 265                    **kwargs)
    266 
    267 
/opt/conda/lib/python3.7/site-packages/torchvision/models/resnet.py in _resnet(arch, block, layers, pretrained, progress, **kwargs)
    225     if pretrained:
    226         state_dict = load_state_dict_from_url(model_urls[arch],
--> 227                                               progress=progress)
    228         model.load_state_dict(state_dict)
    229     return model
/opt/conda/lib/python3.7/site-packages/torch/hub.py in load_state_dict_from_url(url, model_dir, map_location, progress, check_hash, file_name)
    553             r = HASH_REGEX.search(filename)  # r is Optional[Match[str]]
    554             hash_prefix = r.group(1) if r else None
--> 555         download_url_to_file(url, cached_file, hash_prefix, progress=progress)
    556 
    557     if _is_legacy_zip_format(cached_file):
/opt/conda/lib/python3.7/site-packages/torch/hub.py in download_url_to_file(url, dst, hash_prefix, progress)
    423     # certificates in older Python
    424     req = Request(url, headers={"User-Agent": "torch.hub"})
--> 425     u = urlopen(req)
    426     meta = u.info()
    427     if hasattr(meta, 'getheaders'):
/opt/conda/lib/python3.7/urllib/request.py in urlopen(url, data, timeout, cafile, capath, cadefault, context)
    220     else:
    221         opener = _opener
--> 222     return opener.open(url, data, timeout)
    223 
    224 def install_opener(opener):
/opt/conda/lib/python3.7/urllib/request.py in open(self, fullurl, data, timeout)
    523             req = meth(req)
    524 
--> 525         response = self._open(req, data)
    526 
    527         # post-process response
/opt/conda/lib/python3.7/urllib/request.py in _open(self, req, data)
    541         protocol = req.type
    542         result = self._call_chain(self.handle_open, protocol, protocol +
--> 543                                   '_open', req)
    544         if result:
    545             return result
/opt/conda/lib/python3.7/urllib/request.py in _call_chain(self, chain, kind, meth_name, *args)
    501         for handler in handlers:
    502             func = getattr(handler, meth_name)
--> 503             result = func(*args)
    504             if result is not None:
    505                 return result
/opt/conda/lib/python3.7/urllib/request.py in https_open(self, req)
   1391         def https_open(self, req):
   1392             return self.do_open(http.client.HTTPSConnection, req,
-> 1393                 context=self._context, check_hostname=self._check_hostname)
   1394 
   1395         https_request = AbstractHTTPHandler.do_request_
/opt/conda/lib/python3.7/urllib/request.py in do_open(self, http_class, req, **http_conn_args)
   1350                           encode_chunked=req.has_header('Transfer-encoding'))
   1351             except OSError as err: # timeout error
-> 1352                 raise URLError(err)
   1353             r = h.getresponse()
   1354         except:
URLError: <urlopen error [Errno -3] Temporary failure in name resolution>

Image classification Finetuning tutorial.

📚 Documentation

Having pictures and text with the predictions would be nice. Probably there is some lib that already does that, but so this was the first thing I did here:

image

Also, if you look at the tutorial, finetuning and then loading the pre-finetuned model from s3 doesn't make a whole lot of sense to me. I would suggest to load from the just created file and comment out the line that loads from s3.

Finally, the s3 weight file name seems awfully generic for a bee vs. ant model.

Flash doesn't support albumentations transforms

🐛 Bug

You have to pass data to augmentations as named arguments, for example: aug(image=image). Flash doesn't provide functionality for this with the existing data pipeline.

To Reproduce

train_transform = albumentations.Compose(
            [
                albumentations.RandomResizedCrop(height=image_size, width=image_size, scale=(0.9, 1), p=1),
                albumentations.HorizontalFlip(p=0.5),
                albumentations.ShiftScaleRotate(p=0.5)
            ]
)

data = ImageClassificationData.from_filepaths(
    train_transform = train_transform,
    batch_size=32,
    train_filepaths=os.path.join(root, 'train'),
    train_labels=train_labels,
    valid_split=0.10,
    num_workers=4
)
KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop
    data = fetcher.fetch(index)
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataset.py", line 272, in __getitem__
    return self.dataset[self.indices[idx]]
  File "/opt/conda/lib/python3.7/site-packages/flash/vision/classification/data.py", line 77, in __getitem__
    img = self.transform(img)
  File "/opt/conda/lib/python3.7/site-packages/albumentations/core/composition.py", line 165, in __call__
    raise KeyError("You have to pass data to augmentations as named arguments, for example: aug(image=image)")
KeyError: 'You have to pass data to augmentations as named arguments, for example: aug(image=image)'

Code sample

Expected behavior

Additional context

How do I use the LR Scheduler on a general task training?

Hi,
This project is sophisticated, exciting for me.

How do I use the LR Scheduler on a general task training?
As far as I saw, flash.Task is not designed to allow me to configure the LR scheduler.
Also, there is no way to set any parameters other than the LR on the optimizer (like momentum, weight_decay, etc.)

To get around this, I created a class inherited from flash.Task as follows:

class MyTask(flash.Task):
    def __init__(self, model, optimizer: torch.optim.Optimizer, scheduler: dict, *args, **kwargs):
        super().__init__(model, *args, **kwargs)
        self.optimizer = optimizer
        self.scheduler = scheduler
        
    def configure_optimizers(self):
        return [self.optimizer], [self.scheduler]

and use like:

model = ...
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
scheduler = {'scheduler': CosineAnnealingLR(optimizer, T_max=200), 'interval': 'epoch', 'frequency': 1}

task = MyTask(model, loss_fn=torch.nn.functional.cross_entropy, optimizer=optimizer, scheduler=scheduler)

flash.Trainer(...).fit(task, train_data_loader, val_data_loader)

(Reason why I wanted to adopt this was to get reproduction: https://github.com/kuangliu/pytorch-cifar)

Actually, I don't like this extension.
This loses the simplicity of the flash.

Is there any good ideas? Thank you.

Dependency version conflicts on Kaggle Notebooks

🐛 Bug

In Kaggle Notebooks, there's dependency warnings during installation but it succeeds, when trying to import flash an error is thrown complaining about the version of tokenizers

To Reproduce

Steps to reproduce the behavior:

  1. Create a Kaggle Notebook with GPU accelerator
  2. Run !pip install lightning-flash
  3. Try importing flash with import flash
  4. See error:
VersionConflict: tokenizers==0.9.4 is required for a normal functioning of this module, but found tokenizers==0.9.3.
Try: pip install transformers -U or pip install -e '.[dev]' if you're working with git master

Code sample

!pip install lightning-flash 
import flash

Expected behavior

image

Environment

  • Kaggle Notebooks
  • GPU Compute

Additional context

Customizable data pipeline for object detection

🚀 Feature

I would like to have a flexible interface to customize dataset and data pipeline for object detection

Motivation

Thanks for creating this fantastic library. For research or application, I want to use different datasets other than CustomCOCODataset. There are two possible scenarios:

  • using datasets readily available in different format (e.g. YOLO) without converting the format from YOLO to COCO. Here, I assume my model knows how to read and infer the labels (e.g. xyxy, xywh) and build targets from the dataset label.
  • I want to apply some multi-image data augmentation such as Mixup or mosaic augmentation to create new training image from the combination of multiple images from the dataset.

Is it possible to do any of these two scenarios? Can I swap the CustomCOCODataset with my custom LightningDataModule? Do we need to customize ObjectDetectionDataPipeline? I am not sure what the task pipeline is for. Some guideline would be appreciated. Thanks.

Improve DataPipeline API

🚀 Feature

Motivation

  • Should we DataPipeline by splitted in 2, collate_fn and uncollate_fn ?

Reason:

  • collate_fn is really coming from the dataset to process new raw_data
  • uncollate_fn is usually related to a task such as classification. It can be confusing.

default DataPipeline

class DataPipeline(CollatePipeline, UnCollatePipeline):

    ...

 

create text classification data pipeline.

class TextClassificationDataPipeline(TextCollatePipeline, ClassificationUnCollatePipeline):

    ...

 
  • Create BaseDataModule per data-type such as TextDataModule, ImageDataModule + their associated TextDataPipeline, ImageDataPipeline as default.

Pitch

Alternatives

Additional context

Train general translation task

🚀 Feature

Currently we use mbart-large-en-ro as defined here: https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/text/seq2seq/translation/model.py#L39

We should move towards mbart-large and pre-train from this backbone. This would be cleaner and more applicable to a large number of languages.

The other consideration is that this model is around 600M parameters/2gb which is the reasonable size given the translation task. We may want to consider smaller variants for CI + quick iterations!

Adding SOTA Ensembling for vision tasks

🚀 Feature

So Ensemble a De-Facto for winning competitions and even in real world scenario it is used given enough time budget on inference

multi model ensembling

classification

better_prediction = flash.ensemble(models=[model1, model2], type='classification', method='some_sota_method')

object detection

# Weighted Boxes Fusion
better_prediction = flash.ensemble(models=[model1, model2], type='object_detection', method='WBF')

TTA support is possible too.

e.g

# classification.
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
config = {'transformations':['original', 'flip_lr']}
better_predictions = model.ensemble(type='TTA', task='classification', config=config})
# object detection
model
config = {'transformations : ['original', [1024, 1024], [1080, 1920]], 'fusion':'WBF'} #['original', 'resize', 'resize']
better_predictions = model.ensemble(type='TTA', task='detection', config=config) # fuse using WBF i.e weighted bounding box fusion (SOTA)

Motivation

Pitch

Alternatives

Additional context

Unable to run ImageEmbedder

I use the example code to run the ImageEmbedder but it throws the following error,


anaconda3/envs/pose/lib/python3.8/site-packages/pycocotools/mask.py", line 3, in <module>
    import pycocotools._mask as _mask
  File "pycocotools/_mask.pyx", line 1, in init pycocotools._mask
ValueError: numpy.ndarray size changed, may indicate binary incompatibility. Expected 88 from C header, got 80 from PyObject

Make categorical_input/numerical_input optional for TabularData.from_df

🐛 Bug

If I have only numerical features then I still have to pass an empty list for the categorical_input. It can be optional.
image

To Reproduce

Steps to reproduce the behavior:

  1. df = pd.DataFrame({'digit': [1,2,3], 'odd_even':[0,1,0]})
  2. datamodule = TabularData.from_df(df, 'odd_even',
    numerical_input=['digit'],
    )
TypeError                                 Traceback (most recent call last)
<ipython-input-122-405a8bb49976> in <module>
      1 datamodule = TabularData.from_df(final_data, 'target', 
----> 2                     numerical_input=train_x.columns.tolist(),
      3 #                     categorical_input=[],
      4                    )

TypeError: from_df() missing 1 required positional argument: 'categorical_input'

Code sample

df = pd.DataFrame({'digit': [1,2,3], 'odd_even':[0,1,0]})
datamodule = TabularData.from_df(df, 'odd_even', 
                    numerical_input=['digit'],
                   )

Expected behaviour

If only one of categorical or numerical input is passed then users should not be forced to enter an empty list.

Environment

  • PyTorch Version (e.g., 1.0): '1.7.1'
  • OS (e.g., Linux): MacOS
  • How you installed PyTorch (conda, pip, source): pip
  • Python version:Python 3.7.9

I would love to start my contribution to Flash by fixing this issue.

ImageClassificationData.from_arrays

🚀 Feature

In addition from data reading from file paths it should be possible to image data and labels from numpy arrays like required in the mnist kaggle challenge. https://www.kaggle.com/c/digit-recognizer

Motivation

Data processing in Flash needs to be improved there are many different ways to ingest data but flash requires data to be in folders for image classification this forces users to do inefficient preprocessing to use flash for basic scenarios. This needs to be improved.

Pitch

Alternatives

Additional context

[RFC] Hosting and Serving Models with Flash

Flash Serve !!

Motivation

Most people want to also try a demo of models as REST API or say some sort of inference API.
E.g. HuggingFace API, people can have a run out of the box.

This also demonstrates some of out of box models which are associated with Tasks.
Sometimes people need a simple service, either for hobbies or for say small tasks.
Maybe someone is doing a small project and needs an API out of the box that does Object Detection, Segmentation etc
Alternatives are with Google Cloud AI service or say Azure AI service, etc. but they are paid and too difficult to work with.
Many such use cases are possible!

Pitch

Serve Flash models with Torchserve. Torchserve works out of the box with all torchvision models (we use them in Flash)
Torchserve makes it easy to create hosted API, we don't have to worry about protocols underlying inference logics.
Well supported and scalable. Also, it is quite fast.

Why torchserve ?

Well writing own protocols and simple REST API is sometimes not enough, a scalable approach that works really well is needed.
Torchserve directly works on all OSes, support to cloud platforms too. Written in Java so really fast.
We don't have to worry about scalability, and end users don't have to worry about models !

Additional Context

Well there are various ways of creating RESTful service, say FastAPI, Flask, Django, Node.js, what not. But If we use Torchserve we don't have to worry about multi-processing over CPU, GPU, supporting all OSes, scaling to say Kubernetes, etc.

Thanks a lot to @aniketmaurya for sparking these over slack discussion !

P.S. I'm highly willing to contribute to this. Let me know !

Add Standardization for Flash DataModules

🚀 Feature

Flash DataModules have a lot of overlaping boilerplate and inconsistent functionality we should standardize the best pracitices such as

  • Train, test and validation transforms
  • Managing Train, test and validation splits
  • Ensuring no overlap between train and test/val
  • Enabling training folds

Loading

  • From filepaths
  • From pandas
  • From folders
  • From Numpy

align with PR 1.2.0

🚀 Feature

Atm, we have a hard link to 1.2.0rc0 which is a bit outdated...
Probably some API adjust need to be done, continue in #62

Motivation

Pitch

Alternatives

Additional context

Regenerate weights with lower PT versions

🚀 Feature

Regenerate weights (re-train models) with lower PT versions

Motivation

extend usability also for users who for some reason cannot use the latest PT version

Pitch

Alternatives

Additional context

For example for PL legacy testing, we create all checkpoint with PT 1.4 and they are fully compatible with all PT version 1.4 and higher (even the latest 1.8)

Add sample visualization option to the DataModule API

🚀 Feature

It would be great to have the ability to sample image, text, tabular and audio samples from the flash datamodule.

Motivation

In ImageClassificationData, TextClassificationData, TabularClassificationData etc it would be great to have an overridable .visualize() method that plots or visualizes a sample of data.

Pitch

datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
valid_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
)

datamodule.visualize(train, num_samples=6)

[ 1, 2, 3,
4, 5, 6]

Alternatives

Open to other ideas.

Additional context

from_datamodules and dataset flexibility

🚀 Feature

With the current datapipeline, If I want to customize my data is no easy way for me to provide my own data module and take advantage of flashs existing capabilities for validation splits and default transforms.

Motivation

While theoretically you can provide any loss function to a ImageClassificationModule in practice any non mutinomial loss such as binary cross entropy causes Flash to crash.

Ideally it should be easy to override this, but the way the flash create_from_folders abstracts hard codes dataset creation prevents me from being able to easily override the underlying filepath_dataset and folder_dataset classes meaning that if I want to do this myself I need to create my own datamodule.

If I create my own datamodule I lose all the flash features that I get using the from_folders and from_filepaths methods such as the ability to apply default transforms, split my train and validation data and any other future capabilities we may add leading to increased boilerplate .

Pitch

One way to make this better would be to have a from datamodule feature in the datapipeline though I think this only papers over the core issue. The core issue comes from hardcoding the underlying dataset class in these functions without providing any mechanism to override them.

I'm not sure the right way to make this change to flash without potentially breaking things or causing a conflict with the current flash refractor.

Alternatives

Additional context

Flash fails to load image labels from CSV on Kaggle RANZCR CLiP - Catheter and Line Position Challenge

🐛 Bug

The master code does not properly load labels from CSV for image classification. Trying to get a baseline working for the Catheter and Line Position Challenge and am currently blocked.

To Reproduce

Steps to reproduce the behavior:

  1. !pip install git+https://github.com/PyTorchLightning/lightning-flash.git@master
  2. Run '....'
import os
import flash
from flash.core.data import download_data
from flash.core.finetuning import FreezeUnfreeze
from flash.vision import ImageClassificationData, ImageClassifier

data = ImageClassificationData(
    train_filepaths='/kaggle/input/ranzcr-clip-catheter-line-classification/train', 
    train_labels='/kaggle/input/ranzcr-clip-catheter-line-classification/train.csv', 
    valid_filepaths='/kaggle/input/ranzcr-clip-catheter-line-classification/test', 
    valid_labels='/kaggle/input/ranzcr-clip-catheter-line-classification/test.csv'
)

Stack Trace

Successfully installed datasets-1.2.1 lightning-flash-0.2.1.dev0 pandas-1.1.2 pycocotools-2.0.2 pytorch-lightning-1.2.0rc0 pytorch-lightning-bolts-0.3.0 pytorch-tabnet-3.1.1 rouge-score-0.0.4 scikit-learn-0.24.0 torch-1.7.1 torchvision-0.8.2 tqdm-4.49.0 xxhash-2.0.0
TypeError

TypeError Traceback (most recent call last)
in
3 train_labels='/kaggle/input/ranzcr-clip-catheter-line-classification/train.csv',
4 valid_filepaths='/kaggle/input/ranzcr-clip-catheter-line-classification/test',
----> 5 valid_labels='/kaggle/input/ranzcr-clip-catheter-line-classification/test.csv'
6 )

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/core/datamodule.py in call(cls, *args, **kwargs)
48
49 # Get instance of LightningDataModule by mocking its init via call
---> 50 obj = type.call(cls, *args, **kwargs)
51
52 return obj

TypeError: init() got an unexpected keyword argument 'train_filepaths'
init() got an unexpected keyword argument 'train_filepaths'
Enter console command here

Expected behavior

Data Module loads labels without issue

Fix data caching for faster testing

🐛 Bug

It seems that data caching isn't working properly. Every ci run is trying to download datasets, and it's taking too long.

Currently, it actually saves data to:

/tmp/pytest-of-nitta/pytest-0/test_finetune_example_finetuni0/data/hymenoptera_data.zip

To Reproduce

Example: https://github.com/PyTorchLightning/lightning-flash/runs/1874512493

Expected behavior

Downloaded datasets should be saved within the project root to share data across all test cases:

${PROJECT_ROOT}/data/hymenoptera_data.zip

NLTK being loaded on image classifcation

🐛 Bug

To Reproduce

from flash.data import labels_from_csv
from flash.vision import ImageClassificationData
from flash.vision import ImageClassifier
from flash import Trainer

[nltk_data] Error loading punkt: <urlopen error [Errno -3] Temporary
[nltk_data] failure in name resolution>

ValueError while finetuning

I'm using the following code to fine-tune embedder,


import flash
from flash import download_data
from flash.vision import ImageClassificationData, ImageEmbedder


datamodule = ImageClassificationData.from_folders(
    train_folder="assets/classes/train/",
    valid_folder="assets/classes/val/",
    test_folder="assets/classes/test/",
)

# 3. Build the model
embedder = ImageEmbedder(backbone="resnet18")

# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1)

# 5. Train the model
trainer.finetune(embedder, datamodule=datamodule, strategy="freeze_unfreeze")

# 6. Test the model
trainer.test()

# 7. Save it!
trainer.save_checkpoint("image_embedder_model.pt")

Unfortunately it throws the following error,

/home/oggie/anaconda3/envs/pose/lib/python3.8/site-packages/pl_bolts/utils/warnings.py:30: UserWarning: You want to use `wandb` which is not installed yet, install it with `pip install wandb`.
  stdout_func(
/home/oggie/anaconda3/envs/pose/lib/python3.8/site-packages/pl_bolts/utils/warnings.py:30: UserWarning: You want to use `gym` which is not installed yet, install it with `pip install gym`.
  stdout_func(
GPU available: True, used: False
TPU available: None, using: 0 TPU cores
/home/oggie/anaconda3/envs/pose/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: GPU available but not used. Set the --gpus flag when calling the script.
  warnings.warn(*args, **kwargs)
Traceback (most recent call last):
  File "/home/oggie/Workspace/Python/KapschUI/embedded_train.py", line 20, in <module>
    trainer.finetune(embedder, datamodule=datamodule, strategy="freeze_unfreeze")
  File "/home/oggie/anaconda3/envs/pose/lib/python3.8/site-packages/flash/core/trainer.py", line 90, in finetune
    return super().fit(model, train_dataloader, val_dataloaders, datamodule)
  File "/home/oggie/anaconda3/envs/pose/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 468, in fit
    self.accelerator_backend.setup(model)
  File "/home/oggie/anaconda3/envs/pose/lib/python3.8/site-packages/pytorch_lightning/accelerators/legacy/cpu_accelerator.py", line 49, in setup
    self.setup_optimizers(model)
  File "/home/oggie/anaconda3/envs/pose/lib/python3.8/site-packages/pytorch_lightning/accelerators/legacy/accelerator.py", line 140, in setup_optimizers
    optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
  File "/home/oggie/anaconda3/envs/pose/lib/python3.8/site-packages/pytorch_lightning/trainer/optimizers.py", line 30, in init_optimizers
    optim_conf = model.configure_optimizers()
  File "/home/oggie/anaconda3/envs/pose/lib/python3.8/site-packages/flash/core/model.py", line 153, in configure_optimizers
    return self.optimizer_cls(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate)
  File "/home/oggie/anaconda3/envs/pose/lib/python3.8/site-packages/torch/optim/sgd.py", line 68, in __init__
    super(SGD, self).__init__(params, defaults)
  File "/home/oggie/anaconda3/envs/pose/lib/python3.8/site-packages/torch/optim/optimizer.py", line 47, in __init__
    raise ValueError("optimizer got an empty parameter list")
ValueError: optimizer got an empty parameter list

Process finished with exit code 1

TabNet Classification Broken

🐛 Bug

Tabular Classification throwing Index Error on prediction

To Reproduce

Steps to reproduce the behavior:
model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabnet_classification_model.pt")
predictions = model.predict("../input/titanic/test.csv")
print(predictions)

See error

IndexError                                Traceback (most recent call last)
<ipython-input-18-17cd8b85d1d8> in <module>
      1 model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabnet_classification_model.pt")
----> 2 predictions = model.predict("../input/titanic/test.csv")
      3 print(predictions)

/opt/conda/lib/python3.7/site-packages/flash/tabular/classification/model.py in predict(self, x, batch_idx, skip_collate_fn, dataloader_idx, data_pipeline)
     80         data_pipeline = data_pipeline or self.data_pipeline
     81         batch = x if skip_collate_fn else data_pipeline.collate_fn(x)
---> 82         predictions = self.forward(batch)
     83         return data_pipeline.uncollate_fn(predictions)
     84 

/opt/conda/lib/python3.7/site-packages/flash/tabular/classification/model.py in forward(self, x_in)
     86         # TabNet takes single input, x_in is composed of (categorical, numerical)
     87         x = torch.cat([x for x in x_in if x.numel()], dim=1)
---> 88         return self.model(x)[0]
     89 
     90     @classmethod

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/opt/conda/lib/python3.7/site-packages/pytorch_tabnet/tab_network.py in forward(self, x)
    580 
    581     def forward(self, x):
--> 582         x = self.embedder(x)
    583         return self.tabnet(x)
    584 

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/opt/conda/lib/python3.7/site-packages/pytorch_tabnet/tab_network.py in forward(self, x)
    847             else:
    848                 cols.append(
--> 849                     self.embeddings[cat_feat_counter](x[:, feat_init_idx].long())
    850                 )
    851                 cat_feat_counter += 1

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/sparse.py in forward(self, input)
    124         return F.embedding(
    125             input, self.weight, self.padding_idx, self.max_norm,
--> 126             self.norm_type, self.scale_grad_by_freq, self.sparse)
    127 
    128     def extra_repr(self) -> str:

/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   1850         # remove once script supports set_grad_enabled
   1851         _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 1852     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
   1853 
   1854 

IndexError: index out of range in self

Code sample

Expected behavior

Environment

  • PyTorch Version (e.g., 1.0):
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, source):
  • Build command you used (if compiling from source):
  • Python version:
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

Download Data Fails if Content Length Not Defined in Header

🐛 Bug

When I try to download a zip file using download_data from flash.core.data it fails because the response header does not contain a value for 'Content Length' this should be check for and handled in the code.

To Reproduce

Steps to reproduce the behavior:

KeyError Traceback (most recent call last)
in ()
1 # 1. Download the data
----> 2 download_data("https://github.com/karoldvl/ESC-50/archive/master.zip", 'data/')

2 frames
/content/gdrive/MyDrive/lightning-flash/flash/core/data/utils.py in download_data(url, path)
75
76 """
---> 77 download_file(url, path)
78
79

/content/gdrive/MyDrive/lightning-flash/flash/core/data/utils.py in download_file(url, path, verbose)
36 local_filename = os.path.join(path, url.split('/')[-1])
37 r = requests.get(url, stream=True)
---> 38 file_size = int(r.headers['Content-Length'])
39 chunk = 1
40 chunk_size = 1024

/usr/local/lib/python3.6/dist-packages/requests/structures.py in getitem(self, key)
52
53 def getitem(self, key):
---> 54 return self._store[key.lower()][1]
55
56 def delitem(self, key):

KeyError: 'content-length'

Code sample

import flash
from flash.core.data import download_data
download_data("https://github.com/karoldvl/ESC-50/archive/master.zip", 'data/')

Expected behavior

File downloads and extracts ESC-50 data into datasets folder

Environment

Default Collab Configuration

Additional context

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.