refactor dataset_attr, add eos in pt, fix #757
Former-commit-id: 0feec9a830b917b36686b61938a66e842eccf930
This commit is contained in:
@@ -6,7 +6,7 @@ import gradio as gr
|
||||
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
|
||||
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
||||
|
||||
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, DATASET_STAGE_MAP
|
||||
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES
|
||||
|
||||
|
||||
DEFAULT_CACHE_DIR = "cache"
|
||||
@@ -78,11 +78,10 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
|
||||
def list_dataset(dataset_dir: Optional[str] = None, stage: Optional[str] = None) -> Dict[str, Any]:
|
||||
def list_dataset(
|
||||
dataset_dir: Optional[str] = None, training_stage: Optional[str] = list(TRAINING_STAGES.keys())[0]
|
||||
) -> Dict[str, Any]:
|
||||
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
|
||||
if stage:
|
||||
dataset_stage = DATASET_STAGE_MAP[stage]
|
||||
dataset_info = {key: value for key, value in dataset_info.items()
|
||||
if ("stage" not in value) or value["stage"] == dataset_stage}
|
||||
|
||||
return gr.update(value=[], choices=list(dataset_info.keys()))
|
||||
ranking = TRAINING_STAGES[training_stage] in ["rm", "dpo"]
|
||||
datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
|
||||
return gr.update(value=[], choices=datasets)
|
||||
|
||||
Reference in New Issue
Block a user