Comments (8)
Hi @MarkBenjamin , does it re-download the model if you pass the cache_dir as a parameter ?
model = HQQModelForCausalLM.from_pretrained(model_id, cache_dir=cache_path, torch_dtype=compute_dtype, attn_implementation=attn_imp)
from hqq.
It seems to try to
HQQModelForCausalLM.from_quantized('mobiuslabsgmbh/Llama-2-7b-chat-hf_1bitgs8_hqq', cache_dir='/home/mark/.cache/huggingface/hub/')
for instance
from hqq.
it looks as though cache_dir
isn't being passed from HQQWrapper.from_quantized()
to BaseHQQModel.try_snapshot_download()
diff --git a/hqq/engine/base.py b/hqq/engine/base.py
index 27b3050..e2f4acf 100755
--- a/hqq/engine/base.py
+++ b/hqq/engine/base.py
@@ -74,12 +74,12 @@ class HQQWrapper:
cls,
save_dir_or_hub,
compute_dtype: torch.dtype = float16,
- device="cuda",
- cache_dir: str = "",
+ device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
+ cache_dir: str | None = "",
adapter: str = None,
):
# Both local and hub-support
- save_dir = BaseHQQModel.try_snapshot_download(save_dir_or_hub)
+ save_dir = BaseHQQModel.try_snapshot_download(save_dir_or_hub, cache_dir=cache_dir)
arch_key = cls._get_arch_key_from_save_dir(save_dir)
cls._check_arch_support(arch_key)
however even after that change, HQQModelForCausalLM.from_quantized('mobiuslabsgmbh/Llama-2-7b-chat-hf_1bitgs8_hqq', cache_dir=None)
is trying to re-download the model
from hqq.
You need to put here as well: https://github.com/mobiusml/hqq/blob/master/hqq/models/base.py#L303
That part in /hqq/engine/base.py
is not supposed to download the whole repo, only the config.json
to figure out which architecture the model is, but somehow forgot to change it.
I actually don't see this issue. Just use a fixed cache_dir
, like cache_dir=''
and it should work just fine.
from hqq.
I seem to have it working now (possibly the most necessary change is the passing of the cache_dir
value from HQQWrapper.from_quantized()
to BaseHQQModel.try_snapshot_download()
), whether I specify the huggingface cache path or specifically send None
(that then needs guarding the pjoin()
call); the diff (plus some type hints allowing specifically passing None
; including in BaseHQQModel.from_quantized()
as you suggest) cleaned of the torch cpu code would be
diff --git a/hqq/engine/base.py b/hqq/engine/base.py
index 27b3050..a767fad 100755
--- a/hqq/engine/base.py
+++ b/hqq/engine/base.py
@@ -75,11 +75,11 @@ class HQQWrapper:
save_dir_or_hub,
compute_dtype: torch.dtype = float16,
device="cuda",
- cache_dir: str = "",
+ cache_dir: str | None = "",
adapter: str = None,
):
# Both local and hub-support
- save_dir = BaseHQQModel.try_snapshot_download(save_dir_or_hub)
+ save_dir = BaseHQQModel.try_snapshot_download(save_dir_or_hub, cache_dir=cache_dir)
arch_key = cls._get_arch_key_from_save_dir(save_dir)
cls._check_arch_support(arch_key)
diff --git a/hqq/models/base.py b/hqq/models/base.py
index e305391..7ba9494 100755
--- a/hqq/models/base.py
+++ b/hqq/models/base.py
@@ -278,8 +278,11 @@ class BaseHQQModel:
cls.save_weights(weights, save_dir)
@classmethod
- def try_snapshot_download(cls, save_dir_or_hub: str, cache_dir: str = ""):
- save_dir = pjoin(cache_dir, save_dir_or_hub)
+ def try_snapshot_download(cls, save_dir_or_hub: str, cache_dir: str | None = ""):
+ if cache_dir is None:
+ save_dir = pjoin("", save_dir_or_hub)
+ else:
+ save_dir = pjoin(cache_dir, save_dir_or_hub)
if not os.path.exists(save_dir):
save_dir = snapshot_download(repo_id=save_dir_or_hub, cache_dir=cache_dir)
@@ -305,13 +308,11 @@ class BaseHQQModel:
save_dir_or_hub,
compute_dtype: torch.dtype = float16,
device="cuda",
- cache_dir: str = "",
+ cache_dir: str | None = "",
adapter: str = None,
):
# Get directory path
save_dir = cls.try_snapshot_download(save_dir_or_hub, cache_dir)
-
- # Load model from config
model = cls.create_model(save_dir)
# Track save directory
I guess you wouldn't see it unless you have already (or subsequently?) directly called
from huggingface_hub import snapshot_download
snapshot_download('mobiuslabsgmbh/Llama-2-7b-chat-hf_1bitgs8_hqq')
unless your envvars have altered the HF_HUB_CACHE
or (legacy) HUGGINGFACE_HUB_CACHE
or (climbing the tree) HF_HOME
or XDG_CACHE_HOME
values, or more generally some coincidence between the substitution for "" as cache_dir
(your working directory?) / HF_HUB_CACHE
value
from hqq.
Sorry @MarkBenjamin , I missed this.
In general, if you specify the cache directory when you load via from_pretrained
this problem doesn't happen. At least, no one has reported this issue before.
I can try using the logic you mentioned in try_snapshot_download
or you can do a pull request if you want, happy to review it !
from hqq.
Will do 🙂
Best regards
Mark
from hqq.
Closing since this was merged, thanks!
from hqq.
Related Issues (20)
- Question about quantization. HOT 2
- Is HQQLinearLoRAWithFakeQuant differentiable? HOT 1
- hqq+ lora ValueError || ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' HOT 3
- Activation quantization HOT 9
- Group_Size setting HOT 1
- 1 bit inference HOT 4
- Weird problem in loading quantized_model + lora_adpter
- 2-bit quantization representation HOT 3
- module 'torch.library' has no attribute 'custom_op' HOT 4
- bitblas introduces dependency on CUDA version HOT 3
- OSError: libnvrtc.so.12: cannot open shared object file: No such file or directory HOT 1
- About the implentation of .cpu() HOT 1
- 3-bit quantization weight data type issue HOT 10
- RuntimeError: Expected in.dtype() == at::kInt to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.) HOT 1
- Weight Sharding HOT 1
- Support Gemma quantization HOT 2
- Bug of the saved model when applying zero and scale quantization HOT 1
- Expected in.dtype() == at::kInt to be true, but got false HOT 14
- Easy way to run lm evaluation harness HOT 1
- Warning: failed to import the BitBlas backend
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from hqq.