-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Description
Describe the bug
Inference not working with quantization
Reproduction
Use the sample code from here
https://github.com/NVlabs/Sana/blob/main/asset/docs/8bit_sana.md#quantization
Replace model with Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers
and dtype torch.bfloat16
Logs
(venv) C:\ai1\diffuser_t2i>python Sana_4K-Quant.py
`low_cpu_mem_usage` was None, now default to True since model is quantized.
Loading checkpoint shards: 100%|████████████████████████████████████| 2/2 [00:28<00:00, 14.45s/it]
Expected types for text_encoder: ['AutoModelForCausalLM'], got Gemma2Model.
Loading pipeline components...: 100%|███████████████████████████████| 5/5 [00:15<00:00, 3.17s/it]
The 'batch_size' argument of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'max_batch_size' argument instead.
The 'batch_size' attribute of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'self.max_batch_size' attribute instead.
C:\ai1\diffuser_t2i\venv\lib\site-packages\bitsandbytes\autograd\_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
0%| | 0/20 [00:00<?, ?it/s]
Traceback (most recent call last):
File "C:\ai1\diffuser_t2i\Sana_4K-Quant.py", line 30, in <module>
image = pipeline(prompt).images[0]
File "C:\ai1\diffuser_t2i\venv\lib\site-packages\torch\utils\_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "C:\ai1\diffuser_t2i\venv\lib\site-packages\diffusers\pipelines\sana\pipeline_sana.py", line 882, in __call__
noise_pred = self.transformer(
File "C:\ai1\diffuser_t2i\venv\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\ai1\diffuser_t2i\venv\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "C:\ai1\diffuser_t2i\venv\lib\site-packages\diffusers\models\transformers\sana_transformer.py", line 414, in forward
hidden_states = self.patch_embed(hidden_states)
File "C:\ai1\diffuser_t2i\venv\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\ai1\diffuser_t2i\venv\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "C:\ai1\diffuser_t2i\venv\lib\site-packages\diffusers\models\embeddings.py", line 569, in forward
return (latent + pos_embed).to(latent.dtype)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!System Info
python 3.10.11
accelerate 1.2.0.dev0
aiofiles 23.2.1
annotated-types 0.7.0
anyio 4.7.0
bitsandbytes 0.45.0
certifi 2024.12.14
charset-normalizer 3.4.1
click 8.1.8
colorama 0.4.6
diffusers 0.33.0.dev0
einops 0.8.0
exceptiongroup 1.2.2
fastapi 0.115.6
ffmpy 0.5.0
filelock 3.16.1
fsspec 2024.12.0
gguf 0.13.0
gradio 5.9.1
gradio_client 1.5.2
h11 0.14.0
httpcore 1.0.7
httpx 0.28.1
huggingface-hub 0.25.2
idna 3.10
imageio 2.36.1
imageio-ffmpeg 0.5.1
importlib_metadata 8.5.0
Jinja2 3.1.5
markdown-it-py 3.0.0
MarkupSafe 2.1.5
mdurl 0.1.2
mpmath 1.3.0
networkx 3.4.2
ninja 1.11.1.3
numpy 2.2.1
opencv-python 4.10.0.84
optimum-quanto 0.2.6.dev0
orjson 3.10.13
packaging 24.2
pandas 2.2.3
patch-conv 0.0.1b0
pillow 11.1.0
pip 23.0.1
protobuf 5.29.2
psutil 6.1.1
pydantic 2.10.4
pydantic_core 2.27.2
pydub 0.25.1
Pygments 2.18.0
python-dateutil 2.9.0.post0
python-multipart 0.0.20
pytz 2024.2
PyYAML 6.0.2
regex 2024.11.6
requests 2.32.3
rich 13.9.4
ruff 0.8.6
safehttpx 0.1.6
safetensors 0.5.0
semantic-version 2.10.0
sentencepiece 0.2.0
setuptools 65.5.0
shellingham 1.5.4
six 1.17.0
sniffio 1.3.1
starlette 0.41.3
sympy 1.13.1
tokenizers 0.21.0
tomlkit 0.13.2
torch 2.5.1+cu124
torchao 0.7.0
torchvision 0.20.1+cu124
tqdm 4.67.1
transformers 4.47.1
typer 0.15.1
typing_extensions 4.12.2
tzdata 2024.2
urllib3 2.3.0
uvicorn 0.34.0
websockets 14.1
wheel 0.45.1
zipp 3.21.0
Who can help?
No response