bjarten / early-stopping-pytorch Goto Github PK
View Code? Open in Web Editor NEWEarly stopping for PyTorch
License: MIT License
Early stopping for PyTorch
License: MIT License
Even if we do pip install pytorchtools
, this error creeps up
cannot import name 'EarlyStopping' from 'pytorchtools' (/usr/local/lib/python3.7/dist-packages/pytorchtools/init.py)
early-stopping-pytorch/pytorchtools.py
Line 31 in effbcce
is this line should be changed like below... ?
score - self.best_score < self.delta
or
score - self.best_score <= self.delta
Is it normal that the validation set loss is less than the training set loss here?
!pip install pytorchtools
from pytorchtools import EarlyStopping
While running the above lines I am getting the error in this notebook
Hi. Thanks for your resources :) They are giving great help to my research.
BTW, isn't it correct to save the previous model when the loss decreases? It seems like your current early-stopping-pytorch.py script is intended to save the current model, not the previous one.
Hello @Bjarten
I am trying to use your approach for the early stopping. I installed pytorchtools by using "pip install pytorchtools"
Then I wanted to import EarlyStopping using "from pytorchtools import EarlyStopping", but eventually I received the following error:
ImportError: cannot import name 'EarlyStopping' from 'pytorchtools' (C:\Users\Name\anaconda3\envs\abys\lib\site-packages\pytorchtools_init_.py)
Can you please let me know why I am receiving this error and how to fix it potentially? Thank you in advance!
Hi I use that example in my code but i got that problem
Traceback (most recent call last):
File "/mnt/HDD/train_Model_1.py", line 237, in
train_loss = np.average(train_losses)
File "<array_function internals>", line 6, in average
File "/home/saida/.local/lib/python3.6/site-packages/numpy/lib/function_base.py", line 380, in average
avg = a.mean(axis)
File "/home/saida/.local/lib/python3.6/site-packages/numpy/core/_methods.py", line 160, in _mean
ret = umr_sum(arr, axis, dtype, out, keepdims)
TypeError: unsupported operand type(s) for +: 'builtin_function_or_method' and 'builtin_function_or_method' '\
Error message:ImportError: cannot import name 'BalancedDataParallel' from 'pytorchtools'
Sorry to trouble you! I hava followed your method and solved "ModuleNotFoundError: No module named 'pytorchtools'" but met another problem. How can I solve it? Thanks.
Do you think the early stopping condition check in line number 36
in __cal__()
function of EarlyStopping
class should be less than or equal to instead of just less than?
The logic behind this is:
When we go with the default values, the delta
is 0
this means the early stopping condition check defaults to:
score < self.best_score
Now consider the case when the score is not improving and it is constant throughout the epochs. In this case, the score
will never be less than the best_score
. Hence, this will not early stop.
In order to incorporate this case i.e. when score is not improving and is constant throughout the epochs, the EarlyStopping
counter should start, and it should terminate eventually after we run out of patience
.
Hence, the early stop condition check should change from:
elif score < self.best_score + self.delta:
to,
elif score <= self.best_score + self.delta:
If gradients are exploding, the loss function can return nan, which is interpreted as a decrease in the validation loss. An additional if statement should be added with an np.isnan(score) or similar
I think you have a small bug in your code, you need to reset the early_stop variable as:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0
self.early_stop = False # reset here
Otherwise, if you check this flag from outside, it will possibly stop earlier depending on what happened in the previous iterations.
Packaging is pretty easy these days, just add a pyproject.toml
. It would be cool to add one to this repo.
Then, people can at least pip install
directly from GitHub like so:
pip install git+ssh://git@github.com/Bjarten/early-stopping-pytorch.git
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.