boyuanjiang / matching-networks-pytorch Goto Github PK
View Code? Open in Web Editor NEWMatching Networks for one shot learning
Matching Networks for one shot learning
Since the update of Pytorch has changed its API for some functions, thus the version >= 0.4.0 is not compatible with this code.
You need to create a new conda environment with python3 and Pytorch 0.3.1 to run.
Otherwise, you will find it stack with problem:"TypeError iteration over a 0-d tensor"
That's all.
Do you know what the final accuracies of this model trained on Omniglot are?
The results from the paper are:
5-way/1-shot = 98.1%
5-way/5-shot = 98.9%
20-way/1-shot = 93.8%
20-way/5-shot = 98.5%
Thanks
Ben
The training objective in the original paper is theta, should the loss value be theta? Why is loss in your program a cross entropy loss?
In the forward function of matching_net.py, output of self.g is passed to self.lstm, but the output of self.lstm is not used.
Note the 'output' becomes 'outputs' in 'outputs = self.lstm(output)'
# use fce?
if self.fce:
outputs = self.lstm(output)
# get similarities between support set embeddings and target
similarites = self.dn(support_set=output[:-1], input_image=output[-1])
File "mainOmniglot.py", line 1, in
from data_loader import OmniglotNShotDataset
ValueError: source code string cannot contain null bytes
Can any help?Thank you...
You need to change code:
line 133 def repackage_hidden(self,h): """Wraps hidden states in new Variables, to detach them from their history.""" if type(h) == Variable: return Variable(h.data) else: return tuple(self.repackage_hidden(v) for v in h)
to
def repackage_hidden(self, h): """Wraps hidden states in new Variables, to detach them from their history.""" if isinstance(h, torch.Tensor): return h.detach() else: return tuple(self.repackage_hidden(v) for v in h)
and all acc.data[0], loss.data[0] should be changed to acc.item(),loss.item().
then you can run the code with newest pytorch.
The source project:https://github.com/gitabcworld/MatchingNetworks
I have a little question about the testing.
In the testing part, the query image always comes from the 5 support classes. But, in practice, there is a totally unknown query coming, how we choose support set (because we don't know its class).
Actually I have tested this query with all possible support sets, but I got many high confidences.
This does not make any sense, we can NOT recognise the query by the output confidence.
That makes me confused about how to use matching net in real practice.
Is my understanding all right?
In the source code, the author calculates the cosine distance as follows.
sum_support = torch.sum(torch.pow(support_image, 2), 1)
support_manitude = sum_support.clamp(eps, float("inf")).rsqrt()
dot_product = input_image.unsqueeze(1).bmm(support_image.unsqueeze(2)).squeeze()
cosine_similarity = dot_product * support_manitude * input_manitude
similarities.append(cosine_similarity)
But in my opinion, the right the cosine distance should be calculated as follows.
sum_support = torch.sum(torch.pow(support_image, 2), 1)
support_manitude = sum_support.clamp(eps, float("inf")).rsqrt()
sum_input = torch.sum(torch.pow(input_image, 2), 1)
input_manitude = sum_input.clamp(eps, float("inf")).rsqrt()
dot_product = input_image.unsqueeze(1).bmm(support_image.unsqueeze(2)).squeeze()
cosine_similarity = dot_product * support_manitude * input_manitude
similarities.append(cosine_similarity)
Am i right? If not, what is the mistake?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.