GithubHelp home page GithubHelp logo

ykumards / simtorch Goto Github PK

View Code? Open in Web Editor NEW
7.0 2.0 1.0 912 KB

PyTorch library to compare similarity between NN representations

License: Apache License 2.0

Makefile 1.37% Python 96.04% Shell 2.58%
cka deep-learning interpretability neural-networks pytorch

simtorch's Introduction

simtorch

Downloads

A Pytorch library to measure the similarity between two neural network representations. The library currently supports the following (dis)similarity measures:

Design

The package consists of two components -

  • SimilarityModel - which is a thin wrapper on torch.nn.Module() which adds forwards hooks to store the layer-wise activations (aka representations) in a dictionary.
  • BaseSimilarity - which sets the interface for classes that compute similarity between network representations

Installation

The package is indexed by pypi

pip install simtorch

Usage

The torch model objects need to be wrapped with SimilarityModel. A list of names of the layers we wish to compute the representations is passed as an attribute to this class.

model1 = torchvision.models.densenet121()
model2 = torchvision.models.resnet101()

sim_model1 = SimilarityModel(
    model1,
    model_name="DenseNet 121",
    layers_to_include=["conv", "classifier",]
)

sim_model2 = SimilarityModel(
    model2,
    model_name="ResNet 101",
    layers_to_include=["conv", "fc",]
)

An instance of a similarity metric can then be initialized with these SimilarityModels. The compute() method can be used to obtain a similarity matrix $S$ for these two models where $S[i, j]$ is the similarity metric for the $i^{th}$ layer of the first model and the $j^{th}$ layer of the second model.

sim_cka = CKA(sim_model1, sim_model2, device="cuda")
cka_matrix = sim_cka.compute(torch_dataloader)

The similarity matrix can be visualized using the sim_cka.plot_similarity() method to obtain the CKA similarity plot

Centered Kernel Alignment Matrix

Citations

If you use Deconfounded Centered Kernel Alignment (dCKA) for your research, please cite:

@article{cui2022deconfounded,
  title={Deconfounded Representation Similarity for Comparison of Neural Networks},
  author={Cui, Tianyu and Kumar, Yogesh and Marttinen, Pekka and Kaski, Samuel},
  journal={Neural Information Processing Systems (NeurIPS)},
  year={2022}
}

Credits

This has been built by using the following awesome repos as reference:

simtorch's People

Contributors

dependabot[bot] avatar ykumards avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

xxyu2012

simtorch's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.