[breaking] bump transformers to 4.45.0 & improve ci (#7746)
* update ci * fix * fix * fix * fix * fix
This commit is contained in:
@@ -188,7 +188,7 @@ class LogCallback(TrainerCallback):
|
||||
self.webui_mode = is_env_enabled("LLAMABOARD_ENABLED")
|
||||
if self.webui_mode and not use_ray():
|
||||
signal.signal(signal.SIGABRT, self._set_abort)
|
||||
self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
|
||||
self.logger_handler = logging.LoggerHandler(os.getenv("LLAMABOARD_WORKDIR"))
|
||||
logging.add_handler(self.logger_handler)
|
||||
transformers.logging.add_handler(self.logger_handler)
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.misc import calculate_tps, get_logits_processor
|
||||
from ...extras.misc import calculate_tps
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ..trainer_utils import create_modelcard_and_push
|
||||
@@ -82,7 +82,6 @@ def run_sft(
|
||||
gen_kwargs = generating_args.to_dict(obey_generation_config=True)
|
||||
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
|
||||
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||
gen_kwargs["logits_processor"] = get_logits_processor()
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = CustomSeq2SeqTrainer(
|
||||
|
||||
Reference in New Issue
Block a user