mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-02 20:43:38 +00:00
[v1] add batch generator (#9744)
This commit is contained in:
@@ -16,7 +16,7 @@ import os
|
||||
from dataclasses import dataclass, field
|
||||
from uuid import uuid4
|
||||
|
||||
from .arg_utils import PluginConfig, get_plugin_config
|
||||
from .arg_utils import BatchingStrategy, PluginConfig, get_plugin_config
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -29,18 +29,30 @@ class TrainingArguments:
|
||||
default=1,
|
||||
metadata={"help": "Micro batch size for training."},
|
||||
)
|
||||
global_batch_size: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Global batch size for training."},
|
||||
global_batch_size: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Global batch size for training, default to DP size * micro batch size."},
|
||||
)
|
||||
learning_rate: float = field(
|
||||
default=1e-4,
|
||||
metadata={"help": "Learning rate for training."},
|
||||
)
|
||||
cutoff_len: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "Maximum sequence length for training."},
|
||||
)
|
||||
bf16: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use bf16 for training."},
|
||||
)
|
||||
batching_strategy: BatchingStrategy = field(
|
||||
default=BatchingStrategy.NORMAL,
|
||||
metadata={"help": "Batching strategy for training."},
|
||||
)
|
||||
batching_workers: int = field(
|
||||
default=16,
|
||||
metadata={"help": "Number of workers for batching."},
|
||||
)
|
||||
dist_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Distribution configuration for training."},
|
||||
|
||||
Reference in New Issue
Block a user