matthias-wright / flaxmodels Goto Github PK
View Code? Open in Web Editor NEWPretrained deep learning models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet, etc.
Home Page: https://github.com/matthias-wright/flaxmodels
Pretrained deep learning models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet, etc.
Home Page: https://github.com/matthias-wright/flaxmodels
Hi @matthias-wright, I've been playing around for a couple days with your project and it's so cool, thanks for building some pure flax
models here 👍🏻
Don't know if you're aware, but @huggingface developed a new format for storing tensors named safetensors
as most of the serialized models from PyTorch use pickle
to store the tensors, which seems to be not super efficient plus it has some known security issues. So I want to know whether you're considering to port the current tensors to use safetensors
instead.
I've recently built safejax
so as to easily do that, which means that the storage is optimal and more safe! If this is something you could consider to improve flaxmodels
please let me know and I can try to help if applicable!
P.S. Did you consider publishing the Python package to PyPI tracking it through GitHub Release so that it attracts more users due to the ease of installation through pip
from PyPI instead of from source as in the README.md
?
Hi,
FYI: The ResNet demo notebook needs to install a later jaxlib version in order to be compatible to the jax version that it installs in its first cell.
I suggest you to do this pip install instead:
!pip install --upgrade jax jaxlib==0.1.75+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_releases.html
Having done that (despite or because of this change) the line params = resnet18.init(key, x)
fails with this error:
UnfilteredStackTrace: RuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for:
%cudnn-conv = (f32[1,180,240,64]{2,1,3,0}, u8[0]{0}) custom-call(f32[1,360,480,3]{2,1,3,0} %copy.3, f32[7,7,3,64]{1,0,2,3} %copy.4), window={size=7x7 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", metadata={op_type="conv_general_dilated" op_name="jit(conv_general_dilated)/conv_general_dilated[\n batch_group_count=1\n dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n feature_group_count=1\n lhs_dilation=(1, 1)\n lhs_shape=(1, 360, 480, 3)\n padding=((3, 3), (3, 3))\n precision=None\n preferred_element_type=None\n rhs_dilation=(1, 1)\n rhs_shape=(7, 7, 3, 64)\n window_strides=(2, 2)\n]" source_file="/usr/local/lib/python3.7/dist-packages/flax/linen/linear.py" source_line=291}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"
Original error: UNIMPLEMENTED: DNN library is not found.
To ignore this failure and try to use a fallback algorithm (which may have suboptimal performance), use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false. Please also file a bug for the root cause of failing autotuning.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
Hope this helps you in some way.
How do i run the training on my custom dataset while utilizing colab TPU.
From the doc string of ResNet18, I saw you have the following comments:
The pretrained parameters are taken from:
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
May I know how did you convert a PyTorch model weight into Flax? I ddin't find any reference that touch upon this area. Many thanks!
I noticed that some links (e.g. from the main README
, or from gpt2/README.md
) still point to the old Colabs on Drive.
Could you update the links such that they all point to the new Colab paths on Github?
Thanks, Andreas
Thank you for this code! We are using it to implement a simple version of lpips loss here. However, I didn't find a way to install our package from PyPI as I don't know how to include a dependency to your repository distribution, and installing flaxmodels
from PyPI doesn't work in our case. See this issue for a brief discussion.
Do you think it would be possible to update your PyPI distribution so we can in turn build a package that uses it as a dependency? Being new to PyPI myself, I'd also be interested to learn about any drawbacks about doing it this way.
Thanks again!
Hi @matthias-wright, thanks a lot for releasing this nice, complete package for pretrained models! I have recently used your package in a tutorial to extract features from a pre-trained ResNet34, and noticed that the pip installation of the package requires an old numpy version (v1.19.5, setup.py, line 15). However, the current tensorflow package requires a newer numpy (>v1.21) and installing flaxmodels can break an existing tensorflow installation. In this case, Flax throws an error during import regarding checkpoints from tensorflow because tensorflow has been compiled with a different numpy version than flaxmodels overwrote. A re-installation of the newest numpy version fixes the issue. Is it possible to change the requirement of the numpy package to >=
instead of ==
, similar to what is currently used for JAX and Flax?
Hello Matthias,
We have noticed that adjusting the value of --fmap_base
only affects the generator:
flaxmodels/training/stylegan2/training.py
Line 78 in 0ec7f22
but not the discriminator:
flaxmodels/training/stylegan2/training.py
Line 101 in 0ec7f22
Is this intentional?
In the StyleGAN 2 paper, both G and D receive increased capacity (bottom of page 7):
This leads us to hypothesize that there is a capacity problem in our networks, which we test by doubling the number of feature maps in the highest-resolution layers of both networks.
We double the number of feature maps in resolutions 64^2–1024^2 while keeping other parts of the networks unchanged. This increases the total number of trainable parameters in the generator by 22% (25M → 30M) and in the discriminator by 21% (24M → 29M).
I am sorry to write another issue again because I am trying to apply transfer learning on ResNet using your flaxmodels.
However, I got stuck on how to get the backbone and add a new head on top of the ResNet.
Do you have any sample code/ guidance for me as a reference?
When trying to run the StyleGAN 2 training code on Google Colab, I'm getting:
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
But that's after confirming that the TPU is setup correctly:
Here's a minimal example: https://colab.research.google.com/gist/josephrocca/5e64c9906db96f27b583f0a577ef9b4a/debugging-matthias-wright-s-stylegan2-jax-tpu-not-detected.ipynb
If I set TF_CPP_MIN_LOG_LEVEL=0
, I get:
2021-10-08 16:05:10.421297: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
2021-10-08 16:05:12.352286: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:171] XLA service 0x55a8d14dddc0 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
2021-10-08 16:05:12.352348: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:179] StreamExecutor device (0): Interpreter, <undefined>
2021-10-08 16:05:12.358082: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc:163] TfrtCpuClient created.
2021-10-08 16:05:12.371498: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2021-10-08 16:05:12.371542: I external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (4b26ff987b6b): /proc/driver/nvidia/version does not exist
2021-10-08 16:05:12.371984: I external/org_tensorflow/tensorflow/stream_executor/tpu/tpu_platform_interface.cc:74] No TPU platform found.
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
2021-10-08 16:05:12.508418: I tensorflow/core/platform/cloud/google_auth_provider.cc:180] Attempting an empty bearer token since no token was retrieved from files, and GCE metadata check was skipped.
2021-10-08 16:05:12.545218: I tensorflow/core/platform/cloud/google_auth_provider.cc:180] Attempting an empty bearer token since no token was retrieved from files, and GCE metadata check was skipped.
2021-10-08 16:05:12.586517: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2021-10-08 16:05:12.586652: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1
2021-10-08 16:05:12.595212: E tensorflow/stream_executor/cuda/cuda_driver.cc:328] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2021-10-08 16:05:12.595243: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (4b26ff987b6b): /proc/driver/nvidia/version does not exist
2021-10-08 16:05:12.595544: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-10-08 16:05:12.597344: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set
Not sure if this problem is specific to the StyleGAN 2 training code, since I haven't tried any of the other models. I'm going to continue trying to debug this tommorrow - will update this post if I find out what's going on here.
It doesn’t look like there’s a way to remove the head in the resnet models. There is a way to dump the activations. It would be nice if I could specify up till which resnet block I’d like outputs for to avoid wasted computation of the stages that I don’t need (if I just want to use it as a low level feature extractor)
Hi,
First of all, you did a really good job on converting the stylegan2 implementation to jax.
I am not sure if this is an issue or it was originally intended to design the module this way, but one thing I noticed is that you pass the PRNG for the weight initialization as a parameter in the flax module.
This means that initializing the module using flax.Module.init with different PRNG will result in the same weight initialization for both the generator and the discriminator. The only way to produce different initialization would be to pass different PRNGKeys at the creation of the StyleGAN module.
Here's a minimal code that I made the demonstrates this:
# Create a StyleGAN2 Generator Flax.Module
G_model = stylegan2.generator.Generator(pretrained=None)
# Invoking this method will initialize the module based on the PRNGKey passed (i.e., g_rng)
def init_g(g_rng):
z_shape = (4, 512)
@jax.jit
def _init(*args):
return G_model.init(*args, train=True)
variables = _init({'params': g_rng}, jnp.ones(z_shape, G_model.dtype))
return variables['params'], variables['moving_stats'], variables['noise_consts']
Now initializing the module using the flax.Module.init will give the same params:
params_1, _, _ = init_g(jax.random.PRNGKey(10))
params_2, _, _ = init_g(jax.random.PRNGKey(58))
## The following returns true
jnp.alltrue( params_1['mapping_network']['LinearLayer_0']['weight'] == params_2['mapping_network']['LinearLayer_0']['weight'])
In order to initialize the module with different seeds, then you need to pass the PRNGKey explicitly at the creation of the Module.
If this was intended, then I think you also need to split the RNGs whenever the key is passed to submodules. This will produce different random numbers for different weights. Otherwise, creating the same layer twice (for example two ops.LinearLayer with the same hyperparameters, including the PRNGKey) will create the exact parameters for these layers.
Thanks again for the hard work on making the module Jax accessible 👍
Hi,
Thanks for your very useful work!
It would be very interesting to experiment with other pretrained weights. I am particularly interested in self-supervised learning weights. E.g. https://github.com/facebookresearch/barlowtwins
Could you share the weight processing script and/or add the proposed weights?
Thank you again!
Thank you so much for creating this. For resnet, I am wondering how the implementation of BatchNorm differs from the Flax implementation? Basically, I'm wondering if I can replace ops.BatchNorm
with flax.linen.BatchNorm
to reduce dependencies? Thanks!
Hi,
Nice repo! Is this line a bug? Since I think batch['images']
is N x B x H x W x C
, so the indices should be shift up by 1.
flaxmodels/training/stylegan2/training.py
Line 239 in edc6a85
Running tests with new flax (0.4.2) yields
E TypeError: call() got an unexpected keyword argument 'rng'
from the line :
x = nn.Dropout(rate=self.embd_dropout)(x, deterministic=not training, rng=rng)
Looking at dropout docs it seems this key is not available in flax.linen https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.Dropout.html
Some problems I ran into:
tfds.ImageFolder
working with a "flat" folder of images. I had to nest a dummy label folder inside a dummy split folder. I followed the instructions here: https://www.tensorflow.org/datasets/api_docs/python/tfds/folder_dataset/ImageFoldernum_examples
property in tfds.core.DatasetInfo
, so I had to use builder.info.splits['fake_split'].num_examples
where fake_split
is the name of my dummy split folder. It does look like there's a total_num_examples
property, but I'm not sure how to access it - maybe it's a private field (though I'm not sure if those are possible in Python)?pre_process
because it was expecting protobufs instead of {image, label}
objects.Note that the reason I am using the ImageFolder
approach is because the tfrecords approach blew my 3GB dataset up to 200GB, since I think it's storing the raw tensor data? I'm new to this, but it seems like it'd make more sense to just store the data in jpg format since jpg decoding is so fast? That said, even if the tfrecords approach used a reasonable amount of space, I'd probably still prefer to store the ImageFolder
approach since it just seems nicer and more portable. Even better, from my (newbie) perspective, would be the ability to load a tar
of images with any internal directory structure.
Below is my new data_pipeline.py
so far. It seems to work okay now, but I haven't got training to work yet as I'm still debugging some stuff. Will update this post if I run into any more problems with data_pipeline.py
.
import tensorflow as tf
import tensorflow_datasets as tfds
import jax
import flax
import numpy as np
from PIL import Image
import os
from typing import Sequence
from tqdm import tqdm
import json
from tqdm import tqdm
def prefetch(dataset, n_prefetch):
# Taken from: https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py
ds_iter = iter(dataset)
ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x),
ds_iter)
if n_prefetch:
ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch)
return ds_iter
def get_data(data_dir, img_size, img_channels, num_classes, num_devices, batch_size, shuffle_buffer=1000):
"""
Args:
data_dir (str): Root directory of the dataset.
img_size (int): Image size for training.
img_channels (int): Number of image channels.
num_classes (int): Number of classes, 0 for no classes.
num_devices (int): Number of devices.
batch_size (int): Batch size (per device).
shuffle_buffer (int): Buffer used for shuffling the dataset.
Returns:
(tf.data.Dataset): Dataset.
"""
def pre_process(example):
# feature = {'height': tf.io.FixedLenFeature([], tf.int64),
# 'width': tf.io.FixedLenFeature([], tf.int64),
# 'channels': tf.io.FixedLenFeature([], tf.int64),
# 'image': tf.io.FixedLenFeature([], tf.string),
# 'label': tf.io.FixedLenFeature([], tf.int64)}
# example = tf.io.parse_single_example(serialized_example, feature)
# height = tf.cast(example['height'], dtype=tf.int64)
# width = tf.cast(example['width'], dtype=tf.int64)
# channels = tf.cast(example['channels'], dtype=tf.int64)
# image = tf.io.decode_raw(example['image'], out_type=tf.uint8)
# image = tf.reshape(image, shape=[height, width, channels])
image = example['image']
image = tf.cast(image, dtype='float32')
image = tf.image.resize(image, size=[img_size, img_size], method='bicubic', antialias=True)
image = tf.image.random_flip_left_right(image)
image = (image - 127.5) / 127.5
label = tf.one_hot(example['label'], num_classes)
return {'image': image, 'label': label}
def shard(data):
# Reshape images from [num_devices * batch_size, H, W, C] to [num_devices, batch_size, H, W, C]
# because the first dimension will be mapped across devices using jax.pmap
data['image'] = tf.reshape(data['image'], [num_devices, -1, img_size, img_size, img_channels])
data['label'] = tf.reshape(data['label'], [num_devices, -1, num_classes])
return data
# print('Loading TFRecord...')
# with open(os.path.join(data_dir, 'dataset_info.json'), 'r') as fin:
# dataset_info = json.load(fin)
# ds = tf.data.TFRecordDataset(filenames=os.path.join(data_dir, 'dataset.tfrecords'))
# ds = ds.shuffle(min(dataset_info['num_examples'], shuffle_buffer))
builder = tfds.ImageFolder(data_dir)
print(builder.info)
ds = builder.as_dataset(split='fake_split', shuffle_files=True)
num_examples = builder.info.splits['fake_split'].num_examples
dataset_info = {'num_examples': num_examples, 'num_classes': 1}
ds = ds.shuffle(min(num_examples, shuffle_buffer))
ds = ds.map(pre_process, tf.data.AUTOTUNE)
ds = ds.batch(batch_size * num_devices, drop_remainder=True)
ds = ds.map(shard, tf.data.AUTOTUNE)
ds = ds.prefetch(1)
return ds, dataset_info
Hi,
In the README, it is mentioned that input should be between 0 and 1.
In the training code, they seem to be between -1 and 1.
In the torchvision doc, they seem to be loaded between 0 and 1 and then normalized with
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
Should they be preprocessed as per the torchvision docs?
I'm struggling to use the stylegan 2 training code on my local machine due to dependency-resolver issues. I've tried several docker images (including deepo images and tensorflow-gpu, for example), but it always ends up taking so long that I have to cancel it (I waited an hour the first time 😳). The install process throws warnings like this:
INFO: pip is looking at multiple versions of attrs to determine which version is compatible with other requirements. This could take a while.
INFO: This is taking longer than usual. You might need to provide the dependency resolver with stricter constraints to reduce runtime. If you want to abort this run, you can press Ctrl + C to do so. To improve how pip performs, tell us what happened here: https://pip.pypa.io/surveys/backtracking
Using pip install --use-deprecated=legacy-resolver -r requirements.txt
results in this error at the end of installation:
ERROR: pip's legacy dependency resolver does not consider dependency conflicts when selecting packages. This behaviour is the source of the following dependency conflicts.
chex 0.0.8 requires dataclasses>=0.7; python_version >= "3.6" and python_version < "3.7", but you'll have dataclasses 0.6 which is incompatible.
tensorflow-gpu 2.5.0 requires gast==0.4.0, but you'll have gast 0.3.3 which is incompatible.
tensorflow-gpu 2.5.0 requires grpcio~=1.34.0, but you'll have grpcio 1.32.0 which is incompatible.
tensorflow-gpu 2.5.0 requires h5py~=3.1.0, but you'll have h5py 2.10.0 which is incompatible.
tensorflow-gpu 2.5.0 requires tensorflow-estimator<2.6.0,>=2.5.0rc0, but you'll have tensorflow-estimator 2.4.0 which is incompatible.
tensorflow-metadata 1.2.0 requires absl-py<0.13,>=0.9, but you'll have absl-py 0.13.0 which is incompatible.
yaspin 2.1.0 requires dataclasses<0.9,>=0.8; python_version >= "3.6" and python_version < "3.7", but you'll have dataclasses 0.6 which is incompatible.
flaxmodels 0.1.1 requires tqdm==4.60.0, but you'll have tqdm 4.61.1 which is incompatible.
Could specific versions be provided in requirements.txt, assuming that would resolve these problems? Apparently pip
has a freeze
command that can be used to grab the versions that you're using?
Hey,
Super cool project!
I discovered it as I plan to try to port lpips to JAX (VGG16 and inference only, no training) and I see that the VGG16 part is already done so only the lpips module needs to be ported.
I noticed that the models were hosted on dropbox. May I suggest to host them on huggingface model hub for more reliability and control (versions, etc)? Also storage is free there so it's probably more interesting!
Demo notebook: flaxmodels/resnet/resnet_demo.ipynb
Notebook executed on Google Colab (GPU runtime)
Descriptive error message:
UnfilteredStackTrace: OSError: Unable to open file (file signature not found)
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
OSError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/h5py/_hl/files.py in make_fid(name, mode, userblock_size, fapl, fcpl, swmr)
171 if swmr and swmr_support:
172 flags |= h5f.ACC_SWMR_READ
--> 173 fid = h5f.open(name, flags, fapl=fapl)
174 elif mode == 'r+':
175 fid = h5f.open(name, h5f.ACC_RDWR, fapl=fapl)
h5py/_objects.pyx in h5py._objects.with_phil.wrapper()
h5py/_objects.pyx in h5py._objects.with_phil.wrapper()
h5py/h5f.pyx in h5py.h5f.open()
OSError: Unable to open file (file signature not found)
Hi! I am relatively new to flax, and trying to train the model for stylegan. When I try initializing it using the .init(...) method, I get this error:
flax.errors.SetAttributeFrozenModuleError: Can't set w_avg=[-4.48062038e-03 -2.26424704e-03 -1.13881414e-03 1.17254574e-04
6.46775961e-03 9.12655238e-03 9.49112512e-03 9.94393881e-03
4.56054788e-03 4.13127383e-03 8.64951871e-03 -1.31167704e-03
-1.39767269e-03 -1.31127622e-03 8.68047995e-04 -1.76230317e-03
...
-3.35969799e-03 1.63379814e-02 -3.18925083e-03 -2.73406412e-03] for Module of type MappingNetwork: Module instance is frozen outside of setup method. (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.SetAttributeFrozenModuleError)
Here is the line which causes the error:
...
from flaxmodels.stylegan2 import *
from jax import random
...
dummy_z = random.normal(key, (1, 512))
generator = Generator()
params = generator.init(key, dummy_z)['params'] #The problematic line, works the same even without ['params']
...
Thanks!
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.