Comments (2)
Take the fabric as an example, I changed the backward to fix the issue of When using multiple models + deepspeed, please provide the model used to perform the optimization: self.backward(loss, model=model)
import os
import time
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.utils
from lightning.fabric import Fabric, seed_everything
from torchvision.datasets import CelebA
torch.set_float32_matmul_precision('medium')
# Root directory for dataset
dataroot = "data/"
# Number of workers for dataloader
workers = os.cpu_count()
# Batch size during training
batch_size = 128
# Spatial size of training images
image_size = 64
# Number of channels in the training images
nc = 3
# Size of z latent vector (i.e. size of generator input)
nz = 100
# Size of feature maps in generator
ngf = 64
# Size of feature maps in discriminator
ndf = 64
# Number of training epochs
num_epochs = 5
# Learning rate for optimizers
lr = 0.0002
# Beta1 hyperparameter for Adam optimizers
beta1 = 0.5
# Number of GPUs to use
num_gpus = 1
def main():
# Set random seed for reproducibility
seed_everything(999)
fabric = Fabric(accelerator="auto", devices=1, strategy="deepspeed")
fabric.launch()
dataset = CelebA(
root=dataroot,
split="all",
download=True,
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]),
)
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers)
output_dir = Path("outputs-fabric", time.strftime("%Y%m%d-%H%M%S"))
output_dir.mkdir(parents=True, exist_ok=True)
# Plot some training images
real_batch = next(iter(dataloader))
torchvision.utils.save_image(
real_batch[0][:64],
output_dir / "sample-data.png",
padding=2,
normalize=True,
)
# Create the generator
generator = Generator()
# Apply the weights_init function to randomly initialize all weights
generator.apply(weights_init)
# Create the Discriminator
discriminator = Discriminator()
# Apply the weights_init function to randomly initialize all weights
discriminator.apply(weights_init)
# Initialize BCELoss function
criterion = nn.BCELoss()
# Create batch of latent vectors that we will use to visualize
# the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=fabric.device)
# Establish convention for real and fake labels during training
real_label = 1.0
fake_label = 0.0
# Set up Adam optimizers for both G and D
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
discriminator, optimizer_d = fabric.setup(discriminator, optimizer_d)
generator, optimizer_g = fabric.setup(generator, optimizer_g)
dataloader = fabric.setup_dataloaders(dataloader)
# Lists to keep track of progress
losses_g = []
losses_d = []
iteration = 0
# Training loop
for epoch in range(num_epochs):
for i, data in enumerate(dataloader, 0):
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
# (a) Train with all-real batch
discriminator.zero_grad()
real = data[0]
b_size = real.size(0)
label = torch.full((b_size,), real_label, dtype=torch.float, device=fabric.device)
# Forward pass real batch through D
output = discriminator(real).view(-1)
# Calculate loss on all-real batch
err_d_real = criterion(output, label)
# Calculate gradients for D in backward pass
fabric.backward(err_d_real, model=discriminator)
d_x = output.mean().item()
# (b) Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1, device=fabric.device)
# Generate fake image batch with G
fake = generator(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = discriminator(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
err_d_fake = criterion(output, label)
# Calculate the gradients for this batch, accumulated (summed) with previous gradients
fabric.backward(err_d_fake, model=discriminator)
d_g_z1 = output.mean().item()
# Compute error of D as sum over the fake and the real batches
err_d = err_d_real + err_d_fake
# Update D
optimizer_d.step()
# (2) Update G network: maximize log(D(G(z)))
generator.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output = discriminator(fake).view(-1)
# Calculate G's loss based on this output
err_g = criterion(output, label)
# Calculate gradients for G
fabric.backward(err_g, model=generator)
d_g_z2 = output.mean().item()
# Update G
optimizer_g.step()
# Output training stats
if i % 50 == 0:
fabric.print(
f"[{epoch}/{num_epochs}][{i}/{len(dataloader)}]\t"
f"Loss_D: {err_d.item():.4f}\t"
f"Loss_G: {err_g.item():.4f}\t"
f"D(x): {d_x:.4f}\t"
f"D(G(z)): {d_g_z1:.4f} / {d_g_z2:.4f}"
)
# Save Losses for plotting later
losses_g.append(err_g.item())
losses_d.append(err_d.item())
# Check how the generator is doing by saving G's output on fixed_noise
if (iteration % 500 == 0) or ((epoch == num_epochs - 1) and (i == len(dataloader) - 1)):
with torch.no_grad():
fake = generator(fixed_noise).detach().cpu()
if fabric.is_global_zero:
torchvision.utils.save_image(
fake,
output_dir / f"fake-{iteration:04d}.png",
padding=2,
normalize=True,
)
fabric.barrier()
iteration += 1
def weights_init(m):
# custom weights initialization called on netG and netD
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh(),
# state size. (nc) x 64 x 64
)
def forward(self, input):
return self.main(input)
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid(),
)
def forward(self, input):
return self.main(input)
if __name__ == "__main__":
main()
But I got error like:
Traceback (most recent call last):
File "/home/ichigo/LocalCodes/dl_playground/train.py", line 261, in <module>
main()
File "/home/ichigo/LocalCodes/dl_playground/train.py", line 156, in main
fabric.backward(err_g, model=generator)
File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 448, in backward
self._strategy.backward(tensor, module, *args, **kwargs)
File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/lightning/fabric/strategies/strategy.py", line 191, in backward
self.precision.backward(tensor, module, *args, **kwargs)
File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/lightning/fabric/plugins/precision/deepspeed.py", line 91, in backward
model.backward(tensor, *args, **kwargs)
File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1976, in backward
self.optimizer.backward(loss, retain_graph=retain_graph)
File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 2051, in backward
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
scaled_loss.backward(retain_graph=retain_graph)
File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward
torch.autograd.backward(
File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 899, in reduce_partition_and_remove_grads
self.reduce_ready_partitions_and_remove_grads(param, i)
File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1412, in reduce_ready_partitions_and_remove_grads
self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
File "/home/ichigo/miniconda3/envs/dl/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 945, in reduce_independent_p_g_buckets_and_remove_grads
new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(0, self.elements_in_ipg_bucket, param.numel())
TypeError: 'NoneType' object is not subscriptable
from lightning.
from lightning.
Related Issues (20)
- Issue in Manual optimisation, during self.manual_backward call HOT 1
- Existing metric keys not moved to device after LearningRateFinder
- Checkpoint every_n_steps reruns epoch on restore HOT 3
- Metrics logged by self.log and metric.compute() are different HOT 1
- Multi-node Training with DDP stuck at "Initialize distributed..." on SLURM cluster HOT 3
- Full validation after first microbatch when training after LearningRateFinder
- Add a warning when some of the modules are in eval mode before the training stage
- why pytorch-lightning doc say "Model-parallel training (FSDP and DeepSpeed)". I think there is something wrong. HOT 2
- AWS Trainium fails number of device validation when using more than 1 accelerator on the instances
- OnExceptionCheckpoint: training resumes if ckpt found, even if no ckpt_path provided
- TensorBoardLogger has the wrong epoch numbers much more than the fact HOT 1
- How to incorporate vLLM in Lightning for LLM inference?
- WandbLogger `save_dir` and `dir` parameters do not work as expected.
- Loading large models with fabric, FSDP and empty_init=True does not work
- Unable to extract confusion matrix as a metric from trainer HOT 1
- Torchmetrics Accuracy issue when dont shuffle test data. HOT 1
- ModelCheckpoint: Using save_top_k, only the first k models are stored, not the best k models HOT 2
- trainer.fit from checkpoint without performance improvement will break 'last' link to checkpoint on window11
- Exception in RecordFunction callback: state_ptr INTERNAL ASSERT FAILED at "../torch/csrc/profiler/standalone/nvtx_observer.cpp":115
- `ckpt_path` in `Trainer` accepts URIs to automatically load checkpoints from remote paths
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from lightning.