mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-30 04:22:02 +00:00
bugfix
This commit is contained in:
@@ -149,8 +149,8 @@ def main():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
|
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
|
||||||
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
|
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
|
||||||
parser.add_argument('--model_tag', type=str, default=None, help='optional model tag for the output directory name')
|
parser.add_argument('--model-tag', type=str, default=None, help='optional model tag for the output directory name')
|
||||||
parser.add_argument('--model_step', type=str, default=None, help='optional model step for the output directory name')
|
parser.add_argument('--step', type=str, default=None, help='optional model step for the output directory name')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# distributed / precision setup
|
# distributed / precision setup
|
||||||
@@ -168,7 +168,7 @@ def main():
|
|||||||
model_slug = hf_path.replace("/", "-") # for the output csv file
|
model_slug = hf_path.replace("/", "-") # for the output csv file
|
||||||
else:
|
else:
|
||||||
# load a local model from the file system
|
# load a local model from the file system
|
||||||
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.model_step)
|
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||||
model_name = f"base_model (step {meta['step']})" # just for logging
|
model_name = f"base_model (step {meta['step']})" # just for logging
|
||||||
model_slug = f"base_model_{meta['step']:06d}" # for the output csv file
|
model_slug = f"base_model_{meta['step']:06d}" # for the output csv file
|
||||||
|
|
||||||
|
|||||||
@@ -31,6 +31,8 @@ from tasks.gsm8k import GSM8K
|
|||||||
# RL hyperparameters
|
# RL hyperparameters
|
||||||
run = "dummy" # wandb run name
|
run = "dummy" # wandb run name
|
||||||
source = "sft" # mid|sft
|
source = "sft" # mid|sft
|
||||||
|
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||||
|
step = None # step to load the model from (base model or midtrained model)
|
||||||
dtype = "bfloat16"
|
dtype = "bfloat16"
|
||||||
device_batch_size = 8 # no forward pass will go above this to not OOM
|
device_batch_size = 8 # no forward pass will go above this to not OOM
|
||||||
examples_per_step = 16 # in total and across all ranks (note: examples, not samples/completions!)
|
examples_per_step = 16 # in total and across all ranks (note: examples, not samples/completions!)
|
||||||
@@ -64,7 +66,7 @@ use_dummy_wandb = run == "dummy" or not master_process
|
|||||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=run, config=user_config)
|
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=run, config=user_config)
|
||||||
|
|
||||||
# Init model and tokenizer
|
# Init model and tokenizer
|
||||||
model, tokenizer, meta = load_model(source, device, phase="eval")
|
model, tokenizer, meta = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
|
||||||
engine = Engine(model, tokenizer) # for sampling rollouts
|
engine = Engine(model, tokenizer) # for sampling rollouts
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -307,8 +309,8 @@ for step in range(num_steps):
|
|||||||
if master_process and ((step > 0 and step % save_every == 0) or step == num_steps - 1):
|
if master_process and ((step > 0 and step % save_every == 0) or step == num_steps - 1):
|
||||||
base_dir = get_base_dir()
|
base_dir = get_base_dir()
|
||||||
depth = model.config.n_layer
|
depth = model.config.n_layer
|
||||||
model_tag = f"d{depth}" # base the model tag on the depth of the base model
|
output_dirname = model_tag if model_tag else f"d{depth}" # base the model tag on the depth of the base model
|
||||||
checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", model_tag)
|
checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", output_dirname)
|
||||||
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
checkpoint_dir,
|
checkpoint_dir,
|
||||||
|
|||||||
Reference in New Issue
Block a user