Comments (31)
@lucidrains I forgot to mention that I used fp16 in the training above. This is likely one of the causes of numerical instability and NaN for this experiment. Since I did not want to alter the training script, I did not apply any stabilization techniques to shrink the token embedding gradient. For example: x = self.token_emb(x) * self.alpha + self.token_emb(x).detach() * (1 - self.alpha)
. This is what Tsinghua did to help stabilize training with fp16 for GLM130b. I can add this to a new script and post the results for training with fp16 again.
I will run the normal script again with fp32 on an A100 and document the results here as well.
Thank you,
Enrico
from flash-cosine-sim-attention.
@conceptofmind ohh yes the fp16 is the likely cause, as i was in the middle of fixing an underflow issue with the way cosine sim attention was approached in the CUDA kernel (which should be fixed in 0.1.38). If you have time to retry it on the latest version, that would be greatly appreciated!
are you using character level enwik8 for training from this repo, or modifying another gpt2 codebase? if modifying another codebase, could you share the code you have?
there is no need for the gradient shrinking technique from Tsinghua, as the whole idea behind the repository was to explore whether cosine sim attention can bring about greater stability without any cost
from flash-cosine-sim-attention.
@lucidrains I will update the repository and rerun the test again with fp16 enabled. I will post the new fp16 training results. I am using the character level enwiki8 for training from this repository for these tests. I have not made any alterations to the training script except for logging the loss with wandb. I did not want to change anything from your original work in order to remain consistent.
Here are the training results with fp32 on an A100 (40GB) for almost 40k steps:
The training has remained more stable.
I am additionally testing flash-cosine-sim-attention in another GPT-like model, a Vision Transformer, and a PALM model. I will post all of the code and results for these additional tests when I am confident everything meets a certain level of correctness. I will not apply the gradient shrinking technique from Tsinghua to any of these additional tests.
Thank you,
Enrico
from flash-cosine-sim-attention.
@lucidrains I am using the current CosineSimCausalTransformer available in the repository for the GPT-2 run. I believe the architecture used post-norm layers with DeepNorm. I did not see a specific place for a PreNorm wrapper or where pre-layernorm was explicitly defined. I saw that DeepNorm was applied to attn.to_v, attn.to_out, ff[0], and ff[2].
from flash-cosine-sim-attention.
@conceptofmind that would actually be great! i'll add a prenorm option tomorrow morning 🙏
from flash-cosine-sim-attention.
Sidenote I am going to spend the next weeks working on a Triton version of Flash Cosine Similarity Attention as well. I think it would be an interesting comparative benchmark!
from flash-cosine-sim-attention.
Nice! I added the prenorm option in the transformer this morning as a simple flag 0c260d1
Is it possible to move the plots you have above into the same graph for comparison?
from flash-cosine-sim-attention.
@lucidrains Of course! Here are the grouped graphs.
fp16 training without pre-layernorm (validating every 100)
fp16 training with pre-layernorm (validating every 10)
fp32 training with pre-layernorm (validating every 10)
I can start validating every step if that would be better as well.
I will post the results of training a PALM-like model soon.
Thank you,
Enrico
from flash-cosine-sim-attention.
@conceptofmind yup that looks good! so for PaLM, because it uses rotary embeddings, the l2 normalization needs to come before the rotation of the queries and keys
something like this in the readme
q, k = l2norm_tensors(q, k)
positions = self.get_rotary_embedding(n, device)
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
out = self.attn_fn(q, k, v, causal = True, scale = scale, groups = l2norm_groups, **kwargs)
from flash-cosine-sim-attention.
@lucidrains I will make the adjustments to the model to do the l2 normalization before the rotation of the queries and keys, and post the results.
Thank you,
Enrico
from flash-cosine-sim-attention.
@lucidrains Here are the results for training the ViT-16 with flash-cosine-sim-attention on CIFAR10 for 100 epochs.
Train and Validation loss:
Train and Validation accuracy:
Training performance and accuracy were great. Validation shows possible overfitting which can be expected.
What do you think about trying flash-cosine-sim-attention in MEGA-pytorch?
Best,
Enrico
from flash-cosine-sim-attention.
🙏 thank you for running these experiments Enrico
do you think you could also run the same experiments against regular attention and compare the curves side by side? i am concerned about expressiveness issues after Robin's experiments
from flash-cosine-sim-attention.
@conceptofmind you can name your wandb runs by doing
wandb.run.name = 'regular attention'
wandb.run.save()
right after wandb.init
from flash-cosine-sim-attention.
@lucidrains Of course. I will train a ViT-16 with regular attention on CIFAR10 for 100 epochs and compare the curves side by side now. I will update you when everything is done running for the ViT experiments. All of these experiments were conducted on an A100 (40GB).
Here is a chart of the train and val losses for PaLM with and without flash-cosine-sim-attention:
Here is a link to that chart:
https://wandb.ai/please/my-test-project/reports/val-loss-train-loss-22-11-14-23-08-14---VmlldzoyOTcxODQx?accessToken=733c0o8hpgpklqq42phj7rem7rznxcny5tca2q7v7bnbxm9gwppa1p0gv5fbo19n
With Label smoothing:
from flash-cosine-sim-attention.
@lucidrains Here are the results for ViT-16 with and without flash-cosine-sim attention on CIFAR10 for 100 epochs with the same learning rate of 2e-4. I am using an A100(40 GB). Definitely good news!
So far from my testing, I have seen better accuracy and faster convergence with flash-cosine-sim attention. I will need to keep training more models and I am thinking about including some other improvements too. Possibly FastLayerNorm from Apex. I am getting everything set up for doing a training run on IMAGENET and will post the code for that soon.
I am still working on the Triton version (It is my first time using Triton) as well.
from flash-cosine-sim-attention.
@conceptofmind thank you Enrico
the results actually look more promising than i expected, if conditions are held equal between the two runs
from flash-cosine-sim-attention.
@conceptofmind if you have some time, do you think you could try flash cosine sim attention on some generative models and see if the FID scores between regular and cosine sim attention differ at all? that is what i worry about, as @rromb showed evidence that loss curves are not everything. however, the ViT accuracy curves you got from above do look good 🙏
from flash-cosine-sim-attention.
@lucidrains Absolutely! I am currently finishing setting everything up for a run on Imagenet with flash-cosine-sim-attention and will hopefully have the results for that soon. Huggingface has a standard version of Imagenet which fortunately can be downloaded in a reasonable amount of time.
Here is the training script for the run on Imagenet with no data augmentation:
import torch
import tqdm
import argparse
import wandb
from datasets import load_dataset
from transformers import AutoFeatureExtractor
from vit import ViT
wandb.init(project="my-test-project")
wandb.run.name = 'regular attention - imagenet'
wandb.run.save()
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default = 4, type = int)
args = parser.parse_args()
BATCH_SIZE = args.batch_size
DEVICE = 'cuda'
EPOCHS = 100
imagenet_1k_train = load_dataset('imagenet-1k', 'train')
imagenet_1k_test = load_dataset('imagenet-1k', 'test')
model_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = AutoFeatureExtractor.from_pretrained(model_path)
def preprocess_function(images):
image_tensors = [image.convert("RGB") for image in images['image']]
inputs = feature_extractor(image_tensors, return_tensors="pt")
inputs['label'] = images['label']
return inputs
def collate_function(batches):
pixel_values = torch.stack([batch['pixel_values'] for batch in batches])
label = torch.tensor([batch['label'] for batch in batches])
return {'pixel_values': pixel_values, 'label': label}
train_dataset = imagenet_1k_train.with_transform(preprocess_function)
test_dataset = imagenet_1k_test.with_transform(preprocess_function)
train_loader = torch.utils.data.DataLoader(
train_dataset['train'],
batch_size=BATCH_SIZE,
shuffle=True,
collate_fn=collate_function,
)
test_loader = torch.utils.data.DataLoader(
test_dataset['test'],
batch_size=BATCH_SIZE,
collate_fn=collate_function,
)
model = ViT(
image_size = 224,
patch_size = 16,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
).to(DEVICE)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 2e-4)
for epoch in tqdm.tqdm(range(EPOCHS), desc='training'):
epoch_loss = 0
epoch_acc = 0
for batch in train_loader:
images = batch['pixel_values'].to(DEVICE)
labels = batch['label'].to(DEVICE)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc = (outputs.argmax(dim = 1) == labels).float().mean()
epoch_acc += acc / len(train_loader)
epoch_loss += loss / len(train_loader)
with torch.no_grad():
epoch_val_acc = 0
epoch_val_loss = 0
for batch in test_loader:
images = batch['pixel_values'].to(DEVICE)
labels = batch['label'].to(DEVICE)
val_output = model(images)
val_loss = criterion(val_output, labels)
acc = (val_output.argmax(dim=1) == labels).float().mean()
epoch_val_acc += acc / len(test_loader)
epoch_val_loss += val_loss / len(test_loader)
print(
f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_acc:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_acc:.4f}\n"
)
wandb.log({"epoch": epoch, "train loss": epoch_loss, "train acc": epoch_acc, "val loss": epoch_val_loss, "val acc": epoch_val_acc})
I am looking into testing flash-cosine-sim-attention with pytorch ddp, deepspeed or oslo for distributed computing.
Is there a specific diffusion or generative model which you want to be run on CIFAR10? I can do a wide range of them as well.
Best,
Enrico
from flash-cosine-sim-attention.
@conceptofmind Hi Enrico again and thank you for running this experiment
Was the above run done in f32 or f16?
from flash-cosine-sim-attention.
Thanks Enrico! Definitely also compare the run to non-cosine sim attention, and obtain a validation curve while you are at it 🙏
from flash-cosine-sim-attention.
@conceptofmind is the GPT-2 run using the pre-layernorm architecture?
from flash-cosine-sim-attention.
@conceptofmind ah ok! thanks for clearing that up!
from flash-cosine-sim-attention.
@lucidrains If I missed something or you want me to add a PreNorm to the Attention layer. I am more than willing to test with that as well.
from flash-cosine-sim-attention.
@lucidrains Here are the results for fp16 training without pre-layernorm for 30k steps on an A100 (40GB). The recent update greatly improved numerical stability for fp16 training.
Validation Loss (Validating every 100 steps):
I will provide an update for fp16 with pre-layernorm when it gets around 30k steps. I will also train one model with fp32 and pre-layernorm as well as one model with non-cosine sim attention. So an additional 3 baseline tests!
Thank you,
Enrico
from flash-cosine-sim-attention.
@lucidrains Here are the results for fp16 training with pre-layernorm for 30k steps on an A100 (40GB). Training remained more stable as well. I changed to validating every 10 steps as that gave a better idea of the results.
Validation Loss (Validating every 10 steps):
Here is the code for the slight change made to attention to include Pre-LayerNorm:
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head = 64,
heads = 8,
scale = 8,
l2norm_groups = 1,
use_cuda_kernel = False,
**kwargs
):
super().__init__()
inner_dim = dim_head * heads
self.scale = scale
self.heads = heads
self.norm = nn.LayerNorm(dim)
self.l2norm_groups = l2norm_groups
self.attn_fn = plain_cosine_sim_attention if not use_cuda_kernel else partial(flash_cosine_sim_attention, **kwargs)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_k = nn.Linear(dim, inner_dim, bias = False)
self.to_v = nn.Linear(dim, inner_dim, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x):
h, scale, l2norm_groups = self.heads, self.scale, self.l2norm_groups
# pre layernorm
x = self.norm(x)
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
o = self.attn_fn(q, k, v, causal = True, scale = scale, groups = l2norm_groups)
o = rearrange(o, 'b h n d -> b n (h d)')
return self.to_out(o)
I am training the model with fp32 and pre-layernorm now.
Thank you,
Enrico
from flash-cosine-sim-attention.
Here are the results for fp32 training with pre-layernorm for 30k steps on an A100 (40GB).
Validation Loss (Validating every 10 steps):
from flash-cosine-sim-attention.
Results for training standard PaLM on an A100 (40 GB) for 30k steps:
from flash-cosine-sim-attention.
@lucidrains Here is the code for the PaLM model with flash cosine sim attention. The model is currently training and I will update the results likely later tonight or tomorrow morning.
- 5.94s/it
- Sequence Length 8192
- fp32
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import einsum, nn
from functools import partial
from flash_cosine_sim_attention import flash_cosine_sim_attention
# normalization
# they use layernorm without bias, something that pytorch does not offer
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
# residual
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
# rotary positional embedding
# https://arxiv.org/abs/2104.09864
class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, max_seq_len, *, device):
seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = einsum("i , j -> i j", seq, self.inv_freq)
return torch.cat((freqs, freqs), dim=-1)
def rotate_half(x):
x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(pos, t):
return (t * pos.cos()) + (rotate_half(t) * pos.sin())
# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
# https://arxiv.org/abs/2002.05202
class SwiGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x
# parallel attention and feedforward with residual
# discovered by Wang et al + EleutherAI from GPT-J fame
class ParallelTransformerBlock(nn.Module):
def __init__(
self,
dim,
dim_head=64,
heads=8,
scale = 8,
l2norm_groups = 1,
ff_mult=4,
**kwargs
):
super().__init__()
self.norm = LayerNorm(dim)
attn_inner_dim = dim_head * heads
ff_inner_dim = dim * ff_mult
self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
self.attn_fn = partial(flash_cosine_sim_attention, **kwargs)
self.heads = heads
self.scale = scale
self.l2norm_groups = l2norm_groups
self.rotary_emb = RotaryEmbedding(dim_head)
self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
self.ff_out = nn.Sequential(
SwiGLU(),
nn.Linear(ff_inner_dim, dim, bias=False)
)
# for caching rotary embeddings
self.register_buffer("pos_emb", None, persistent=False)
def get_rotary_embedding(self, n, device):
if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
return self.pos_emb[:n]
pos_emb = self.rotary_emb(n, device=device)
self.register_buffer("pos_emb", pos_emb, persistent=False)
return pos_emb
def forward(self, x):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
n, device, h, scale, l2norm_groups = x.shape[1], x.device, self.heads, self.scale, self.l2norm_groups
# pre layernorm
x = self.norm(x)
# attention queries, keys, values, and feedforward inner
q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
# split heads
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
# they found no performance loss past a certain scale, and more efficient decoding obviously
# https://arxiv.org/abs/1911.02150
q = rearrange(q, "b n (h d) -> b h n d", h=h)
# rotary embeddings
positions = self.get_rotary_embedding(n, device)
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
# flash cosine similarity attention
out = self.attn_fn(q, k, v, causal = True, scale = scale, groups = l2norm_groups)
# merge heads
out = rearrange(out, "b h n d -> b n (h d)")
return self.attn_out(out) + self.ff_out(ff)
# transformer
def PaLM_flash(
*,
dim,
num_tokens,
depth,
attn_scale = 8,
attn_l2norm_groups = 1,
dim_head=64,
heads=8,
ff_mult=4,
**kwargs
):
net = nn.Sequential(
nn.Embedding(num_tokens, dim),
*[
Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult, scale=attn_scale, groups=attn_l2norm_groups, **kwargs))
for _ in range(depth)
],
LayerNorm(dim),
nn.Linear(dim, num_tokens, bias=False)
)
# they used embedding weight tied projection out to logits, not common, but works
net[-1].weight = net[0].weight
nn.init.normal_(net[0].weight, std=0.02)
return net
from flash-cosine-sim-attention.
@lucidrains Here are the runs for PaLM with flash-cosine-sim-attention.
- 6.52s/it
- Sequence Length 8192
- fp32
And for the whole training run:
I am working on the ViT now.
from flash-cosine-sim-attention.
@lucidrains Here is the code for a ViT-16 with flash-cosine-sim-attention:
import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from flash_cosine_sim_attention import flash_cosine_sim_attention
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
out = flash_cosine_sim_attention(q, k, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
Here is the training script for the ViT on CIFAR10:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision import transforms as T
import tqdm
import wandb
from vit_cosine_sim_flash import ViT
wandb.init(project="my-test-project")
DEVICE = 'cuda'
IMAGE_SIZE = 224
BATCH_SIZE = 4
LEARNING_RATE = 6e-4
EPOCHS = 100
train_transform = T.Compose([
T.Resize(IMAGE_SIZE),
T.AutoAugment(policy = T.AutoAugmentPolicy.CIFAR10),
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
test_transform = T.Compose([
T.Resize(IMAGE_SIZE),
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
train_dataset = CIFAR10(
root = './cifar_data_train/',
train = True,
download = True,
transform = train_transform,
)
test_dataset = CIFAR10(
root = './cifar_data_train/',
train = False,
download = True,
transform = test_transform,
)
train_loader = DataLoader(
train_dataset,
shuffle = True,
batch_size = BATCH_SIZE,
)
test_loader = DataLoader(
test_dataset,
batch_size = BATCH_SIZE,
)
model = ViT(
image_size = IMAGE_SIZE,
patch_size = 16,
num_classes = 10,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
model = model.to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
model.parameters(),
lr = LEARNING_RATE,
)
for epoch in tqdm.tqdm(range(EPOCHS), desc='training'):
epoch_loss = 0
epoch_acc = 0
for images, labels in train_loader:
images = images.to(DEVICE)
labels = labels.to(DEVICE)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc = (outputs.argmax(dim = 1) == labels).float().mean()
epoch_acc += acc / len(train_loader)
epoch_loss += loss / len(train_loader)
with torch.no_grad():
epoch_val_acc = 0
epoch_val_loss = 0
for images, labels in test_loader:
images = images.to(DEVICE)
labels = labels.to(DEVICE)
val_output = model(images)
val_loss = criterion(val_output, labels)
acc = (val_output.argmax(dim=1) == labels).float().mean()
epoch_val_acc += acc / len(test_loader)
epoch_val_loss += val_loss / len(test_loader)
print(
f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_acc:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_acc:.4f}\n"
)
wandb.log({"epoch": epoch, "train loss": epoch_loss, "train acc": epoch_acc, "val loss": epoch_val_loss, "val acc": epoch_val_acc})
I will post the training results soon.
from flash-cosine-sim-attention.
@lucidrains Here are the results for the ViT-16 experiments with and without flash-cosine-sim attention.
For regular attention I used a learning rate of 2e-4. For flash-cosine-sim I tested with a learning rate of 6e-4.
I am running flash-cosine-sim again with a learning rate of 2e-4 this time instead. I will provide an update with that soon.
Best,
Enrico
from flash-cosine-sim-attention.
Related Issues (10)
- why the flash-cosine-sim-attention is slower than plain-cosine transformer? HOT 2
- Performance compared to plain version HOT 6
- make install fails HOT 4
- GPU Benchmarks HOT 25
- Pip Install Fails HOT 3
- Import fails HOT 3
- Support head dimension 16? HOT 1
- Can not import debug
- failed building wheel for flash-cosine-sim-attention
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from flash-cosine-sim-attention.