support report custom args

Former-commit-id: d41254c40a1c5cacf9377096adb27efa9bdb79ea
This commit is contained in:
hiyouga
2024-12-19 14:57:09 +00:00
parent adff887659
commit a897d46049
20 changed files with 164 additions and 124 deletions

View File

@@ -42,10 +42,13 @@ if is_safetensors_available():
from safetensors import safe_open
from safetensors.torch import save_file
if TYPE_CHECKING:
from transformers import TrainerControl, TrainerState, TrainingArguments
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = logging.get_logger(__name__)
@@ -101,9 +104,6 @@ class FixValueHeadModelCallback(TrainerCallback):
@override
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a checkpoint save.
"""
if args.should_save:
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
fix_valuehead_checkpoint(
@@ -138,9 +138,6 @@ class PissaConvertCallback(TrainerCallback):
@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of training.
"""
if args.should_save:
model = kwargs.pop("model")
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
@@ -348,3 +345,51 @@ class LogCallback(TrainerCallback):
remaining_time=self.remaining_time,
)
self.thread_pool.submit(self._write_log, args.output_dir, logs)
class ReporterCallback(TrainerCallback):
r"""
A callback for reporting training status to external logger.
"""
def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
self.model_args = model_args
self.data_args = data_args
self.finetuning_args = finetuning_args
self.generating_args = generating_args
os.environ["WANDB_PROJECT"] = os.getenv("WANDB_PROJECT", "llamafactory")
@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if not state.is_world_process_zero:
return
if "wandb" in args.report_to:
import wandb
wandb.config.update(
{
"model_args": self.model_args.to_dict(),
"data_args": self.data_args.to_dict(),
"finetuning_args": self.finetuning_args.to_dict(),
"generating_args": self.generating_args.to_dict(),
}
)
if self.finetuning_args.use_swanlab:
import swanlab
swanlab.config.update(
{
"model_args": self.model_args.to_dict(),
"data_args": self.data_args.to_dict(),
"finetuning_args": self.finetuning_args.to_dict(),
"generating_args": self.generating_args.to_dict(),
}
)