Former-commit-id: bb57478366a70a0871af30ab31c890f471e27ff4
This commit is contained in:
hiyouga
2024-06-25 01:15:19 +08:00
parent c6b17ebc20
commit 135bfbf7c1
8 changed files with 23 additions and 18 deletions

View File

@@ -216,7 +216,7 @@ class ToolFormatter(Formatter):
self._tool_formatter = glm4_tool_formatter
self._tool_extractor = glm4_tool_extractor
else:
raise ValueError("Tool format was not found.")
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")

View File

@@ -387,8 +387,9 @@ def get_template_and_fix_tokenizer(
template = TEMPLATES.get(name, None)
if template is None:
raise ValueError("Template {} does not exist.".format(name))
if tool_format:
if tool_format is not None:
logger.info("Using tool format: {}.".format(tool_format))
template.format_tools = ToolFormatter(tool_format=tool_format)
stop_words = template.stop_words
@@ -625,7 +626,6 @@ _register_template(
_register_template(
name="empty",
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True,
)

View File

@@ -29,10 +29,6 @@ class DataArguments:
default=None,
metadata={"help": "Which template to use for constructing prompts in training and inference."},
)
tool_format: Optional[str] = field(
default=None,
metadata={"help": "Specifies the tool format template for function calling ."},
)
dataset: Optional[str] = field(
default=None,
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."},
@@ -105,6 +101,10 @@ class DataArguments:
"help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training."
},
)
tool_format: Optional[str] = field(
default=None,
metadata={"help": "Tool format to use for constructing function calling examples."},
)
tokenized_path: Optional[str] = field(
default=None,
metadata={"help": "Path to save or load the tokenized datasets."},

View File

@@ -291,7 +291,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Column(scale=1):
loss_viewer = gr.Plot()
input_elems.update({output_dir, config_path, device_count, ds_stage, ds_offload})
input_elems.update({output_dir, config_path, ds_stage, ds_offload})
elem_dict.update(
dict(
cmd_preview_btn=cmd_preview_btn,

View File

@@ -306,7 +306,7 @@ class Runner:
def _form_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]:
config_dict = {}
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path", "train.device_count"]
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"]
for elem, value in data.items():
elem_id = self.manager.get_id_by_elem(elem)
if elem_id not in skip_ids: