Skip to content

how to use quantization and device_map=balance to run qwen-image on kaggle T4 * 2 #12719

@chaowenguo

Description

@chaowenguo
!python3 -m pip install -U diffusers peft bitsandbytes protobuf import diffusers, torch, math qwen = diffusers.QwenImagePipeline.from_pretrained('Qwen/Qwen-Image', quantization_config=diffusers.PipelineQuantizationConfig(quant_backend='bitsandbytes_4bit', quant_kwargs={'load_in_4bit':True, 'bnb_4bit_quant_type':'nf4', 'bnb_4bit_compute_dtype':torch.float16}, components_to_quantize=['transformer', 'text_encoder']), torch_dtype=torch.float16, device_map='balanced') print(qwen.hf_device_map) qwen.scheduler = diffusers.FlowMatchEulerDiscreteScheduler.from_config({'base_image_seq_len':256, 'base_shift':math.log(3), 'invert_sigmas':False, 'max_image_seq_len':8192, 'max_shift':math.log(3), 'num_train_timesteps':1000, 'shift':1, 'shift_terminal':None, 'stochastic_sampling':False, 'time_shift_type':'exponential', 'use_beta_sigmas':False, 'use_dynamic_shifting':True, 'use_exponential_sigmas':False, 'use_karras_sigmas':False}) qwen.load_lora_weights('lightx2v/Qwen-Image-Lightning', weight_name='Qwen-Image-Lightning-4steps-V2.0.safetensors', adapter_name='lightning') qwen.set_adapters('lightning', adapter_weights=1) qwen(prompt='a beautiful girl', height=1280, width=720, num_inference_steps=4, true_cfg_scale=1).images[0].save('a.png')

WARNING:accelerate.big_modeling:Some parameters are on the meta device because they were offloaded to the cpu.
{'text_encoder': 'cpu', 'vae': 0} where is the transformer ?

NotImplementedError: Cannot copy out of meta tensor; no data!

I want to ask how to make the above code work in kaggle. why 16G * 2 vram still not enough to run q4 quantization qwen-image? I want to take full advantage of 2 gpu. Do I need max_memory?

full error logs:
/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
117
118 return decorate_context

/usr/local/lib/python3.11/dist-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage.py in call(self, prompt, negative_prompt, true_cfg_scale, height, width, num_inference_steps, sigmas, guidance_scale, num_images_per_prompt, generator, latents, prompt_embeds, prompt_embeds_mask, negative_prompt_embeds, negative_prompt_embeds_mask, output_type, return_dict, attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length)
566 )
567 do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
--> 568 prompt_embeds, prompt_embeds_mask = self.encode_prompt(
569 prompt=prompt,
570 prompt_embeds=prompt_embeds,

/usr/local/lib/python3.11/dist-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage.py in encode_prompt(self, prompt, device, num_images_per_prompt, prompt_embeds, prompt_embeds_mask, max_sequence_length)
252
253 if prompt_embeds is None:
--> 254 prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
255
256 prompt_embeds = prompt_embeds[:, :max_sequence_length]

/usr/local/lib/python3.11/dist-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage.py in _get_qwen_prompt_embeds(self, prompt, device, dtype)
203 txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
204 ).to(device)
--> 205 encoder_hidden_states = self.text_encoder(
206 input_ids=txt_tokens.input_ids,
207 attention_mask=txt_tokens.attention_mask,

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
1740
1741 # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1748 or _global_backward_pre_hooks or _global_backward_hooks
1749 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750 return forward_call(*args, **kwargs)
1751
1752 result = None

/usr/local/lib/python3.11/dist-packages/accelerate/hooks.py in new_forward(module, *args, **kwargs)
173 output = module._old_forward(*args, **kwargs)
174 else:
--> 175 output = module._old_forward(*args, **kwargs)
176 return module._hf_hook.post_forward(module, output)
177

/usr/local/lib/python3.11/dist-packages/transformers/utils/generic.py in wrapper(self, *args, **kwargs)
941
942 try:
--> 943 output = func(self, *args, **kwargs)
944 if is_requested_to_return_tuple or (is_configured_to_return_tuple and is_top_level_module):
945 output = output.to_tuple()

/usr/local/lib/python3.11/dist-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py in forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw, rope_deltas, cache_position, second_per_grid_ts, **kwargs)
1507 )
1508
-> 1509 outputs = self.model(
1510 input_ids=input_ids,
1511 pixel_values=pixel_values,

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
1740
1741 # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1748 or _global_backward_pre_hooks or _global_backward_hooks
1749 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750 return forward_call(*args, **kwargs)
1751
1752 result = None

/usr/local/lib/python3.11/dist-packages/accelerate/hooks.py in new_forward(module, *args, **kwargs)
173 output = module._old_forward(*args, **kwargs)
174 else:
--> 175 output = module._old_forward(*args, **kwargs)
176 return module._hf_hook.post_forward(module, output)
177

/usr/local/lib/python3.11/dist-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py in forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw, rope_deltas, cache_position, second_per_grid_ts, **kwargs)
1328 position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
1329
-> 1330 outputs = self.language_model(
1331 input_ids=None,
1332 position_ids=position_ids,

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
1740
1741 # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1748 or _global_backward_pre_hooks or _global_backward_hooks
1749 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750 return forward_call(*args, **kwargs)
1751
1752 result = None

/usr/local/lib/python3.11/dist-packages/accelerate/hooks.py in new_forward(module, *args, **kwargs)
173 output = module._old_forward(*args, **kwargs)
174 else:
--> 175 output = module._old_forward(*args, **kwargs)
176 return module._hf_hook.post_forward(module, output)
177

/usr/local/lib/python3.11/dist-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py in forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **kwargs)
918 all_hidden_states += (hidden_states,)
919
--> 920 layer_outputs = decoder_layer(
921 hidden_states,
922 attention_mask=causal_mask_mapping[decoder_layer.attention_type],

/usr/local/lib/python3.11/dist-packages/transformers/modeling_layers.py in call(self, *args, **kwargs)
81
82 return self._gradient_checkpointing_func(partial(super().call, **kwargs), *args)
---> 83 return super().call(*args, **kwargs)

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
1740
1741 # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1748 or _global_backward_pre_hooks or _global_backward_hooks
1749 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750 return forward_call(*args, **kwargs)
1751
1752 result = None

/usr/local/lib/python3.11/dist-packages/accelerate/hooks.py in new_forward(module, *args, **kwargs)
173 output = module._old_forward(*args, **kwargs)
174 else:
--> 175 output = module._old_forward(*args, **kwargs)
176 return module._hf_hook.post_forward(module, output)
177

/usr/local/lib/python3.11/dist-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py in forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)
775
776 # Self Attention
--> 777 hidden_states, self_attn_weights, present_key_value = self.self_attn(
778 hidden_states=hidden_states,
779 attention_mask=attention_mask,

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
1740
1741 # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1748 or _global_backward_pre_hooks or _global_backward_hooks
1749 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750 return forward_call(*args, **kwargs)
1751
1752 result = None

/usr/local/lib/python3.11/dist-packages/accelerate/hooks.py in new_forward(module, *args, **kwargs)
173 output = module._old_forward(*args, **kwargs)
174 else:
--> 175 output = module._old_forward(*args, **kwargs)
176 return module._hf_hook.post_forward(module, output)
177

/usr/local/lib/python3.11/dist-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py in forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)
681 bsz, q_len, _ = hidden_states.size()
682
--> 683 query_states = self.q_proj(hidden_states)
684 key_states = self.k_proj(hidden_states)
685 value_states = self.v_proj(hidden_states)

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
1740
1741 # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1748 or _global_backward_pre_hooks or _global_backward_hooks
1749 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750 return forward_call(*args, **kwargs)
1751
1752 result = None

/usr/local/lib/python3.11/dist-packages/accelerate/hooks.py in new_forward(module, *args, **kwargs)
168
169 def new_forward(module, *args, **kwargs):
--> 170 args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
171 if module._hf_hook.no_grad:
172 with torch.no_grad():

/usr/local/lib/python3.11/dist-packages/accelerate/hooks.py in pre_forward(self, module, *args, **kwargs)
358 self.tied_pointers_to_remove.add((value.data_ptr(), self.execution_device))
359
--> 360 set_module_tensor_to_device(
361 module,
362 name,

/usr/local/lib/python3.11/dist-packages/accelerate/utils/modeling.py in set_module_tensor_to_device(module, tensor_name, device, value, dtype, fp16_statistics, tied_params_map, non_blocking, clear_cache)
361 new_value.SCB = new_value.SCB.to("cpu")
362 else:
--> 363 new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(
364 device, non_blocking=non_blocking
365 )

/usr/local/lib/python3.11/dist-packages/bitsandbytes/nn/modules.py in to(self, *args, **kwargs)
338 else:
339 if self.quant_state is not None:
--> 340 self.quant_state.to(device)
341
342 new_param = Params4bit(

/usr/local/lib/python3.11/dist-packages/bitsandbytes/functional.py in to(self, device)
537 def to(self, device):
538 # make sure the quantization state is on the right device
--> 539 self.code = self.code.to(device)
540 self.absmax = self.absmax.to(device)
541 if self.nested:

NotImplementedError: Cannot copy out of meta tensor; no data!

@yiyixuxu @DN6

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions