support extra args in llamaboard

Former-commit-id: da0a5fd612e2214cc4bcb72516efd768fbe18a20
This commit is contained in:
hiyouga
2024-10-30 08:55:54 +00:00
parent c8a1fb99bf
commit aeeee9d4b5
4 changed files with 27 additions and 37 deletions

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from copy import deepcopy
from subprocess import Popen, TimeoutExpired
@@ -78,6 +79,11 @@ class Runner:
if not get("train.output_dir"):
return ALERTS["err_no_output_dir"][lang]
try:
json.loads(get("train.extra_args"))
except json.JSONDecodeError:
return ALERTS["err_json_schema"][lang]
stage = TRAINING_STAGES[get("train.training_stage")]
if stage == "ppo" and not get("train.reward_model"):
return ALERTS["err_no_reward_model"][lang]
@@ -130,7 +136,6 @@ class Runner:
save_steps=get("train.save_steps"),
warmup_steps=get("train.warmup_steps"),
neftune_noise_alpha=get("train.neftune_alpha") or None,
optim=get("train.optim"),
packing=get("train.packing") or get("train.neat_packing"),
neat_packing=get("train.neat_packing"),
train_on_prompt=get("train.train_on_prompt"),
@@ -148,6 +153,7 @@ class Runner:
plot_loss=True,
ddp_timeout=180000000,
include_num_input_tokens_seen=True,
**json.loads(get("train.extra_args")),
)
# checkpoints