Conversation
|
Hi, I think we could gain more document or comments for the code addition. |
There was a problem hiding this comment.
Pull Request Overview
This PR adds image generation capabilities to the BailingMM2 model by integrating diffusion-based text-to-image generation. The changes enable both standalone image generation and multi-modal inference (text, image, video, audio) in a unified test framework.
- Adds image generation functionality with diffusion model integration
- Refactors the test file to support multi-modal inference examples
- Updates model components to handle image generation parameters and workflows
Reviewed Changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| test_infer_gen_image.py | Completely rewritten to demonstrate image generation and multi-modal inference capabilities |
| modeling_bailingmm2.py | Adds image generation methods, diffusion model loading, and condition embedding generation |
| processing_bailingmm2.py | Adds image generation input processing for reference images |
| modeling_bailing_moe_v2.py | Updates forward method to handle image/audio masks for generation |
| image_processing_bailingmm2.py | Updates VideoInput import source |
| bailingmm_utils.py | Adds VideoInput type definition and aspect ratio processing utilities |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| List["np.ndarray"], | ||
| List["torch.Tensor"], | ||
| List[List["Image.Image"]], | ||
| List[List["np.ndarrray"]], |
There was a problem hiding this comment.
There's a typo in 'np.ndarrray' - it should be 'np.ndarray' with only one 'r'.
| List[List["np.ndarrray"]], | |
| List[List["np.ndarray"]], |
| return output_text | ||
|
|
||
| if __name__ == '__main__': | ||
| model_name_or_path = "/nativemm/share/cpfs/weilong.cwl/checkpoints/Ming_Flash_2.0_sft1_merged" |
There was a problem hiding this comment.
Hard-coded file paths make the code non-portable. Consider using environment variables or configuration files to specify these paths.
| model_name_or_path = "/nativemm/share/cpfs/weilong.cwl/checkpoints/Ming_Flash_2.0_sft1_merged" | |
| model_name_or_path = os.environ.get( | |
| "MODEL_NAME_OR_PATH", | |
| "/nativemm/share/cpfs/weilong.cwl/checkpoints/Ming_Flash_2.0_sft1_merged" | |
| ) |
| attn_implementation="flash_attention_2", | ||
| load_image_gen=True, | ||
| ).to("cuda") | ||
| vision_path = "/input/sunyunxiao.syx/assets/" |
There was a problem hiding this comment.
Hard-coded file paths make the code non-portable. Consider using environment variables or configuration files to specify these paths.
| vision_path = "/input/sunyunxiao.syx/assets/" | |
| vision_path = os.environ.get("VISION_PATH", "/input/sunyunxiao.syx/assets/") |
| print("image_gen_seed: ", image_gen_seed) | ||
| print("image_gen_steps: ", image_gen_steps) | ||
| print("image_gen_height: ", image_gen_height) | ||
| print("image_gen_width: ", image_gen_width) |
There was a problem hiding this comment.
Debug print statements should be removed from production code or replaced with proper logging using the logging module.
| print("image_gen_seed: ", image_gen_seed) | |
| print("image_gen_steps: ", image_gen_steps) | |
| print("image_gen_height: ", image_gen_height) | |
| print("image_gen_width: ", image_gen_width) | |
| logging.info(f"image_gen_seed: {image_gen_seed}") | |
| logging.info(f"image_gen_steps: {image_gen_steps}") | |
| logging.info(f"image_gen_height: {image_gen_height}") | |
| logging.info(f"image_gen_width: {image_gen_width}") |
| with safe_open(safetensors_path, framework="pt") as f: | ||
| temp_state_dict = {key: f.get_tensor(key) for key in f.keys()} | ||
| self.query_tokens_dict = nn.ParameterDict() | ||
| #self.img_gen_scales = [4, 8, 16] |
There was a problem hiding this comment.
Commented-out code should be removed to improve code clarity. If this code might be needed later, consider using version control instead of comments.
| #self.img_gen_scales = [4, 8, 16] |
| for scale in self.img_gen_scales: | ||
| num_tokens = scale * scale | ||
| scale_name = f"{scale}x{scale}" | ||
| #weights = temp_state_dict[f"query_tokens_dict.{scale_name}"] |
There was a problem hiding this comment.
Commented-out code should be removed to improve code clarity. If this code might be needed later, consider using version control instead of comments.
| #weights = temp_state_dict[f"query_tokens_dict.{scale_name}"] |
| else: | ||
| raise ValueError("unsupported dit type: {}".format(dit_type)) | ||
| self.diffusion_loss.to(device) | ||
| #self.norm_query_embeds = True |
There was a problem hiding this comment.
Commented-out code should be removed to improve code clarity. If this code might be needed later, consider using version control instead of comments.
| #self.norm_query_embeds = True |
| ref_pil = ref_pil.convert("RGB") | ||
| closest_size, resize_size = process_ratio(ori_h=ref_pil.size[1], ori_w=ref_pil.size[0]) | ||
| ref_pil = torchvision.transforms.functional.resize(ref_pil, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR) | ||
| ref_pil = torchvision.transforms.functional.center_crop(ref_pil, closest_size) | ||
| ref_tensor = ((torchvision.transforms.functional.to_tensor(ref_pil) - 0.5) * 2.0).unsqueeze(0) | ||
| image_gen_inputs = { | ||
| "image_gen_pixel_values_reference": ref_tensor, | ||
| "image_gen_height": torch.LongTensor([ref_pil.size[1]]), | ||
| "image_gen_width": torch.LongTensor([ref_pil.size[0]]), | ||
| } |
There was a problem hiding this comment.
This image processing logic is complex and could be extracted into a separate helper method for better readability and reusability.
| ref_pil = ref_pil.convert("RGB") | |
| closest_size, resize_size = process_ratio(ori_h=ref_pil.size[1], ori_w=ref_pil.size[0]) | |
| ref_pil = torchvision.transforms.functional.resize(ref_pil, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR) | |
| ref_pil = torchvision.transforms.functional.center_crop(ref_pil, closest_size) | |
| ref_tensor = ((torchvision.transforms.functional.to_tensor(ref_pil) - 0.5) * 2.0).unsqueeze(0) | |
| image_gen_inputs = { | |
| "image_gen_pixel_values_reference": ref_tensor, | |
| "image_gen_height": torch.LongTensor([ref_pil.size[1]]), | |
| "image_gen_width": torch.LongTensor([ref_pil.size[0]]), | |
| } | |
| image_gen_inputs = self._process_reference_image(ref_pil) |
| inputs_embeds.size, | ||
| ) | ||
| batch_size, seq_length = inputs_embeds.shape[:2] | ||
| #raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
There was a problem hiding this comment.
Commented-out error handling should be removed or replaced with proper logic. If this validation is no longer needed, remove the comment entirely.
| #raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") | |
| raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
No description provided.