mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-02 08:33:38 +00:00
allow non-packing pretraining
Former-commit-id: 3fee5cc5a3db9ce874ad90f2500ec092d904bd4e
This commit is contained in:
@@ -104,10 +104,12 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
|
||||
return {}
|
||||
|
||||
|
||||
def list_dataset(
|
||||
dataset_dir: Optional[str] = None, training_stage: Optional[str] = list(TRAINING_STAGES.keys())[0]
|
||||
) -> Dict[str, Any]:
|
||||
def list_dataset(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Dict[str, Any]:
|
||||
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
|
||||
ranking = TRAINING_STAGES[training_stage] in ["rm", "dpo"]
|
||||
datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
|
||||
return gr.update(value=[], choices=datasets)
|
||||
|
||||
|
||||
def autoset_packing(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Dict[str, Any]:
|
||||
return gr.update(value=(TRAINING_STAGES[training_stage] == "pt"))
|
||||
|
||||
Reference in New Issue
Block a user