GithubHelp home page GithubHelp logo

bjarten / early-stopping-pytorch Goto Github PK

View Code? Open in Web Editor NEW
1.2K 9.0 293.0 536 KB

Early stopping for PyTorch

License: MIT License

Python 2.34% Jupyter Notebook 97.66%
pytorch early-stopping early stopping pytorch-tutorial python mnist regularization tutorial

early-stopping-pytorch's People

Contributors

adilzouitine avatar anshulrai avatar bjarten avatar eddinho avatar simonmossmyr avatar wolframalpha avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

early-stopping-pytorch's Issues

How do we install pytorchtools ?

Even if we do pip install pytorchtools , this error creeps up
cannot import name 'EarlyStopping' from 'pytorchtools' (/usr/local/lib/python3.7/dist-packages/pytorchtools/init.py)

early-stopping-pytorch.py, shouldn't the script save previous model?

Hi. Thanks for your resources :) They are giving great help to my research.

BTW, isn't it correct to save the previous model when the loss decreases? It seems like your current early-stopping-pytorch.py script is intended to save the current model, not the previous one.

ImportError: cannot import name 'EarlyStopping' from 'pytorchtools'

Hello @Bjarten

I am trying to use your approach for the early stopping. I installed pytorchtools by using "pip install pytorchtools"
Then I wanted to import EarlyStopping using "from pytorchtools import EarlyStopping", but eventually I received the following error:

ImportError: cannot import name 'EarlyStopping' from 'pytorchtools' (C:\Users\Name\anaconda3\envs\abys\lib\site-packages\pytorchtools_init_.py)

Can you please let me know why I am receiving this error and how to fix it potentially? Thank you in advance!

unsupported operand type(s) for +: 'builtin_function_or_method' and 'builtin_function_or_method'

Hi I use that example in my code but i got that problem

Traceback (most recent call last):
File "/mnt/HDD/train_Model_1.py", line 237, in
train_loss = np.average(train_losses)
File "<array_function internals>", line 6, in average
File "/home/saida/.local/lib/python3.6/site-packages/numpy/lib/function_base.py", line 380, in average
avg = a.mean(axis)
File "/home/saida/.local/lib/python3.6/site-packages/numpy/core/_methods.py", line 160, in _mean
ret = umr_sum(arr, axis, dtype, out, keepdims)
TypeError: unsupported operand type(s) for +: 'builtin_function_or_method' and 'builtin_function_or_method' '\

Another problem about pytorchtools

Error message:ImportError: cannot import name 'BalancedDataParallel' from 'pytorchtools'
Sorry to trouble you! I hava followed your method and solved "ModuleNotFoundError: No module named 'pytorchtools'" but met another problem. How can I solve it? Thanks.

Update: EarlyStop condition with default delta

Do you think the early stopping condition check in line number 36 in __cal__() function of EarlyStopping class should be less than or equal to instead of just less than?

The logic behind this is:
When we go with the default values, the delta is 0 this means the early stopping condition check defaults to:
score < self.best_score

Now consider the case when the score is not improving and it is constant throughout the epochs. In this case, the score will never be less than the best_score. Hence, this will not early stop.
In order to incorporate this case i.e. when score is not improving and is constant throughout the epochs, the EarlyStopping counter should start, and it should terminate eventually after we run out of patience.

Hence, the early stop condition check should change from:
elif score < self.best_score + self.delta:
to,
elif score <= self.best_score + self.delta:

Check for nan validation loss

If gradients are exploding, the loss function can return nan, which is interpreted as a decrease in the validation loss. An additional if statement should be added with an np.isnan(score) or similar

Small bug

I think you have a small bug in your code, you need to reset the early_stop variable as:

        self.best_score = score
        self.save_checkpoint(val_loss, model)
        self.counter = 0
        self.early_stop = False  # reset here

Otherwise, if you check this flag from outside, it will possibly stop earlier depending on what happened in the previous iterations.

Request: packaging

Packaging is pretty easy these days, just add a pyproject.toml. It would be cool to add one to this repo.

Then, people can at least pip install directly from GitHub like so:

pip install git+ssh://git@github.com/Bjarten/early-stopping-pytorch.git

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.