support report custom args
Former-commit-id: d41254c40a1c5cacf9377096adb27efa9bdb79ea
This commit is contained in:
@@ -42,10 +42,13 @@ if is_safetensors_available():
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerControl, TrainerState, TrainingArguments
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -101,9 +104,6 @@ class FixValueHeadModelCallback(TrainerCallback):
|
||||
|
||||
@override
|
||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after a checkpoint save.
|
||||
"""
|
||||
if args.should_save:
|
||||
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
|
||||
fix_valuehead_checkpoint(
|
||||
@@ -138,9 +138,6 @@ class PissaConvertCallback(TrainerCallback):
|
||||
|
||||
@override
|
||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the beginning of training.
|
||||
"""
|
||||
if args.should_save:
|
||||
model = kwargs.pop("model")
|
||||
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
|
||||
@@ -348,3 +345,51 @@ class LogCallback(TrainerCallback):
|
||||
remaining_time=self.remaining_time,
|
||||
)
|
||||
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
||||
|
||||
|
||||
class ReporterCallback(TrainerCallback):
|
||||
r"""
|
||||
A callback for reporting training status to external logger.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
) -> None:
|
||||
self.model_args = model_args
|
||||
self.data_args = data_args
|
||||
self.finetuning_args = finetuning_args
|
||||
self.generating_args = generating_args
|
||||
os.environ["WANDB_PROJECT"] = os.getenv("WANDB_PROJECT", "llamafactory")
|
||||
|
||||
@override
|
||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
if not state.is_world_process_zero:
|
||||
return
|
||||
|
||||
if "wandb" in args.report_to:
|
||||
import wandb
|
||||
|
||||
wandb.config.update(
|
||||
{
|
||||
"model_args": self.model_args.to_dict(),
|
||||
"data_args": self.data_args.to_dict(),
|
||||
"finetuning_args": self.finetuning_args.to_dict(),
|
||||
"generating_args": self.generating_args.to_dict(),
|
||||
}
|
||||
)
|
||||
|
||||
if self.finetuning_args.use_swanlab:
|
||||
import swanlab
|
||||
|
||||
swanlab.config.update(
|
||||
{
|
||||
"model_args": self.model_args.to_dict(),
|
||||
"data_args": self.data_args.to_dict(),
|
||||
"finetuning_args": self.finetuning_args.to_dict(),
|
||||
"generating_args": self.generating_args.to_dict(),
|
||||
}
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user