support BLOOM models

Former-commit-id: 1314b6ea39a01aa8ac325e1d875ac013d43aec45
This commit is contained in:
hiyouga
2023-05-31 16:54:06 +08:00
parent 181c776b58
commit 693c049eac
16 changed files with 134 additions and 90 deletions

View File

@@ -22,7 +22,7 @@ logger = get_logger(__name__)
@dataclass
class ComputeMetrics:
r"""
Wraps the tokenizer into metric functions, used in Seq2SeqTrainerForLLaMA.
Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/0c2806fea82683349194e21996dd6b3acc3c265b/ptuning/main.py#L307
"""
@@ -62,7 +62,7 @@ class ComputeMetrics:
return {k: float(np.mean(v)) for k, v in score_dict.items()}
class Seq2SeqTrainerForLLaMA(PeftTrainer):
class Seq2SeqPeftTrainer(PeftTrainer):
r"""
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
"""