GithubHelp home page GithubHelp logo

weecology / deepforest-pytorch Goto Github PK

View Code? Open in Web Editor NEW
15.0 7.0 9.0 163.11 MB

Pytorch implementation of the deepforest model for tree crown RGB detection.

License: MIT License

Python 36.26% CSS 1.10% JavaScript 62.64%

deepforest-pytorch's Introduction

This repo has been moved to https://github.com/Weecology/DeepForest

The deepforest-pytorch branch has been moved to the 'main' branch of https://github.com/Weecology/DeepForest. The original DeepForest repo was based on tensorflow and since been deprecated. Since this project was newer and had fewer users, it made sense to merge the two repos. Please use the DeepForest repo in all future cases.

DeepForest-pytorch

Conda package Documentation Status

A pytorch implementation of the DeepForest model for individual tree crown detection in RGB images. DeepForest is a python package for training and predicting individual tree crowns from airborne RGB imagery. DeepForest comes with a prebuilt model trained on data from the National Ecological Observatory Network. Users can extend this model by annotating and training custom models starting from the prebuilt model.

DeepForest es un paquete de python para la predicción de coronas de árboles individuales basada en modelos entrenados con imágenes remotas RVA ( RGB, por sus siglas en inglés). DeepForest viene con un modelo entrenado con datos proveídos por la Red Nacional de Observatorios Ecológicos (NEON, por sus siglas en inglés). Los usuarios pueden ampliar este modelo pre-construido por anotación de etiquetas y entrenamiento con datos locales. La documentación de DeepForest está escrita en inglés, sin embargo, agradeceríamos contribuciones con fin de hacerla accesible en otros idiomas.

DeepForest(PyTorch版本)是一个Python软件包,它可以被用来训练以及预测机载RGB图像中的单个树冠。DeepForest内部带有一个基于国家生态观测站网络(NEON : National Ecological Observatory Network)数据训练的预训练模型。在此模型基础上,用户可以注释新的数据然后训练自己的模型。DeepForest的文档是用英文编写的,如果您有兴趣为翻译文档做出贡献。欢迎与我们团队联系。

Motivation

The original DeepForest repo is written in tensorflow and can be found on pypi, conda and source (https://github.com/Weecology/DeepForest). After https://github.com/fizyr/keras-retinanet was deprecated, it became obvious that the shelf life of models that depend on tensorflow 1.0 was limited. The machine learning community is moving more towards pytorch, where many new models can be found.

Installation

Compiled wheels have been made for linux, osx and windows

#Install DeepForest-pytorch
pip install deepforest-pytorch

Usage

Use Benchmark release

from deepforest import main
m = main.deepforest()
m.use_release()

Train a new model

m.create_trainer()
m.trainer.fit(m)
m.evaluate(csv_file=m.config["validation"]["csv_file"], root_dir=m.config["validation"]["root_dir"])

Google colab demo on model training

Predict a single image

from deepforest import main
csv_file = '/Users/benweinstein/Documents/DeepForest-pytorch/deepforest/data/OSBS_029.tif'
df = trained_model.predict_file(csv_file, root_dir = os.path.dirname(csv_file))

Predict a large tile

predicted_boxes = trained_model.predict_tile(raster_path = raster_path,
                                        patch_size = 300,
                                        patch_overlap = 0.5,
                                        return_plot = False)

Evaluate a file of annotations using intersection-over-union

csv_file = get_data("example.csv")
root_dir = os.path.dirname(csv_file)
results = m.evaluate(csv_file, root_dir, iou_threshold = 0.5)

Config

DeepForest comes with a default config file (deepforest_config.yml) to control the location of training and evaluation data, the number of gpus, batch size and other hyperparameters. This file can be edited directly, or using the config dictionary after loading a deepforest object.

from deepforest import main
m = main.deepforest()
m.config["batch_size"] = 10

Config parameters are documented here.

Tree Detection Benchmark score

Tree detection is a central task in forest ecology and remote sensing. The Weecology Lab at the University of Florida has built a tree detection benchmark for evaluation. After building a model, you can compare it to the benchmark using the evaluate method.

git clone https://github.com/weecology/NeonTreeEvaluation.git
cd NeonTreeEvaluation
results = m.evaluate(csv_file = "evaluation/RGB/benchmark_annotations.csv", root_dir = "evaluation/RGB/")
results["box_recall"]
results["box_precision"]

deepforest-pytorch's People

Contributors

bw4sz avatar dingyif avatar dwaipayan05 avatar ethanwhite avatar henrykironde avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

deepforest-pytorch's Issues

More robust to 4 channel tif inputs

Lots of users have tif data that have hidden alpha channels, it should be smarter in converting these data.

OSError: Input file /orange/ewhite/b.weinstein/pfeifer/Fregata_Island_2016_Antarctic_shags.tif has 4 bands. DeepForest only accepts 3 band RGB rasters in the order (height, width, channels). If the image was cropped and saved as a .jpg, please ensure that no alpha channel was used.

Evaluate bounding boxes

main.py deepforest class needs an evaluate method that returns

  • mean-average-precision for a given intersection over union threshold between ground truth and predicted bounding boxes.
  • precision at a fixed probability cutoff
  • recall at a fixed probability cutoff
mAP, precision, recall = m.evaluate("/orange/b.weinstein/NeonTreeEvaluation/benchmark_annotations.csv", metrics=["mAP","precision","recall"], iou_threshold=0.4, probability_threshold=0.2)

Much of this code can be take and generalized from

https://github.com/weecology/NeonTreeEvaluation_python/blob/eaa5dfb3d2a17daf354c146ab158f69b94755535/src/eval.py#L9

Missing .jpeg File

I don't know if this is intentional or a mistake but, in the documentation there's mention of sample_image which is fetched through get_data("OSBS_029.jpeg") which is actually missing in the package.

Proposed Solution

  • Replace .jpeg with .png
  • Add OSBS_029.jpeg to Package

callback throws error

#Create objects
eval_callback = evaluate_callback(
    csv_file="/home/b.weinstein/NeonTreeEvaluation/evaluation/RGB/benchmark_annotations.csv", 
    root_dir="/home/b.weinstein/NeonTreeEvaluation/evaluation/RGB/",iou_threshold=0.4, score_threshold=0.1)

m = main.deepforest()
trainer = pytorch_lightning.Trainer(logger=comet_logger, max_epochs=1, callbacks=[evaluate_callback()], limit_train_batches=0.01, limit_val_batches=0.01)
GPU available: False, used: False
TPU available: None, using: 0 TPU cores
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/b.weinstein/miniconda3/envs/deepforest_pytorch/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/env_vars_connector.py", line 42, in overwrite_by_env_vars
    return fn(self, **kwargs)
  File "/home/b.weinstein/miniconda3/envs/deepforest_pytorch/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 347, in __init__
    self.on_init_start()
  File "/home/b.weinstein/miniconda3/envs/deepforest_pytorch/lib/python3.7/site-packages/pytorch_lightning/trainer/callback_hook.py", line 48, in on_init_start
    callback.on_init_start(self)
TypeError: on_init_start() missing 1 required positional argument: 'trainer'

Labels are currently hard coded

in dataset.py, need a read_classes() method take in a dict of classes either from file or as an arg. Everything is just hard coded as tree.

visualize.plot_predictions should take in a label or an integer

Right now it just takes an integer. We need a more general function that converts labels to integer if given. The integers are required for coloring scheme. Some thought needed on most intuitive way forward.

    plot_predictions(image=image, df=df)
Traceback (most recent call last):
  Debug Console, prompt 85, line 1
  File "/Users/benweinstein/opt/miniconda3/envs/DeepForest_pytorch/lib/python3.8/site-packages/deepforest/visualize.py", line 55, in plot_predictions
    color = label_to_color(row["label"])
  File "/Users/benweinstein/opt/miniconda3/envs/DeepForest_pytorch/lib/python3.8/site-packages/deepforest/visualize.py", line 109, in label_to_color
    return color_dict[label]
builtins.KeyError: 'Larch'
df.head()
     image_path  xmin  ymin  xmax  ymax  label
0  B10_0046.JPG    49    39   316   291  Larch
1  B10_0046.JPG   252    21   449   240  Larch
2  B10_0046.JPG   401    45   594   334  Larch
3  B10_0046.JPG   562    48   747   269  Larch
4  B10_0046.JPG   764     1   949   216  Larch

@dingyif , if interested.

Optional SoftNMS

see [35] N. Bodla, B. Singh, R. Chellappa and L.S. Davis, “Soft-NMS - Improving Object Detection with One Line of Code,” 2017 IEEE International Conference on Computer Vision (ICCV), IEEE, 2017

Azure devops windows SSL error

@henrykironde, can you follow this error? One of the collab needs the newer pip wheels than 0.1.24 and its blocking for a silly SSL windows pip error, nothing to do with source. I wonder if we should just wait this out? Feels like an azure problem.

https://dev.azure.com/benweinstein2010/DeepForest/_build/results?buildId=191&view=logs&j=2d2b3007-3c5c-5840-9bb0-2b1ea49925f3&t=168f295b-0553-5364-35f7-923225ecd8b3

everything that is our end passes

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "c:\hostedtoolcache\windows\python\3.9.2\x64\lib\runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "c:\hostedtoolcache\windows\python\3.9.2\x64\lib\runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "C:\hostedtoolcache\windows\Python\3.9.2\x64\Scripts\cibuildwheel.exe\__main__.py", line 7, in <module>
  File "c:\hostedtoolcache\windows\python\3.9.2\x64\lib\site-packages\cibuildwheel\__main__.py", line 180, in main
    cibuildwheel.windows.build(**build_options)
  File "c:\hostedtoolcache\windows\python\3.9.2\x64\lib\site-packages\cibuildwheel\windows.py", line 99, in build
    download('https://bootstrap.pypa.io/get-pip.py', get_pip_script)
  File "c:\hostedtoolcache\windows\python\3.9.2\x64\lib\site-packages\cibuildwheel\windows.py", line 70, in download
    response = urlopen(url)
  File "c:\hostedtoolcache\windows\python\3.9.2\x64\lib\urllib\request.py", line 214, in urlopen
    return opener.open(url, data, timeout)
  File "c:\hostedtoolcache\windows\python\3.9.2\x64\lib\urllib\request.py", line 517, in open
    response = self._open(req, data)
  File "c:\hostedtoolcache\windows\python\3.9.2\x64\lib\urllib\request.py", line 534, in _open
    result = self._call_chain(self.handle_open, protocol, protocol +
  File "c:\hostedtoolcache\windows\python\3.9.2\x64\lib\urllib\request.py", line 494, in _call_chain
    result = func(*args)
  File "c:\hostedtoolcache\windows\python\3.9.2\x64\lib\urllib\request.py", line 1389, in https_open
    return self.do_open(http.client.HTTPSConnection, req,
  File "c:\hostedtoolcache\windows\python\3.9.2\x64\lib\urllib\request.py", line 1349, in do_open
    raise URLError(err)
urllib.error.URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1123)>
##[error]Bash exited with code '1'.
Finishing: Bash

predict_image error for current PyPi release

from deepforest import main
from deepforest import get_data
model = main.deepforest()
model.use_release()
img=model.predict_image(path="/orange/ewhite/NeonTreeEvaluation/evaluation/RGB/TEAK_049_2019.tif",return_plot=True)

Results in:

---------------------------------------------------------------------------
error                                     Traceback (most recent call last)
<ipython-input-4-590a4318d91f> in <module>
----> 1 img=model.predict_image(path="/orange/ewhite/NeonTreeEvaluation/evaluation/RGB/TEAK_049_2019.tif",return_plot=True)

~/miniconda3/envs/deepforest/lib/python3.9/site-packages/deepforest/main.py in predict_image(self, image, path, return_plot)
    222 
    223         # Check if GPU is available and pass image to gpu
--> 224         result = predict.predict_image(model=self.model,
    225                                        image=image,
    226                                        return_plot=return_plot,

~/miniconda3/envs/deepforest/lib/python3.9/site-packages/deepforest/predict.py in predict_image(model, image, return_plot, device, iou_threshold)
     48         # Matplotlib likes no batch dim and channels first
     49         image = np.array(image.squeeze(0))[:,:,::-1]
---> 50         image = visualize.plot_predictions(image, df)
     51         return image
     52     else:

~/miniconda3/envs/deepforest/lib/python3.9/site-packages/deepforest/visualize.py in plot_predictions(image, df, color)
    115         if not color:
    116             color = label_to_color(row["label"])
--> 117         cv2.rectangle(image, (int(row["xmin"]), int(row["ymin"])), (int(row["xmax"]), int(row["ymax"])), color=color, thickness=1, lineType=cv2.LINE_AA)
    118 
    119     return image

error: OpenCV(4.5.2) /tmp/pip-req-build-ed2ahlq5/opencv/modules/core/src/copy.cpp:71: error: (-215:Assertion failed) cn <= 4 in function 'scalarToRawData'

Everything runs fine if return_plot=False.

Evaluate creates non-sensical results.

The release model is starting to look pretty decent, but the test recall and evaluate are below 1%, this doesn't seem reasonable. Sample image:
TEAK_043_step_5909

It seems more likely that the evaluate method isn't operating correctly. Having a release model (0.1.9) will allow us to test more rigorously.

Load and save models

main.py needs a way to save a trained model and then reload model method for recreating deepforest instances. test should confirm that prediction before and after should yield the same results.

@dingyif if you are interested, this seems like a good place to help. No pressure.

Build Wheel Error

While running pip install -r dev_requirements.txt inside my virtual environment the following error occurred

image

System Specifications

  • Python - 3.8.7
  • OS - Windows 10 Home Single Language

Profile IOU eval

def _overlap_(test_poly, truth_polys, rtree_index):
    """Calculate overlap between one polygon and all ground truth by area"""
    results = []
    matched_list = list(rtree_index.intersection(test_poly.geometry.bounds))
    for index in truth_polys.index:
        if index in matched_list:
            # get the original index just to be sure
            intersection_result = test_poly.geometry.intersection(
                truth_polys.loc[index].geometry)
            intersection_area = intersection_result.area
        else:
            intersection_area = 0
        results.append(
            pd.DataFrame({
                "prediction_id": [test_poly.prediction_id],
                "truth_id": [truth_polys.loc[index].truth_id],
                "area": intersection_area
            }))
    results = pd.concat(results)

    return results

Creating a pandas frame each round seems to be expensive.
Screen Shot 2021-05-16 at 8 03 37 PM

Out of memory error for predict file.

The tensors in model.predict file are doing something funny when predicting large file sizes. This issue needs to be expanded and explored.

https://www.comet.ml/bw4sz/everglades/view/MSbcVMcskDZNgNbj71oW96fJZ

 model.predict_file(csv_file = model.config["train"]["csv_file"], root_dir = model.config["train"]["root_dir"], savedir=model_savedir)
Job ID                  : 66798610
Cluster                 : hipergator
User                    : b.weinstein
Group                   : ewhite
State                   : OUT_OF_MEMORY (exit code 0)
Nodes                   : 1
Cores/Node              : 10
CPU Time Used           : 07:36:08
Wall Time Used          : 02:14:55
CPU Efficiency          : 33.81%
Memory Requested        : 60.00 GB (60.00 GB/node)
Memory Used             : 63.97 GB (6.40 GB/core estimated maximum)
Memory Efficiency       : 106.61%

Some garbage collection is needed to collect unused tensors.

How to avoid rate limit download error from github during tests?

@henrykironde can you look into something for me? To use the release model in /tests/ we download a file from the latest github release. I make sure to only run

model.use_release()

once in https://github.com/weecology/DeepForest-pytorch/blob/master/tests/conftest.py

and then inherit that fixture in all downstream tests to try to github API problems. Nevertheless, I occasionally get

tests/test_main.py:32: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
deepforest/main.py:61: in use_release
    release_tag, self.release_state_dict = utilities.use_release()
deepforest/utilities.py:66: in use_release
    headers={'Accept': 'application/vnd.github.v3+json'},
/usr/local/miniconda/envs/DeepForest_pytorch/lib/python3.7/urllib/request.py:222: in urlopen
    return opener.open(url, data, timeout)
/usr/local/miniconda/envs/DeepForest_pytorch/lib/python3.7/urllib/request.py:531: in open
    response = meth(req, response)
/usr/local/miniconda/envs/DeepForest_pytorch/lib/python3.7/urllib/request.py:641: in http_response
    'http', request, response, code, msg, hdrs)
/usr/local/miniconda/envs/DeepForest_pytorch/lib/python3.7/urllib/request.py:569: in error
    return self._call_chain(*args)
/usr/local/miniconda/envs/DeepForest_pytorch/lib/python3.7/urllib/request.py:503: in _call_chain
    result = func(*args)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <urllib.request.HTTPDefaultErrorHandler object at 0x7fec11aa8350>
req = <urllib.request.Request object at 0x7feb8a288310>
fp = <http.client.HTTPResponse object at 0x7feb6eee1250>, code = 403
msg = 'rate limit exceeded'
hdrs = <http.client.HTTPMessage object at 0x7feb6eee1a50>

    def http_error_default(self, req, fp, code, msg, hdrs):
>       raise HTTPError(req.full_url, code, msg, hdrs, fp)
E       urllib.error.HTTPError: HTTP Error 403: rate limit exceeded

/usr/local/miniconda/envs/DeepForest_pytorch/lib/python3.7/urllib/request.py:649: HTTPError

as in the failed run here

https://github.com/weecology/DeepForest-pytorch/runs/2136417458?check_suite_focus=true

basically because we ask github to check the release too often during tests? Is there an api_key, or maybe something we can mock that might be a workaround? I think this is just a tests problem since they are run in parallel and all communicating with github.

Add Henry as an author.

@henrykironde as you contribute more, we should have you listed as a full author instead of a contributor? Author line appears a few places and I want to make sure you are well noted.

Evaluate module not in docs.

@henrykironde do you have a second to guess at something?

The evaluate module is a blank page sitting on its own outside of deepforest package. At the same time there is no evaluate submodule in package source docs. There should be deepforest.evaluate submodule alongside all the others. Sphinx does not throw any flags. I didn't check readthedocs logs.

https://github.com/weecology/DeepForest-pytorch/blob/master/deepforest/evaluate.py

https://deepforest-pytorch.readthedocs.io/en/latest/source/deepforest.html

Looking at readthedocs, you can see
Screen Shot 2021-03-15 at 2 01 25 PM

I've tried deleting those files and recreating.

cd /Users/benweinstein/Documents/DeepForest-pytorch/docs
sphinx-apidoc -o source ../

but no change.
latest commit 86a6a31

Bug in evaluation script?

Hi everybody,
I am excited that you plan to make Deepforest available in Pytorch as well. While checking out the implementation I noticed that there might be a bug in evaluate.py>>evaluate: You calculate the precision and recall for every image separately and then average the result, which is not the same as the precision and reacall over all images. An average of averages is in your case incorrect.

utilities.shapefile_to_annotations won't save csv

Need to add the below code, in the end, will commit this after the previous commit merged successfully.

#save the data frame to CSV file at the specific location
    result.to_csv(os.path.join(savedir,'annotations.csv'),index=False)

Train release model

Using the data from DeepForest, need to train and validate a release model against the NeonTreeEvaluation benchmark.

Create document for subclassing transform for data augmentation

If a user wants to add new transforms, they can override the methods and create a new class.

from deepforest import main
from deepforest import dataset
from deepforest import utilities
import torch

class DeepForestAugmenter(main.deepforest):
    def __init__(self, transforms):
        super(self).__init__()
        self.transforms = transforms
        
    def load_dataset(self,
                     csv_file,
                     root_dir=None,
                     augment=False,
                     shuffle=True,
                     batch_size=1):
        """Create a tree dataset for inference
        Csv file format is .csv file with the columns "image_path", "xmin","ymin","xmax","ymax" for the image name and bounding box position.
        Image_path is the relative filename, not absolute path, which is in the root_dir directory. One bounding box per line.

        Args:
            csv_file: path to csv file
            root_dir: directory of images. If none, uses "image_dir" in config
            augment: Whether to create a training dataset, this activates data augmentations
        Returns:
            ds: a pytorch dataset
        """

        ds = dataset.TreeDataset(csv_file=csv_file,
                                 root_dir=root_dir,
                                 transforms=self.transforms(augment=augment),
                                 label_dict=self.label_dict)

        data_loader = torch.utils.data.DataLoader(
            ds,
            batch_size=batch_size,
            shuffle=shuffle,
            collate_fn=utilities.collate_fn,
            num_workers=self.config["workers"],
        )

        return data_loader    
    ```

Error in predict_tile

Could someone tell me what am I doing wrong here ?

After pip install deepforest-pytorch. Imported Release Model

test_model = main.deepforest()
test_model.use_release()

Tried Predicting Tile with Sample Data already provided

raster_path = get_data("OSBS_029.tif")
predicted_raster = test_model.predict_tile(raster_path, return_plot = True, patch_size=300,patch_overlap=0.25)

Received the following error

ValueError                                Traceback (most recent call last)
<ipython-input-26-923b61ec3ff0> in <module>()
      1 raster_path = get_data("OSBS_029.tif")
----> 2 predicted_raster = test_model.predict_tile(raster_path, return_plot = True, patch_size=400,patch_overlap=0.25)

9 frames
/usr/local/lib/python3.7/dist-packages/tifffile/tifffile.py in decode(exc, *args, **kwargs)
   5788 
   5789             def decode(*args, exc=str(exc)[1:-1], **kwargs):
-> 5790                 raise ValueError(f'TiffPage {self.index}: {exc}')
   5791 
   5792             return cache(decode)

ValueError: TiffPage 0: <COMPRESSION.LZW: 5> requires the 'imagecodecs' package

Environment : Google Colab

edge bounding box annotation

Annotation on the left edge of 2019_YELL_2_528000_4978000_image_crop2_34.png results in a bounding box that is not biologically feasible. Default training settings were used.

2019_YELL_2_528000_4978000_image_crop2_34.png 135 21 175 61 Tree
2019_YELL_2_528000_4978000_image_crop2_34.png 161 0 210 26 Tree
2019_YELL_2_528000_4978000_image_crop2_34.png 221 0 270 42 Tree
2019_YELL_2_528000_4978000_image_crop2_34.png 190 45 244 102 Tree
2019_YELL_2_528000_4978000_image_crop2_34.png 121 80 172 133 Tree
2019_YELL_2_528000_4978000_image_crop2_34.png 214 143 264 187 Tree
2019_YELL_2_528000_4978000_image_crop2_34.png 207 189 260 233 Tree
2019_YELL_2_528000_4978000_image_crop2_34.png 267 172 326 231 Tree
2019_YELL_2_528000_4978000_image_crop2_34.png 0 230 20 280 Tree
2019_YELL_2_528000_4978000_image_crop2_34.png 397 0 400 16 Tree
2019_YELL_2_528000_4978000_image_crop2_34.png 347 144 355 153 Tree
2019_YELL_2_528000_4978000_image_crop2_34.png 305 140 325 161 Tree
2019_YELL_2_528000_4978000_image_crop2_34.png 326 141 345 167 Tree
2019_YELL_2_528000_4978000_image_crop2_34.png 276 139 296 162 Tree
2019_YELL_2_528000_4978000_image_crop2_34.png 288 234 325 272 Tree
2019_YELL_2_528000_4978000_image_crop2_34.png 301 274 366 326 Tree
2019_YELL_2_528000_4978000_image_crop2_34.png 311 327 378 390 Tree
2019_YELL_2_528000_4978000_image_crop2_34.png 0 149 13 205 Tree
2019_YELL_2_528000_4978000_image_crop2_34.png 8 320 113 400 Tree
2019_YELL_2_528000_4978000_image_crop2_34.png 245 43 334 132 Tree

Roadmap

Motivation

It is clear that tensorflow's 1.0 -> 2.0 upgrade has severely limited the portability of the DeepForest package (https://travis-ci.org/github/weecology/DeepForest/builds/749697379). While it may be possible to continue to support tensorflow as the primary machine learning engine, the vast majority of code in this project is agnostic to the model prediction step. Judging from

fizyr/keras-retinanet#1471. The object detection model that underlies the tensorflow version doesn't have long to live.

As an aside, I like many people, want to thank @hgaiser for his work in leading the keras-retinanet repo.

Given that we have a number of research goals that relate to this model, it seems like the sooner we decide on a pathway forward the better. This repo and issue will serve as a central place for DeepForest users to see the status of moving towards pytorch. As a starting place I've copied the code from the DeepForest repo and will slowly work to replace the components.

Milestones

  • Swap pytorch for tensorflow dependencies and setup travis tests.
  • Replace prediction and training scripts with pytorch implementations
  • Replace tfrecord generation. This needs more thought, its not immediately obvious how to store the over 20 million pretraining images for semi-supervision.

Multi-class models

DeepForest was primarily designed as 'tree' detector, with one class. I think it is useful to atleast allow users to train multi-class models.

Things to investigate

  1. allow multiple labels in the dataset.py
  2. have a lookup dict for multiple labels when visualizing
  3. Add multiple colors for each class
  4. Make sure that evaluation statistics can report multiple classes
  5. Make a tutorial is reproducible data in colab
  6. Make sure that a reloaded model gets the same classes.

Its possible torch vision has already implemented all of this and just needs to be checked out. I do not yet know.

Add sanity check on image size

sometimes png are written with a sneaky alpha channel, such that they are 4 channel images, but look like tree. Write a check_image function to throw a flag.

RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 0

Update README and docs

The README from DeepForest needs to be updated to represent any changes in the API.

Key steps

  • Copy docs from DeepForest
  • Look for changes in names, predict_generator -> predict_file, check color ordering from tensorflow to pytorch
  • A tutorial, preferably written in google collab and/or ipython is needed on model training.
  • Once a pretrained model is recreated, a collab tutorial for prediction is needed.

@dingyif, feel free to ignore any of these contributions, but i'll always add you if interested. My lab has a strong focus on inclusivity and making tools available to a wide audience. I speak English and Spanish, and so did a brief summary of the package at the top. Do you want/is it useful to add a mandarin explanation?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.