Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,14 @@ def __init__(self, key, patches, convert_func=None, set_func=None):
def __call__(self, weight):
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)

#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2

def low_vram_patch_estimate_vram(model, key):
weight, set_func, convert_func = get_key_weight(model, key)
if weight is None:
return 0
return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
model_dtype = getattr(model, "manual_cast_dtype", torch.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

manual_cast_dtype can be None, this code needs to either exclude that case or have an or torch.float32 on the end (getattr default only fires if the attr doesn't exist, not if it's None, so this PR is erroring for users)

return weight.numel() * model_dtype.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR

def get_key_weight(model, key):
set_func = None
Expand Down