GithubHelp home page GithubHelp logo

geometry_representations's Introduction

The geometry of hidden representations of large transformer models

Source code of the paper: 'The geometry of hidden representations of large transformer models'. This work has been included in the NeurIPS 2023 proceedings.

Platform

We performed the experiments on a Linux cluster with Intel Xeon Gold 6226 processors and V100 GPUs with 32GB of VRAM. We ran the code in Ubuntu 22.04.

Premise

The results of the paper rely on intrinsic dimension and neighborhood overlap computation. We use the implementation of intrinsic dimension and neighborhood overlap of DADApy.

With Dadapy, you can compute the intrinsic dimension (Fig. 1, 3) of the dataset representation at a given layer X as follows:

from dadapy.data import Data

# Initialize the Data class with the layer representation X.
# X must be 2d-numpy array with shape N x d. N is the dataset size, and d is the embedding dimension. 
data = Data(X)

# compute the intrinsic dimension using 2nn estimator
id_list_2NN, _, _ = data.return_id_scaling_2NN()

# compute the intrinsic dimension up to the 64th nearest neighbors using Gride
id_list_gride, _, _ = data.return_id_scaling_gride()

The two methods provide similar results; you can choose either of them. The second is slightly faster and more robust. The above methods output a list of intrinsic dimensions; check Appendix B of the paper to see how we select the ID we plot in the figures.

The overlap with the labels Y (Fig. 4) can be computed as:

# Y (shape: N) must be a 1d-numpy array with the integer class label of each example.
overlap_labels = data.return_label_overlap(Y, k=30)

In the paper, we also compute the overlap between pairs of representations (Fig. 2). If X2 is a second representation (shape: N x d2), the overlap between X and X2 can be computed as:

overlap_X2 = data.return_data_overlap(X2, k=30)

In the following, we provide the code to reproduce the paper's results.

In 1. Reproduce the paper plots, the code reproduces some paper plots starting from some precomputed statistics (mu_ratios for ID and nearest neighbor indices for the overlaps. See the Method section of the paper for the meaning of these quantities). We use some Dadapy functions to compute ID and overlap for iGPT.
For esm2, we directly provide the precomputed IDs and overlap with labels.

In 2. Extract the representations, the code extracts the distance matrices required for the ID and overlap computation from iGPT.



1. Reproduce the paper plots

a. Build an environment with the required dependencies

You can get miniconda from https://docs.conda.io/en/latest/miniconda.html. Then, install the dependencies:

conda create -n geometry_representations python=3.11 pip
conda activate geometry_representations
pip install -r requirements.txt   

b. Download the computed intrinsic dimension and overlaps.

The download.py script downloads the numpy arrays needed to reproduce the plots shown in the paper.

python download.py 

c. Plot the intrinsic dimension and overlap profiles

You can plot the intrinsic dimension profiles (Fig. 1) and the overlap with the class labels (Fig. 4).

python plot_id_overlap.py 

The plots shown below are saved by default in "./results/plots".

The iGPT intrinsic dimension and overlaps are computed using the nearest neighbors' distance matrices you downloaded in 2. For the neighborhood overlap (bottom-right), we provide only a small number of checkpoints.

You can use the code from the following section to extract the distance matrices of all the layers in iGPT.



2. Extract the representations and compute the distance matrices of iGPT

We first need to download the pre-trained iGPT models following the instructions in https://github.com/openai/image-gpt. The environment with the required dependencies can be created as follows:

conda create --name image-gpt python=3.7.3 pip
conda activate image-gpt

conda install numpy=1.16.3
conda install tensorflow-gpu=1.13.1

conda install imageio=2.8.0
conda install requests=2.21.0
conda install tqdm=4.46.0
pip install -U scikit-learn

We added the scikit-learn package to the environment of https://github.com/openai/image-gpt. We use it to compute the distance matrices.


a. Download the i-gpt pretrained models.

You can download the iGPT-small model and the ImageNet dataset (training, validation, test sets) with:

python src/download_igpt.py
--model s
--ckpt 1000000
--dataset imagenet
--download_dir igpt_models

--model 's' means that we download the small version of iGPT;
--ckpt 1000000 is the training checkpoint to download. 1 000 000 means fully trained model;
--dataset imagenet means that we download the ImageNet dataset;
--download_dir igpt_models is the directory where the pre-trained model and ImageNet dataset are stored.

The disk memory occupied by models and datasets is as follows:
iGPT-small: 894MB;
iGPT-medium: 5.2GB;
iGPT-large 15.5GB;
ImageNet dataset 11GB.


b. Compute the nearest neighbor matrix.

With the following command, you will extract the 24 hidden representations of iGPT-small analyzed in the paper.

In a V100 GPU with 32GB of V-RAM, extracting the representations of 90k examples takes around one hour and requires 32GB of V-RAM (with a batch size of 8). You can decrease the GPU memory requirement by decreasing the batch size. After the representations have been extracted, the computation of the 24 distance matrices takes another 30 minutes (for 90k examples).

The representations are stored in RAM during the process. Depending on your memory budget, you may need to extract smaller subsets of layers.

python  src/run.py 
--data_dir igpt_models 
--ckpt_dir igpt_models 
--model "s" 
--results_dir "./results" 
--nimg_cat 300 
--n_sub_batch 8 

--ckpt_path is the directory where you stored the model checkpoints downloaded in a.;
--model 's' means that you are analyzing the small model;
--data_dir is the directory where you stored the ImageNet dataset as downloaded in a.;
--results_dir is the directory where the representations/distance matrices are saved;
--nimg_cat is the number of images per class analyzed (300 in the paper);
--n_sub_batch is the batch size.

In the run.py, we extract only the 300 classes from the ImageNet TRAINING SET analyzed in the paper.
The class labels are stored in the './hier_nucl_labels.npy' array.


c. Extract the hidden layer representations.

If you just want to extract the hidden layer representations, add the --activations argument to the previous ones:

python  src/run.py 
--activations 
--data_dir igpt_models 
--ckpt_dir igpt_models 
--model "s" 
--results_dir "./results" 
--nimg_cat 300 
--n_sub_batch 8 

With this setup, the distance matrices are not computed.

geometry_representations's People

Contributors

diegodoimo avatar valerianilucrezia avatar

Stargazers

Isaac Ellmen avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.