mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-30 04:22:02 +00:00
recover the ability to load old checkpoints by patching the lambdas if they don't exist in checkpoints
This commit is contained in:
@@ -20,6 +20,16 @@ def log0(message):
|
||||
if int(os.environ.get('RANK', 0)) == 0:
|
||||
logger.info(message)
|
||||
|
||||
def _patch_missing_keys(model_data, model_config):
|
||||
"""Add default values for new parameters that may be missing in old checkpoints."""
|
||||
n_layer = model_config.n_layer
|
||||
# resid_lambdas defaults to 1.0 (identity scaling)
|
||||
if "resid_lambdas" not in model_data:
|
||||
model_data["resid_lambdas"] = torch.ones(n_layer)
|
||||
# x0_lambdas defaults to 0.0 (disabled)
|
||||
if "x0_lambdas" not in model_data:
|
||||
model_data["x0_lambdas"] = torch.zeros(n_layer)
|
||||
|
||||
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
|
||||
if rank == 0:
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
@@ -76,6 +86,7 @@ def build_model(checkpoint_dir, step, device, phase):
|
||||
model_config_kwargs = meta_data["model_config"]
|
||||
log0(f"Building model with config: {model_config_kwargs}")
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
_patch_missing_keys(model_data, model_config)
|
||||
with torch.device("meta"):
|
||||
model = GPT(model_config)
|
||||
# Load the model state
|
||||
|
||||
Reference in New Issue
Block a user