From ae8fa7e0b452b0fb999cf86d14a33f704ebb8984 Mon Sep 17 00:00:00 2001 From: zhuxiaoxuhit Date: Fri, 5 Jun 2026 11:02:56 +0000 Subject: [PATCH] Fix load saved lora_config.json when loading LoRA weights from_pretrained(lora_weights_path=...) built a default r=8 LoRAConfig and crashed for checkpoints trained with other ranks (e.g. r=32). Load the checkpoint's lora_config.json so r/alpha match; fall back to the default. --- src/voxcpm/core.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/voxcpm/core.py b/src/voxcpm/core.py index 81692c9c..84872989 100644 --- a/src/voxcpm/core.py +++ b/src/voxcpm/core.py @@ -45,14 +45,17 @@ def __init__( file=sys.stderr, ) - # If lora_weights_path is provided but no lora_config, create a default one + # If lora_weights_path is provided but no lora_config, load the saved + # lora_config.json (so r/alpha match the checkpoint); else use a default. if lora_weights_path is not None and lora_config is None: - lora_config = LoRAConfig( - enable_lm=True, - enable_dit=True, - enable_proj=False, - ) - print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}", file=sys.stderr) + cfg_path = os.path.join(lora_weights_path, "lora_config.json") + if os.path.isdir(lora_weights_path) and os.path.isfile(cfg_path): + with open(cfg_path, "r", encoding="utf-8") as f: + lora_config = LoRAConfig(**json.load(f)["lora_config"]) + print(f"Loaded LoRAConfig from: {cfg_path}", file=sys.stderr) + else: + lora_config = LoRAConfig(enable_lm=True, enable_dit=True, enable_proj=False) + print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}", file=sys.stderr) # Determine model type from config.json architecture field config_path = os.path.join(voxcpm_model_path, "config.json")