support extra args in llamaboard
Former-commit-id: da0a5fd612e2214cc4bcb72516efd768fbe18a20
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from subprocess import Popen, TimeoutExpired
|
||||
@@ -78,6 +79,11 @@ class Runner:
|
||||
if not get("train.output_dir"):
|
||||
return ALERTS["err_no_output_dir"][lang]
|
||||
|
||||
try:
|
||||
json.loads(get("train.extra_args"))
|
||||
except json.JSONDecodeError:
|
||||
return ALERTS["err_json_schema"][lang]
|
||||
|
||||
stage = TRAINING_STAGES[get("train.training_stage")]
|
||||
if stage == "ppo" and not get("train.reward_model"):
|
||||
return ALERTS["err_no_reward_model"][lang]
|
||||
@@ -130,7 +136,6 @@ class Runner:
|
||||
save_steps=get("train.save_steps"),
|
||||
warmup_steps=get("train.warmup_steps"),
|
||||
neftune_noise_alpha=get("train.neftune_alpha") or None,
|
||||
optim=get("train.optim"),
|
||||
packing=get("train.packing") or get("train.neat_packing"),
|
||||
neat_packing=get("train.neat_packing"),
|
||||
train_on_prompt=get("train.train_on_prompt"),
|
||||
@@ -148,6 +153,7 @@ class Runner:
|
||||
plot_loss=True,
|
||||
ddp_timeout=180000000,
|
||||
include_num_input_tokens_seen=True,
|
||||
**json.loads(get("train.extra_args")),
|
||||
)
|
||||
|
||||
# checkpoints
|
||||
|
||||
Reference in New Issue
Block a user