GithubHelp home page GithubHelp logo

kaushalya / medclip Goto Github PK

View Code? Open in Web Editor NEW
107.0 3.0 17.0 428.71 MB

A multi-modal CLIP model trained on the medical dataset ROCO

Home Page: https://huggingface.co/spaces/kaushalya/medclip-roco

License: Apache License 2.0

Shell 0.39% Python 40.94% Jupyter Notebook 58.66%
transformers jax flax huggingface machine-learning medical-image-analysis

medclip's Introduction

title emoji colorFrom colorTo sdk app_file pinned
Medical image retrieval using a CLIP model
🩺
red
white
streamlit
app.py
true

MedCLIP: Fine-tuning a CLIP model on the ROCO medical dataset

huggingface-medclip

Summary

This repository contains the code for fine-tuning a CLIP model [Arxiv paper][OpenAI Github Repo] on the ROCO dataset, a dataset made of radiology images and a caption. This work is done as a part of the Flax/Jax community week organized by Hugging Face and Google.

SciBERT (allenai/scibert_scivocab_uncased on 🤗) is used as the casual language model.

[🤗 Model card] [Streamlit demo]

Demo

You can try a Streamlit demo app that uses this model on 🤗 Spaces. You may have to signup for 🤗 Spaces private beta to access this app (screenshot shown below). Streamlit app

The demo can be run locally in the browser with

streamlit run /home/kaushalya/coding/medclip/app.py

Dataset 🧩

Each image is accompanied by a textual caption. The caption length varies from a few characters (a single word) to 2,000 characters (multiple sentences). During preprocessing we remove all images that has a caption shorter than 10 characters. Training set: 57,780 images with their caption. Validation set: 7,200 Test set: 7,650

[ ] Give an example

Installation 💽

This repo depends on the master branch of Hugging Face - Transformers library. First you need to clone the transformers repository and then install it locally (preferably inside a virtual environment) with pip install -e ".[flax]".

The Model ⚙️

You can load the pretrained model from the Hugging Face Hub with

from medclip.modeling_hybrid_clip import FlaxHybridCLIP

model = FlaxHybridCLIP.from_pretrained("flax-community/medclip-roco")

Alternatively you can download the model checkpoint from [🤗 Model card].

Training

The model is trained using Flax/JAX on a cloud TPU-v3-8. You can fine-tune a CLIP model implemented in Flax by simply running sh run_medclip. This is the validation loss curve we observed when we trained the model using the run_medclip.sh script. Validation loss

Limitations 🚨

The current model is capable of identifying higher level features such as the modality of ain image (e.g., if a given radiology image is a PET scan or an ultrasound scan). However it fails at identifying a brain scan from a lung scan. ❗️This model should not be used in a medical setting without further evaluations❗️.

Acknowledgements

Huge thanks to the Hugging Face 🤗 team and Google JAX/Flax team for organizing the community week and letting us use cloud compute for 2 weeks. We specially thank @patil-suraj & @patrickvonplaten for the continued support on Slack and the detailed feedback.

TODO

[ ] Mention more examples

[ ] Evaluation on down-stream tasks

[ ] Zero-shot learning performance

medclip's People

Contributors

kaushalya avatar raulcarlomagno avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

medclip's Issues

App fails to load

The Streamlit app fails with the following error message. Looks like pickle files are being deserialized with protocol=5.

AttributeError: Can't get attribute 'new_block' on <module 'pandas.core.internals.blocks' from '/home/user/.local/lib/python3.8/site-packages/pandas/core/internals/blocks.py'>

Deprecation warning

st.cache is deprecated. Please use one of Streamlit's new caching commands, st.cache_data or st.cache_resource.

Throws this error in search

2023-09-07 17:27:37.161 Invalid arguments were passed to "st.write" function. Support for passing such unknown keywords arguments will be dropped in future. Invalid arguments were: {'help': 'score'}

Streamlit app fails to load

Stacktrace:

AttributeError: module 'jaxlib.pocketfft' has no attribute 'pocketfft'
Traceback:

File "/home/user/.local/lib/python3.8/site-packages/streamlit/script_runner.py", line 354, in _run_script
    exec(code, module.__dict__)
File "/home/user/app/app.py", line 7, in <module>
    from medclip.modeling_hybrid_clip import FlaxHybridCLIP
File "/home/user/app/medclip/modeling_hybrid_clip.py", line 18, in <module>
    import flax.linen as nn
File "/home/user/.local/lib/python3.8/site-packages/flax/__init__.py", line 36, in <module>
    from . import core
File "/home/user/.local/lib/python3.8/site-packages/flax/core/__init__.py", line 15, in <module>
    from .axes_scan import broadcast
File "/home/user/.local/lib/python3.8/site-packages/flax/core/axes_scan.py", line 17, in <module>
    import jax
File "/home/user/.local/lib/python3.8/site-packages/jax/__init__.py", line 108, in <module>
    from .experimental.maps import soft_pmap
File "/home/user/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 25, in <module>
    from .. import numpy as jnp
File "/home/user/.local/lib/python3.8/site-packages/jax/numpy/__init__.py", line 16, in <module>
    from . import fft
File "/home/user/.local/lib/python3.8/site-packages/jax/numpy/fft.py", line 17, in <module>
    from jax._src.numpy.fft import (
File "/home/user/.local/lib/python3.8/site-packages/jax/_src/numpy/fft.py", line 19, in <module>
    from jax import lax
File "/home/user/.local/lib/python3.8/site-packages/jax/lax/__init__.py", line 331, in <module>
    from jax._src.lax.fft import (
File "/home/user/.local/lib/python3.8/site-packages/jax/_src/lax/fft.py", line 144, in <module>
    xla.backend_specific_translations['cpu'][fft_p] = pocketfft.pocketfft

error when trying to load pretrained model

hello, i have this error when trying to load the model from your notebooks

image

Name: transformers
Version: 4.24.0

Name: torch
Version: 1.13.0

Name: jax
Version: 0.3.24

Name: flax
Version: 0.6.1

Improve the efficiency of embedding search

The current implementation performs a brute-force search to find images with embeddings in the k-neighborhood for a text embedding. This can be made efficient using an external library such as faiss.

Update README

  • Mention how and where to download ROCO dataset.
  • Mention details on the dataset.json files

you are using this path of image embedding

where is used embedding file ?
vision_model_name = "openai/clip-vit-base-patch32"
img_dir = "/Users/kaumad/Documents/coding/hf-flax/demo/medclip-roco/images"

Please tell me

Stuck at training ?

I am using cuda 11.2, i have 8 Confif as GPU RTX 5000 :
Gpu is not pick ,
can you help me please

Stuck at training

When I run run_medclip.sh, everything initiates perfectly but when it gets to run the training, it gets stuck and does not pass the 0% progress!

loading weights file https://huggingface.co/openai/clip-vit-base-patch32/resolve/main/pytorch_model.bin from cache at /home/**me**/.cache/huggingface/transformers/8a82711445c5200c2b4fd30739df371f5b3ce2d7e316418d58dd290bae1c1cc8.dabcc684421296ebcdafd583a4415c1757ae007787f2d0e17b87482d9b8cf197
Loading PyTorch weights from /home/**me**/.cache/huggingface/transformers/8a82711445c5200c2b4fd30739df371f5b3ce2d7e316418d58dd290bae1c1cc8.dabcc684421296ebcdafd583a4415c1757ae007787f2d0e17b87482d9b8cf197
PyTorch checkpoint contains 151,277,440 parameters.
Some weights of the model checkpoint at openai/clip-vit-base-patch32 were not used when initializing FlaxCLIPModel: {('vision_model', 'embeddings', 'position_ids'), ('text_model', 'embeddings', 'position_ids')}
- This IS expected if you are initializing FlaxCLIPModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxCLIPModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of FlaxCLIPModel were initialized from the model checkpoint at openai/clip-vit-base-patch32.
If your task is similar to the task the model of the checkpoint was trained on, you can already use FlaxCLIPModel for predictions without further training.
text_config_dict is None. Initializing the CLIPTextConfig with default values.
vision_config_dict is None. initializing the CLIPVisionConfig with default values.
08/04/2021 09:57:44 - INFO - absl -   A polynomial schedule was set with a non-positive `transition_steps` value; this results in a constant schedule with value `init_value`.
/home/**me**/.local/lib/python3.9/site-packages/jax/lib/xla_bridge.py:364: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.
  warnings.warn(
/home/**me**/.local/lib/python3.9/site-packages/jax/lib/xla_bridge.py:351: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code.
  warnings.warn(
08/04/2021 09:57:45 - INFO - __main__ -   ***** Running training *****
08/04/2021 09:57:45 - INFO - __main__ -     Num examples = 65420
08/04/2021 09:57:45 - INFO - __main__ -     Num Epochs = 40
08/04/2021 09:57:45 - INFO - __main__ -     Instantaneous batch size per device = 64
08/04/2021 09:57:45 - INFO - __main__ -     Total train batch size (w. parallel & distributed) = 64
08/04/2021 09:57:45 - INFO - __main__ -     Total optimization steps = 40880
Epoch ... (1/40):   0%|                                                                                                                                                | 0/40 [00:00<?, ?it/s2021-08-04 10:00:17.638335: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:55]                                                     | 0/1022 [00:00<?, ?it/s]
********************************
Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
Compiling module pmap_train_step.103881
********************************
                                                                              

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.