GithubHelp home page GithubHelp logo

few-shot-learning / keras-fewshotlearning Goto Github PK

View Code? Open in Web Editor NEW
211.0 5.0 41.0 2.15 MB

Some State-of-the-Art few shot learning algorithms in tensorflow 2

Python 64.74% Jupyter Notebook 35.26%
one-shot-learning few-shot-learning tensorflow keras-tensorflow deep-learning tf2 tensorflow2

keras-fewshotlearning's Introduction

Install Package Tests codecov

Currently supporting python 3.6, 3.7 and tensorflow ^2.1.

Welcome to keras-fsl!

As years go by, Few Shot Learning (FSL) and especially Metric Learning is becoming a hot topic not only in academic papers but also in production applications.

While a lot of researcher nowadays tend to publish their code on github, there is still no easy framework to get started with FSL. Especially when it comes to benchmarking existing models on personal datasets it is not always easy to find its path into each single repo. Not mentioning the Tensorflow/PyTorch issue.

This repo aims at filling this gap by providing a single entry-point for Few Shot Learning. It is deeply inspired by Keras because it shares the same philosophy:

It was developed with a focus on enabling fast experimentation. Being able to go from idea to result with the least possible delay is key to doing good research.

Thus this repo mainly relies on two of the main high-level python packages for data science: Keras and Pandas. While Pandas may not seem very useful for researchers working with static dataset, it becomes a strong backbone in production applications when you always need to tinker with your data.

Few-Shot Learning

Few-shot learning is a task consisting in classifying unseen samples into n classes (so called n way task) where each classes is only described with few (from 1 to 5 in usual benchmarks) examples.

Most of the state-of-the-art algorithms try to sort of learn a metric into a well suited (optimized) feature space. Thus deep networks usually first encode the base images into a feature space onto which a distance or similarity is learnt.

This similarity is meant to be used to later classify samples according to their relative distance, either in a pair-wise manner where the nearest support set samples is used to classify the query sample (Voronoi diagram) or in a more advanced classifier. Indeed, this philosophy is most commonly known as the kernel trick where the kernel is actually the similarity learnt during training. Hence any kind of usual kernel based Machine Learning could potentially be plugged onto this learnt similarity (see the min_eigenvalue metric to track eigenvalues of the learnt similarity to see if it as actually a kernel).

There is no easy answer to the optimal choice of such a classifier in the feature space. This may depend on performance as well as on complexity and real application parameters. For instance if the support set is strongly imbalanced, you may not want to fit an advanced classifier onto it but rather use a raw nearest neighbor approach.

All these considerations lead to the need of a code architecture that will let you play with these parameters with your own data in order to take the best from them.

Amongst other, the original Siamese Nets is usually known as the network from Koch et al. This algorithm learns a pair-wise similarity between images. More precisely it uses a densely connected layers on top of the absolute difference between the two embeddings to predict 0 (different) or 1 (same).

Actually, and as it is now expressed in recent papers, the representation learning framework is as follows:

  • a data augmentation module A
  • an encoder network E
  • a projection network P
  • a loss L

This repo mimics this framework by proving model builders and notebooks to implement current SOTA algorithms and your own tweaks seamlessly:

  • use tf.data.Dataset.map to apply data augmentation
  • define a tf.Keras.Sequential model for your encoder
  • define a kernel, ie a tf.keras.Layer with two inputs and a real-valued output (see head models)
  • use any support_layers to wrap the kernel and compute similarities in a tf.keras.Sequential manner (see notebooks for instance).
  • use any loss chosen accordingly to the output of the tf.keras.Sequential model (GramMatrix or CentroidsMatrix for instance)

As an example, the TripletLoss algorithm uses indeed:

  • data augmentation: whatever you want
  • encoder: any backbone like ResNet50 or MobileNet
  • kernel: the l2 norm: k(x, x') = ||x - x'||^2 = tf.keras.layers.Lambda(lambda inputs: tf.reduce_sum(tf.square(inputs[0] - inputs[1]), axis=1))
  • support_layer: triplet loss uses all the pair-wises distances, hence it is the GramMatrix
  • loss: k(a, p) + margin - k(a, n) with semi-hard mining (see triplet_loss)

Overview

This repos provides several tools for few-shot learning:

  • Keras layers and models
  • Keras sequences and Tensorflow datasets for training the models
  • Notebooks with proven learning sequences

All these tools can be used all together or separately. One may want to stick with the keras model trained on regular numpy arrays, with or without callbacks. When designing more advanced keras.Sequence or tf.data.Dataset for training, it is advised (and some examples are provided) to use Pandas though it is not necessary at all.

Feel free to experiment and share your thought on this repo by contributing to it!

Getting started

The notebooks section provides some examples. For instance, just run:

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.models import Sequential

from keras_fsl.models.encoders import BasicCNN
from keras_fsl.layers import GramMatrix
from keras_fsl.losses.gram_matrix_losses import BinaryCrossentropy
from keras_fsl.metrics.gram_matrix_metrics import classification_accuracy, min_eigenvalue
from keras_fsl.utils.tensors import get_dummies


#%% Get data
train_dataset, val_dataset, test_dataset = [
    dataset.shuffle(1024).batch(64).map(lambda x, y: (tf.image.convert_image_dtype(x, tf.float32), get_dummies(y)[0]))
    for dataset in tfds.load(name="omniglot", split=["train[:90%]", "train[90%:]", "test"], as_supervised=True)
]
input_shape = next(tfds.as_numpy(train_dataset.take(1)))[0].shape[1:]  # first shape is batch_size

#%% Training
encoder = BasicCNN(input_shape=input_shape)
support_layer = GramMatrix(kernel="DenseSigmoid")
model = Sequential([encoder, support_layer])
model.compile(optimizer="Adam", loss=BinaryCrossentropy(), metrics=[classification_accuracy(), min_eigenvalue])
model.fit(train_dataset, validation_data=val_dataset, epochs=5)

keras-fewshotlearning's People

Contributors

clementwalter avatar dependabot[bot] avatar julienperichon avatar wirg avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

keras-fewshotlearning's Issues

How to use the Classification layer

Hi Clement,

I think there is a problem with the support_set_loss attribute in the Classification layer.
Here is a minimal example to reproduce the bug:

import tensorflow as tf
from keras_fsl.layers import Classification
from keras_fsl.models.head_models import LearntNorms

kernel = LearntNorms(input_shape=(512,), activation="sigmoid")
classifier = Classification(kernel=kernel)

classifier.set_support_set(
    support_tensors=tf.random.uniform(shape=(10, 512)),
    support_labels_name=tf.constant(list("AAABBCCCDD")),
    overwrite=tf.constant(True),
)

Here is the error I got:

Traceback (most recent call last):
  File "<input>", line 4, in <module>
  File "/Users/toubi/.pyenv/versions/totem/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 580, in __call__
    result = self._call(*args, **kwds)
  File "/Users/toubi/.pyenv/versions/totem/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 627, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/Users/toubi/.pyenv/versions/totem/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 506, in _initialize
    *args, **kwds))
  File "/Users/toubi/.pyenv/versions/totem/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 2446, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/Users/toubi/.pyenv/versions/totem/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 2777, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/Users/toubi/.pyenv/versions/totem/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 2667, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/Users/toubi/.pyenv/versions/totem/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 981, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/Users/toubi/.pyenv/versions/totem/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 441, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/Users/toubi/.pyenv/versions/totem/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3299, in bound_method_wrapper
    return wrapped_fn(*args, **kwargs)
  File "/Users/toubi/.pyenv/versions/totem/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 968, in wrapper
    raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:
    /Users/toubi/.pyenv/versions/3.6.9/envs/totem/src/keras-fsl/keras_fsl/layers/classification.py:83 set_support_set  *
        self.support_set_loss.assign(class_consistency_loss(support_labels_one_hot, pair_wise_scores))
    /Users/toubi/.pyenv/versions/totem/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py:846 assign  **
        self._shape.assert_is_compatible_with(value_tensor.shape)
    /Users/toubi/.pyenv/versions/totem/lib/python3.6/site-packages/tensorflow/python/framework/tensor_shape.py:1117 assert_is_compatible_with
        raise ValueError("Shapes %s and %s are incompatible" % (self, other))
    ValueError: Shapes () and (None, None) are incompatible

The problem is that it cannot assign the tf variable support_set_loss with the new support set loss as shapes do not match. This is line 83 in keras_fsl/layers/classification.py:

self.support_set_loss.assign(class_consistency_loss(support_labels_one_hot, pair_wise_scores))

I was able to fix this by redefining the support_set_loss variable before calling set_support_set:

classifier.support_set_loss = tf.Variable(
    tf.Variable(np.zeros(16, dtype=np.float32).reshape(4, 4), name="support_set_loss"),
    name="support_set_loss",
)

Also, the is a problem with the Classification layer docstring: args do not correspond to the actual __init__ arguments:

Screenshot 2020-07-25 at 16 06 29

Input data format

Hi there, it's not clear to me what format the input data should be in! Is it as a tf.Dataset? Could you provide a sample data set to test with? For example, for multiple classes? Any advice would be welcomed.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.