diff --git a/src/voxcpm/core.py b/src/voxcpm/core.py index 81692c9c..84872989 100644 --- a/src/voxcpm/core.py +++ b/src/voxcpm/core.py @@ -45,14 +45,17 @@ def __init__( file=sys.stderr, ) - # If lora_weights_path is provided but no lora_config, create a default one + # If lora_weights_path is provided but no lora_config, load the saved + # lora_config.json (so r/alpha match the checkpoint); else use a default. if lora_weights_path is not None and lora_config is None: - lora_config = LoRAConfig( - enable_lm=True, - enable_dit=True, - enable_proj=False, - ) - print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}", file=sys.stderr) + cfg_path = os.path.join(lora_weights_path, "lora_config.json") + if os.path.isdir(lora_weights_path) and os.path.isfile(cfg_path): + with open(cfg_path, "r", encoding="utf-8") as f: + lora_config = LoRAConfig(**json.load(f)["lora_config"]) + print(f"Loaded LoRAConfig from: {cfg_path}", file=sys.stderr) + else: + lora_config = LoRAConfig(enable_lm=True, enable_dit=True, enable_proj=False) + print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}", file=sys.stderr) # Determine model type from config.json architecture field config_path = os.path.join(voxcpm_model_path, "config.json")