fix ppo trainer save zero3 model

accelerator.get_state_dict(ds_model) should be called at all ranks


Former-commit-id: 3a0f60f0aa072531e4ae5819ec00c8fa42aa0913
This commit is contained in:
hiyouga
2024-06-07 05:14:19 +08:00
parent 8692796c9b
commit b0e5a76f4c
2 changed files with 22 additions and 10 deletions

View File

@@ -10,12 +10,15 @@ from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_a
if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer
if is_jieba_available():
import jieba # type: ignore
if is_nltk_available():
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
if is_rouge_available():
from rouge_chinese import Rouge