GithubHelp home page GithubHelp logo

Comments (8)

david-wb avatar david-wb commented on June 14, 2024 1

@Zeleni9 Great question.

Using no_grad() without eval mode basically turns off batch normalization at inference time. This makes sense to me because the inference time data (webcam) is very different from the training data (unity eyes). You might try experimenting with setting track_running_stats=False in the batch norm layers and then running in eval mode, but you'll need to retrain the model. Please feel free to submit a PR if you find that this works. If you find other relevant info to shed more light on this issue also please comment again here, I'd appreciate it.

Thanks!

from gaze-estimation.

Zeleni9 avatar Zeleni9 commented on June 14, 2024

Hi David,

I have trained the model with nn.BatchNorm2d(features, track_running_stats=False) and some small changes to keep all operations in Float throughout the net. After training it for 10 epochs and running evaluation I got the mean angular error around 12-13%.

I did test model in torch.no_grad() and eyenet.eval() and it yields the similar results.

Results of model - 44999

error [17.87877716]
mean error 12.241872422782441
side right
gaze [[0.04301077 0.11114367]]
gaze pred [[0.05413113 0.42336184]]

But unfortunately the trained model is not valid when exported with ONNX, I am not sure about that being problem with track_running_stats=False or something else. Since I have exported it good with your pretrained model to ONNX format.

Update:
It is the track_running_stats=False that changes the nodes and then the graph is invalid in ONNX conversion discussion
With your code and only track_running_stats=False i got results like this:

error [5.87883281]
mean error 11.070940301345876
Len of mean error list: 45000
side right
gaze [[0.04301077 0.11114367]]
gaze pred [[-0.03775143 0.17444813]]

from gaze-estimation.

david-wb avatar david-wb commented on June 14, 2024

Hi, thanks for the updates and clarification. I have observed the same effects on my end using track_running_stats=False.

I can't help you with the ONNX conversion but I am watching your pytorch discussion.

from gaze-estimation.

Zeleni9 avatar Zeleni9 commented on June 14, 2024

Hey David,

I did train the model now with removing the nn.batchNorm() elements in the model and training went smooth. Evalulation on MPII Gaze was around:

Trained model without batch norm Eyenet - 44999

error [8.79008233]
mean error 10.929743261465049
Len of mean error list: 45000
side right
gaze [[0.04301077 0.11114367]]
gaze pred [[0.08114091 0.260042 ]

But the visualization is totaly off, gaze vectors jumping around, landmarks are not even close to the pupil. Did you do some post processing that makes it work on webcam, if yes how can I turn it off. I am bit confused with training with better evaluation error, but behaving way worse then with nn.batchNorm(features, track_running_stats=False).

I see that it is trained on Synthetic Eyes and webcam is different, but how can batchNorm layer change it that much?

Thanks.

from gaze-estimation.

david-wb avatar david-wb commented on June 14, 2024

The Residual layers use batch norm and it can't be turned off with a flag. I'm guessing that is why you see that behavior.

When you retrain make sure you start a new model file too.

from gaze-estimation.

Zeleni9 avatar Zeleni9 commented on June 14, 2024

No I didn't use the flag, I deleted code with BatchNorm completely in the Residual layers.

What do you mean by starting a new model file? I am removing checkopoint.pt for each training.

from gaze-estimation.

david-wb avatar david-wb commented on June 14, 2024

With BN completely removed, I'm surprised there would be any difference at eval time. Afraid I can't answer this one.

About starting a new model file, you are on the same page.

from gaze-estimation.

Zeleni9 avatar Zeleni9 commented on June 14, 2024

With BN completely removed, torch_no_grad() and eval() give the same result, but visualy on the webcam it looks very incorrect when comparining with for example BatchNorm with nn.batchNorm(features, track_running_stats=False).

Update:
Results were different in the .eval() beacuse I called .eval() on the model directly after loading from checkpoint.pt. I think it is possible it had random statistics. After I have run 5-10 inferences in .train() mode and then switched to .eval() I got the same correct results for both modes.

I have managed to convert model to ONNX format and it works. Interesting behaviour to be honest.

Well no problem, you helped a lot. Thank you for insightful information.
Best.

from gaze-estimation.

Related Issues (10)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.