Skip to content

Commit 60ee574

Browse files
authored
retune lowVramPatch VRAM accounting (#11173)
In the lowvram case, this now does its math in the model dtype in the post de-quantization domain. Account for that. The patching was also put back on the compute stream getting it off-peak so relax the MATH_FACTOR to only x2 so get out of the worst-case assumption of everything peaking at once.
1 parent 8e889c5 commit 60ee574

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

comfy/model_patcher.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,14 +132,14 @@ def __init__(self, key, patches, convert_func=None, set_func=None):
132132
def __call__(self, weight):
133133
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
134134

135-
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
136-
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
135+
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
137136

138137
def low_vram_patch_estimate_vram(model, key):
139138
weight, set_func, convert_func = get_key_weight(model, key)
140139
if weight is None:
141140
return 0
142-
return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
141+
model_dtype = getattr(model, "manual_cast_dtype", torch.float32)
142+
return weight.numel() * model_dtype.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
143143

144144
def get_key_weight(model, key):
145145
set_func = None

0 commit comments

Comments
 (0)