GithubHelp home page GithubHelp logo

Comments (31)

conceptofmind avatar conceptofmind commented on May 30, 2024 1

@lucidrains I forgot to mention that I used fp16 in the training above. This is likely one of the causes of numerical instability and NaN for this experiment. Since I did not want to alter the training script, I did not apply any stabilization techniques to shrink the token embedding gradient. For example: x = self.token_emb(x) * self.alpha + self.token_emb(x).detach() * (1 - self.alpha). This is what Tsinghua did to help stabilize training with fp16 for GLM130b. I can add this to a new script and post the results for training with fp16 again.

I will run the normal script again with fp32 on an A100 and document the results here as well.

Thank you,

Enrico

from flash-cosine-sim-attention.

lucidrains avatar lucidrains commented on May 30, 2024 1

@conceptofmind ohh yes the fp16 is the likely cause, as i was in the middle of fixing an underflow issue with the way cosine sim attention was approached in the CUDA kernel (which should be fixed in 0.1.38). If you have time to retry it on the latest version, that would be greatly appreciated!

are you using character level enwik8 for training from this repo, or modifying another gpt2 codebase? if modifying another codebase, could you share the code you have?

there is no need for the gradient shrinking technique from Tsinghua, as the whole idea behind the repository was to explore whether cosine sim attention can bring about greater stability without any cost

from flash-cosine-sim-attention.

conceptofmind avatar conceptofmind commented on May 30, 2024 1

@lucidrains I will update the repository and rerun the test again with fp16 enabled. I will post the new fp16 training results. I am using the character level enwiki8 for training from this repository for these tests. I have not made any alterations to the training script except for logging the loss with wandb. I did not want to change anything from your original work in order to remain consistent.

Here are the training results with fp32 on an A100 (40GB) for almost 40k steps:
Screenshot from 2022-11-01 12-40-30

The training has remained more stable.

I am additionally testing flash-cosine-sim-attention in another GPT-like model, a Vision Transformer, and a PALM model. I will post all of the code and results for these additional tests when I am confident everything meets a certain level of correctness. I will not apply the gradient shrinking technique from Tsinghua to any of these additional tests.

Thank you,

Enrico

from flash-cosine-sim-attention.

conceptofmind avatar conceptofmind commented on May 30, 2024 1

@lucidrains I am using the current CosineSimCausalTransformer available in the repository for the GPT-2 run. I believe the architecture used post-norm layers with DeepNorm. I did not see a specific place for a PreNorm wrapper or where pre-layernorm was explicitly defined. I saw that DeepNorm was applied to attn.to_v, attn.to_out, ff[0], and ff[2].

from flash-cosine-sim-attention.

lucidrains avatar lucidrains commented on May 30, 2024 1

@conceptofmind that would actually be great! i'll add a prenorm option tomorrow morning 🙏

from flash-cosine-sim-attention.

conceptofmind avatar conceptofmind commented on May 30, 2024 1

Sidenote I am going to spend the next weeks working on a Triton version of Flash Cosine Similarity Attention as well. I think it would be an interesting comparative benchmark!

from flash-cosine-sim-attention.

lucidrains avatar lucidrains commented on May 30, 2024 1

Nice! I added the prenorm option in the transformer this morning as a simple flag 0c260d1

Is it possible to move the plots you have above into the same graph for comparison?

from flash-cosine-sim-attention.

conceptofmind avatar conceptofmind commented on May 30, 2024 1

@lucidrains Of course! Here are the grouped graphs.

fp16 training without pre-layernorm (validating every 100)
Screenshot from 2022-11-02 12-28-39

fp16 training with pre-layernorm (validating every 10)
Screenshot from 2022-11-02 12-29-55

fp32 training with pre-layernorm (validating every 10)
Screenshot from 2022-11-02 12-30-49

I can start validating every step if that would be better as well.

I will post the results of training a PALM-like model soon.

Thank you,

Enrico

from flash-cosine-sim-attention.

lucidrains avatar lucidrains commented on May 30, 2024 1

@conceptofmind yup that looks good! so for PaLM, because it uses rotary embeddings, the l2 normalization needs to come before the rotation of the queries and keys

something like this in the readme

q, k = l2norm_tensors(q, k)

positions = self.get_rotary_embedding(n, device)
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))

out = self.attn_fn(q, k, v, causal = True, scale = scale, groups = l2norm_groups, **kwargs)

from flash-cosine-sim-attention.

conceptofmind avatar conceptofmind commented on May 30, 2024 1

@lucidrains I will make the adjustments to the model to do the l2 normalization before the rotation of the queries and keys, and post the results.

Thank you,

Enrico

from flash-cosine-sim-attention.

conceptofmind avatar conceptofmind commented on May 30, 2024 1

@lucidrains Here are the results for training the ViT-16 with flash-cosine-sim-attention on CIFAR10 for 100 epochs.

Train and Validation loss:
Screenshot from 2022-11-12 10-54-01
Train and Validation accuracy:
Screenshot from 2022-11-12 10-54-58
Training performance and accuracy were great. Validation shows possible overfitting which can be expected.

What do you think about trying flash-cosine-sim-attention in MEGA-pytorch?

Best,

Enrico

from flash-cosine-sim-attention.

lucidrains avatar lucidrains commented on May 30, 2024 1

🙏 thank you for running these experiments Enrico

do you think you could also run the same experiments against regular attention and compare the curves side by side? i am concerned about expressiveness issues after Robin's experiments

from flash-cosine-sim-attention.

lucidrains avatar lucidrains commented on May 30, 2024 1

@conceptofmind you can name your wandb runs by doing

wandb.run.name = 'regular attention'
wandb.run.save()

right after wandb.init

from flash-cosine-sim-attention.

conceptofmind avatar conceptofmind commented on May 30, 2024 1

@lucidrains Of course. I will train a ViT-16 with regular attention on CIFAR10 for 100 epochs and compare the curves side by side now. I will update you when everything is done running for the ViT experiments. All of these experiments were conducted on an A100 (40GB).

Here is a chart of the train and val losses for PaLM with and without flash-cosine-sim-attention:
Screenshot from 2022-11-14 23-10-40
Here is a link to that chart:
https://wandb.ai/please/my-test-project/reports/val-loss-train-loss-22-11-14-23-08-14---VmlldzoyOTcxODQx?accessToken=733c0o8hpgpklqq42phj7rem7rznxcny5tca2q7v7bnbxm9gwppa1p0gv5fbo19n
With Label smoothing:
Screenshot from 2022-11-14 23-12-13

from flash-cosine-sim-attention.

conceptofmind avatar conceptofmind commented on May 30, 2024 1

@lucidrains Here are the results for ViT-16 with and without flash-cosine-sim attention on CIFAR10 for 100 epochs with the same learning rate of 2e-4. I am using an A100(40 GB). Definitely good news!

Train/Validation Loss:
Screenshot from 2022-11-16 12-07-03

Train/Validation Accuracy:
Screenshot from 2022-11-16 12-09-49

So far from my testing, I have seen better accuracy and faster convergence with flash-cosine-sim attention. I will need to keep training more models and I am thinking about including some other improvements too. Possibly FastLayerNorm from Apex. I am getting everything set up for doing a training run on IMAGENET and will post the code for that soon.

I am still working on the Triton version (It is my first time using Triton) as well.

from flash-cosine-sim-attention.

lucidrains avatar lucidrains commented on May 30, 2024 1

@conceptofmind thank you Enrico

the results actually look more promising than i expected, if conditions are held equal between the two runs

from flash-cosine-sim-attention.

lucidrains avatar lucidrains commented on May 30, 2024 1

@conceptofmind if you have some time, do you think you could try flash cosine sim attention on some generative models and see if the FID scores between regular and cosine sim attention differ at all? that is what i worry about, as @rromb showed evidence that loss curves are not everything. however, the ViT accuracy curves you got from above do look good 🙏

from flash-cosine-sim-attention.

conceptofmind avatar conceptofmind commented on May 30, 2024 1

@lucidrains Absolutely! I am currently finishing setting everything up for a run on Imagenet with flash-cosine-sim-attention and will hopefully have the results for that soon. Huggingface has a standard version of Imagenet which fortunately can be downloaded in a reasonable amount of time.

Here is the training script for the run on Imagenet with no data augmentation:

import torch
import tqdm
import argparse
import wandb
from datasets import load_dataset
from transformers import AutoFeatureExtractor

from vit import ViT

wandb.init(project="my-test-project")
wandb.run.name = 'regular attention - imagenet'
wandb.run.save()

parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default = 4, type = int)
args = parser.parse_args()

BATCH_SIZE = args.batch_size
DEVICE = 'cuda'
EPOCHS = 100

imagenet_1k_train = load_dataset('imagenet-1k', 'train')
imagenet_1k_test = load_dataset('imagenet-1k', 'test')

model_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = AutoFeatureExtractor.from_pretrained(model_path)

def preprocess_function(images):
    image_tensors = [image.convert("RGB") for image in images['image']]
    inputs = feature_extractor(image_tensors, return_tensors="pt")
    inputs['label'] = images['label']
    return inputs

def collate_function(batches):
    pixel_values = torch.stack([batch['pixel_values'] for batch in batches])
    label = torch.tensor([batch['label'] for batch in batches])
    return {'pixel_values': pixel_values, 'label': label}

train_dataset = imagenet_1k_train.with_transform(preprocess_function)
test_dataset = imagenet_1k_test.with_transform(preprocess_function)

train_loader = torch.utils.data.DataLoader(
    train_dataset['train'], 
    batch_size=BATCH_SIZE, 
    shuffle=True,
    collate_fn=collate_function,
    )

test_loader = torch.utils.data.DataLoader(
    test_dataset['test'],
    batch_size=BATCH_SIZE,
    collate_fn=collate_function,
    )

model = ViT(
    image_size = 224,
    patch_size = 16,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
).to(DEVICE)

criterion = torch.nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr = 2e-4)

for epoch in tqdm.tqdm(range(EPOCHS), desc='training'):
    epoch_loss = 0
    epoch_acc = 0
    for batch in train_loader:
        images = batch['pixel_values'].to(DEVICE)
        labels = batch['label'].to(DEVICE)

        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (outputs.argmax(dim = 1) == labels).float().mean()
        epoch_acc += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        epoch_val_acc = 0
        epoch_val_loss = 0
        
        for batch in test_loader:
            images = batch['pixel_values'].to(DEVICE)
            labels = batch['label'].to(DEVICE)

            val_output = model(images)
            val_loss = criterion(val_output, labels)

            acc = (val_output.argmax(dim=1) == labels).float().mean()
            epoch_val_acc += acc / len(test_loader)
            epoch_val_loss += val_loss / len(test_loader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_acc:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_acc:.4f}\n"
    )
    wandb.log({"epoch": epoch, "train loss": epoch_loss, "train acc": epoch_acc, "val loss": epoch_val_loss, "val acc": epoch_val_acc})

I am looking into testing flash-cosine-sim-attention with pytorch ddp, deepspeed or oslo for distributed computing.

Is there a specific diffusion or generative model which you want to be run on CIFAR10? I can do a wide range of them as well.

Best,

Enrico

from flash-cosine-sim-attention.

lucidrains avatar lucidrains commented on May 30, 2024

@conceptofmind Hi Enrico again and thank you for running this experiment

Was the above run done in f32 or f16?

from flash-cosine-sim-attention.

lucidrains avatar lucidrains commented on May 30, 2024

Thanks Enrico! Definitely also compare the run to non-cosine sim attention, and obtain a validation curve while you are at it 🙏

from flash-cosine-sim-attention.

lucidrains avatar lucidrains commented on May 30, 2024

@conceptofmind is the GPT-2 run using the pre-layernorm architecture?

from flash-cosine-sim-attention.

lucidrains avatar lucidrains commented on May 30, 2024

@conceptofmind ah ok! thanks for clearing that up!

from flash-cosine-sim-attention.

conceptofmind avatar conceptofmind commented on May 30, 2024

@lucidrains If I missed something or you want me to add a PreNorm to the Attention layer. I am more than willing to test with that as well.

from flash-cosine-sim-attention.

conceptofmind avatar conceptofmind commented on May 30, 2024

@lucidrains Here are the results for fp16 training without pre-layernorm for 30k steps on an A100 (40GB). The recent update greatly improved numerical stability for fp16 training.

Training loss:
Screenshot from 2022-11-01 21-41-19

Validation Loss (Validating every 100 steps):
Screenshot from 2022-11-01 21-42-56

I will provide an update for fp16 with pre-layernorm when it gets around 30k steps. I will also train one model with fp32 and pre-layernorm as well as one model with non-cosine sim attention. So an additional 3 baseline tests!

Thank you,

Enrico

from flash-cosine-sim-attention.

conceptofmind avatar conceptofmind commented on May 30, 2024

@lucidrains Here are the results for fp16 training with pre-layernorm for 30k steps on an A100 (40GB). Training remained more stable as well. I changed to validating every 10 steps as that gave a better idea of the results.

Training loss:
Screenshot from 2022-11-01 23-03-37

Validation Loss (Validating every 10 steps):
Screenshot from 2022-11-01 23-04-18

Here is the code for the slight change made to attention to include Pre-LayerNorm:

class Attention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8,
        scale = 8,
        l2norm_groups = 1,
        use_cuda_kernel = False,
        **kwargs
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.scale = scale
        self.heads = heads

        self.norm = nn.LayerNorm(dim)

        self.l2norm_groups = l2norm_groups
        self.attn_fn = plain_cosine_sim_attention if not use_cuda_kernel else partial(flash_cosine_sim_attention, **kwargs)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_k = nn.Linear(dim, inner_dim, bias = False)
        self.to_v = nn.Linear(dim, inner_dim, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        h, scale, l2norm_groups = self.heads, self.scale, self.l2norm_groups

        # pre layernorm
        x = self.norm(x)

        q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        o = self.attn_fn(q, k, v, causal = True, scale = scale, groups = l2norm_groups)

        o = rearrange(o, 'b h n d -> b n (h d)')
        return self.to_out(o)

I am training the model with fp32 and pre-layernorm now.

Thank you,

Enrico

from flash-cosine-sim-attention.

conceptofmind avatar conceptofmind commented on May 30, 2024

Here are the results for fp32 training with pre-layernorm for 30k steps on an A100 (40GB).

Training loss:
Screenshot from 2022-11-02 11-38-00

Validation Loss (Validating every 10 steps):
Screenshot from 2022-11-01 23-04-18

from flash-cosine-sim-attention.

conceptofmind avatar conceptofmind commented on May 30, 2024

Results for training standard PaLM on an A100 (40 GB) for 30k steps:

  • 2.15it/s
  • Sequence Length 1024
  • fp32
    Screenshot from 2022-11-02 17-56-01

from flash-cosine-sim-attention.

conceptofmind avatar conceptofmind commented on May 30, 2024

@lucidrains Here is the code for the PaLM model with flash cosine sim attention. The model is currently training and I will update the results likely later tonight or tomorrow morning.

  • 5.94s/it
  • Sequence Length 8192
  • fp32
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import einsum, nn
from functools import partial
from flash_cosine_sim_attention import flash_cosine_sim_attention

# normalization
# they use layernorm without bias, something that pytorch does not offer


class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer("beta", torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# residual


class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x


# rotary positional embedding
# https://arxiv.org/abs/2104.09864


class RotaryEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, max_seq_len, *, device):
        seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
        freqs = einsum("i , j -> i j", seq, self.inv_freq)
        return torch.cat((freqs, freqs), dim=-1)


def rotate_half(x):
    x = rearrange(x, "... (j d) -> ... j d", j=2)
    x1, x2 = x.unbind(dim=-2)
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(pos, t):
    return (t * pos.cos()) + (rotate_half(t) * pos.sin())


# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
# https://arxiv.org/abs/2002.05202


class SwiGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return F.silu(gate) * x


# parallel attention and feedforward with residual
# discovered by Wang et al + EleutherAI from GPT-J fame


class ParallelTransformerBlock(nn.Module):
    def __init__(
        self, 
        dim, 
        dim_head=64, 
        heads=8,
        scale = 8,
        l2norm_groups = 1, 
        ff_mult=4,
        **kwargs
    ):
        super().__init__()
        self.norm = LayerNorm(dim)

        attn_inner_dim = dim_head * heads
        ff_inner_dim = dim * ff_mult
        self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))

        self.attn_fn = partial(flash_cosine_sim_attention, **kwargs)

        self.heads = heads
        self.scale = scale
        self.l2norm_groups = l2norm_groups
        self.rotary_emb = RotaryEmbedding(dim_head)

        self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
        self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)

        self.ff_out = nn.Sequential(
            SwiGLU(),
            nn.Linear(ff_inner_dim, dim, bias=False)
        )

        # for caching rotary embeddings

        self.register_buffer("pos_emb", None, persistent=False)

    def get_rotary_embedding(self, n, device):
        if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
            return self.pos_emb[:n]

        pos_emb = self.rotary_emb(n, device=device)
        self.register_buffer("pos_emb", pos_emb, persistent=False)
        return pos_emb

    def forward(self, x):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        n, device, h, scale, l2norm_groups = x.shape[1], x.device, self.heads, self.scale, self.l2norm_groups

        # pre layernorm

        x = self.norm(x)

        # attention queries, keys, values, and feedforward inner

        q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)

        # split heads
        # they use multi-query single-key-value attention, yet another Noam Shazeer paper
        # they found no performance loss past a certain scale, and more efficient decoding obviously
        # https://arxiv.org/abs/1911.02150

        q = rearrange(q, "b n (h d) -> b h n d", h=h)

        # rotary embeddings

        positions = self.get_rotary_embedding(n, device)
        q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))

        # flash cosine similarity attention

        out = self.attn_fn(q, k, v, causal = True, scale = scale, groups = l2norm_groups)

        # merge heads

        out = rearrange(out, "b h n d -> b n (h d)")
        return self.attn_out(out) + self.ff_out(ff)


# transformer


def PaLM_flash(
    *, 
    dim, 
    num_tokens, 
    depth,
    attn_scale = 8,
    attn_l2norm_groups = 1, 
    dim_head=64, 
    heads=8, 
    ff_mult=4,
    **kwargs
):
    net = nn.Sequential(
        nn.Embedding(num_tokens, dim),
        *[
            Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult, scale=attn_scale, groups=attn_l2norm_groups, **kwargs))
            for _ in range(depth)
        ],
        LayerNorm(dim),
        nn.Linear(dim, num_tokens, bias=False)
    )

    # they used embedding weight tied projection out to logits, not common, but works
    net[-1].weight = net[0].weight

    nn.init.normal_(net[0].weight, std=0.02)
    return net

from flash-cosine-sim-attention.

conceptofmind avatar conceptofmind commented on May 30, 2024

@lucidrains Here are the runs for PaLM with flash-cosine-sim-attention.

  • 6.52s/it
  • Sequence Length 8192
  • fp32

For 14k steps:
Screenshot from 2022-11-10 14-40-17

And for the whole training run:
Screenshot from 2022-11-10 14-40-47

I am working on the ViT now.

from flash-cosine-sim-attention.

conceptofmind avatar conceptofmind commented on May 30, 2024

@lucidrains Here is the code for a ViT-16 with flash-cosine-sim-attention:

import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

from flash_cosine_sim_attention import flash_cosine_sim_attention

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        out = flash_cosine_sim_attention(q, k, v)
        
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

Here is the training script for the ViT on CIFAR10:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision import transforms as T
import tqdm

import wandb

from vit_cosine_sim_flash import ViT

wandb.init(project="my-test-project")

DEVICE = 'cuda'
IMAGE_SIZE = 224
BATCH_SIZE = 4
LEARNING_RATE = 6e-4
EPOCHS = 100

train_transform = T.Compose([
    T.Resize(IMAGE_SIZE),
    T.AutoAugment(policy = T.AutoAugmentPolicy.CIFAR10),
    T.ToTensor(),
    T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

test_transform = T.Compose([
    T.Resize(IMAGE_SIZE),
    T.ToTensor(),
    T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])    

train_dataset = CIFAR10(
    root = './cifar_data_train/', 
    train = True,
    download = True,
    transform = train_transform,
)

test_dataset = CIFAR10(
    root = './cifar_data_train/',
    train = False,
    download = True,
    transform = test_transform,
)

train_loader = DataLoader(
    train_dataset, 
    shuffle = True,
    batch_size = BATCH_SIZE, 
)

test_loader = DataLoader(
    test_dataset, 
    batch_size = BATCH_SIZE,
)

model = ViT(
    image_size = IMAGE_SIZE,
    patch_size = 16,
    num_classes = 10,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(
    model.parameters(), 
    lr = LEARNING_RATE,
)

for epoch in tqdm.tqdm(range(EPOCHS), desc='training'):
    epoch_loss = 0
    epoch_acc = 0
    for images, labels in train_loader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (outputs.argmax(dim = 1) == labels).float().mean()
        epoch_acc += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        epoch_val_acc = 0
        epoch_val_loss = 0
        
        for images, labels in test_loader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            val_output = model(images)
            val_loss = criterion(val_output, labels)

            acc = (val_output.argmax(dim=1) == labels).float().mean()
            epoch_val_acc += acc / len(test_loader)
            epoch_val_loss += val_loss / len(test_loader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_acc:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_acc:.4f}\n"
    )
    wandb.log({"epoch": epoch, "train loss": epoch_loss, "train acc": epoch_acc, "val loss": epoch_val_loss, "val acc": epoch_val_acc})

I will post the training results soon.

from flash-cosine-sim-attention.

conceptofmind avatar conceptofmind commented on May 30, 2024

@lucidrains Here are the results for the ViT-16 experiments with and without flash-cosine-sim attention.

For regular attention I used a learning rate of 2e-4. For flash-cosine-sim I tested with a learning rate of 6e-4.

Train/validation loss:
Screenshot from 2022-11-15 18-36-34

Train/validation accuracy:
Screenshot from 2022-11-15 18-36-54

I am running flash-cosine-sim again with a learning rate of 2e-4 this time instead. I will provide an update with that soon.

Best,

Enrico

from flash-cosine-sim-attention.

Related Issues (10)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.