11 Commits

Author SHA1 Message Date
hiyouga
fbd0584391 release v0.6.1
Former-commit-id: a59d823f554505b2e649e6e111b9dee8306d3ad8
2024-03-29 11:36:08 +08:00
hiyouga
50224b09cc update readme
Former-commit-id: 312d4f90784800dc8db4eaa7d908e6761115bc51
2024-03-28 22:02:32 +08:00
hiyouga
32dcc5a491 add project
Former-commit-id: 0418e9fecb2337b5d1b72e8358adb8aa10803c4b
2024-03-28 20:24:27 +08:00
hiyouga
9408366a36 fix #2982
Former-commit-id: e5e6a0c50c7a1c0052ed6b459450b9735ff2c9a1
2024-03-28 20:22:31 +08:00
hiyouga
f0e564beaa update readme
Former-commit-id: 6b634b5c2dbad827e8cc9850b8d7697c2056532a
2024-03-28 18:35:11 +08:00
hiyouga
14b75a0b93 fix #3010
Former-commit-id: a5e823ae75556eaa3b52ce7a887a6e7838a1eb5f
2024-03-28 18:31:17 +08:00
hiyouga
59e6ebf039 update trainers
Former-commit-id: d0dd6eefed0b86895ed00a7cafb331e5193db645
2024-03-28 18:16:27 +08:00
hoshi-hiyouga
dc540dfaa8 fix ds optimizer
Former-commit-id: 2675127070a1e7584e71039a11c1ebac54ddd1db
2024-03-26 23:39:56 +08:00
hiyouga
587e65e442 fix #2981
Former-commit-id: ede2a913856e52c0a96155705116528d3af15998
2024-03-26 17:53:04 +08:00
hiyouga
a916688723 fix bug
Former-commit-id: f513e1415cc3fe87f600318fba855d1286b6d007
2024-03-26 17:30:12 +08:00
hiyouga
3336422760 fix #2961
Former-commit-id: 616917bb3be7f71073b56ad8c7bc4e164b08b9b5
2024-03-26 17:26:14 +08:00
27 changed files with 306 additions and 271 deletions

View File

@@ -5,7 +5,7 @@
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
[![Citation](https://img.shields.io/badge/citation-26-green)](#projects-using-llama-factory)
[![Citation](https://img.shields.io/badge/citation-27-green)](#projects-using-llama-factory)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
@@ -72,14 +72,14 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
[24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See `examples/fsdp_qlora` for usage.
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. Try `loraplus_lr_ratio=16.0` to enable LoRA+ algorithm.
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See `examples/extras/loraplus` for usage.
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. Try `--use_galore` to use the memory-efficient optimizer.
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.)
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. See `examples/extras/galore` for usage.
<details><summary>Full Changelog</summary>
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.)
[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training.
[24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See `examples/extras/llama_pro` for usage.
@@ -451,7 +451,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
```
> [!TIP]
> Use `--adapter_name_or_path path_to_sft_checkpoint,path_to_ppo_checkpoint` to infer the fine-tuned model.
> Use `--adapter_name_or_path path_to_sft_checkpoint,path_to_ppo_checkpoint` to infer the fine-tuned model if `--create_new_adapter` was enabled.
> [!WARNING]
> Use `--per_device_train_batch_size=1` for LLaMA-2 models in fp16 PPO training.
@@ -482,7 +482,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
```
> [!TIP]
> Use `--adapter_name_or_path path_to_sft_checkpoint,path_to_dpo_checkpoint` to infer the fine-tuned model.
> Use `--adapter_name_or_path path_to_sft_checkpoint,path_to_dpo_checkpoint` to infer the fine-tuned model if `--create_new_adapter` was enabled.
### Distributed Training
@@ -570,7 +570,7 @@ deepspeed --num_gpus 8 src/train_bash.py \
### Merge LoRA weights and export model
```bash
CUDA_VISIBLE_DEVICES=0 python src/export_model.py \
CUDA_VISIBLE_DEVICES= python src/export_model.py \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \
@@ -586,7 +586,7 @@ CUDA_VISIBLE_DEVICES=0 python src/export_model.py \
> [!TIP]
> Use `--model_name_or_path path_to_export` solely to use the exported model.
>
> Use `--export_quantization_bit 4` and `--export_quantization_dataset data/c4_demo.json` to quantize the model with AutoGPTQ after merging the LoRA weights.
> Use `CUDA_VISIBLE_DEVICES=0`, `--export_quantization_bit 4` and `--export_quantization_dataset data/c4_demo.json` to quantize the model with AutoGPTQ after merging the LoRA weights.
### Inference with OpenAI-style API
@@ -662,19 +662,23 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
### Dockerize Training
#### Get ready
Necessary dockerized environment is needed, such as Docker or Docker Compose.
#### Docker support
#### Use Docker
```bash
docker build -f ./Dockerfile -t llama-factory:latest .
docker run --gpus=all -v ./hf_cache:/root/.cache/huggingface/ -v ./data:/app/data -v ./output:/app/output -p 7860:7860 --shm-size 16G --name llama_factory -d llama-factory:latest
docker run --gpus=all \
-v ./hf_cache:/root/.cache/huggingface/ \
-v ./data:/app/data \
-v ./output:/app/output \
-e CUDA_VISIBLE_DEVICES=0 \
-p 7860:7860 \
--shm-size 16G \
--name llama_factory \
-d llama-factory:latest
```
#### Docker Compose support
#### Use Docker Compose
```bash
docker compose -f ./docker-compose.yml up -d
@@ -682,7 +686,7 @@ docker compose -f ./docker-compose.yml up -d
> [!TIP]
> Details about volume:
> * hf_cache: Utilize Huggingface cache on the host machine. Reassignable if a cache already exists in a different directory.
> * hf_cache: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
> * data: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI.
> * output: Set export dir to this location so that the merged result can be accessed directly on the host machine.
@@ -709,6 +713,7 @@ docker compose -f ./docker-compose.yml up -d
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.

View File

@@ -5,7 +5,7 @@
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
[![Citation](https://img.shields.io/badge/citation-26-green)](#使用了-llama-factory-的项目)
[![Citation](https://img.shields.io/badge/citation-27-green)](#使用了-llama-factory-的项目)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
@@ -72,14 +72,14 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
[24/03/20] 我们支持了能在 2x24GB GPU 上微调 70B 模型的 **FSDP+QLoRA**。详细用法请参照 `examples/fsdp_qlora`
[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。请使用 `loraplus_lr_ratio=16.0` 参数开启 LoRA+ 方法
[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。详细用法请参照 `examples/extras/loraplus`
[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。请使用 `--use_galore` 参数切换显存高效的优化器
[24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA请先合并权重。
[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。详细用法请参照 `examples/extras/galore`
<details><summary>展开日志</summary>
[24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA请先合并权重。
[24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。
[24/02/15] 我们支持了 [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro) 提出的**块扩展**方法。详细用法请参照 `examples/extras/llama_pro`
@@ -450,7 +450,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
```
> [!TIP]
> 使用 `--adapter_name_or_path path_to_sft_checkpoint,path_to_ppo_checkpoint` 来进行微调模型的推理。
> 如果开启了 `--create_new_adapter`,则使用 `--adapter_name_or_path path_to_sft_checkpoint,path_to_ppo_checkpoint` 来进行微调模型的推理。
> [!WARNING]
> 如果使用 fp16 精度进行 LLaMA-2 模型的 PPO 训练,请使用 `--per_device_train_batch_size=1`。
@@ -481,7 +481,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
```
> [!TIP]
> 使用 `--adapter_name_or_path path_to_sft_checkpoint,path_to_dpo_checkpoint` 来进行微调模型的推理。
> 如果开启了 `--create_new_adapter`,则使用 `--adapter_name_or_path path_to_sft_checkpoint,path_to_dpo_checkpoint` 来进行微调模型的推理。
### 多 GPU 分布式训练
@@ -569,7 +569,7 @@ deepspeed --num_gpus 8 src/train_bash.py \
### 合并 LoRA 权重并导出模型
```bash
CUDA_VISIBLE_DEVICES=0 python src/export_model.py \
CUDA_VISIBLE_DEVICES= python src/export_model.py \
--model_name_or_path path_to_llama_model \
--adapter_name_or_path path_to_checkpoint \
--template default \
@@ -585,7 +585,7 @@ CUDA_VISIBLE_DEVICES=0 python src/export_model.py \
> [!TIP]
> 仅使用 `--model_name_or_path path_to_export` 来加载导出后的模型。
>
> 合并 LoRA 权重之后可再次使用 `--export_quantization_bit 4` 和 `--export_quantization_dataset data/c4_demo.json` 基于 AutoGPTQ 量化模型。
> 合并 LoRA 权重之后可再次使用 `CUDA_VISIBLE_DEVICES=0`、`--export_quantization_bit 4` 和 `--export_quantization_dataset data/c4_demo.json` 基于 AutoGPTQ 量化模型。
### 使用 OpenAI 风格 API 推理
@@ -659,6 +659,36 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
> [!TIP]
> 我们建议在量化模型的预测中使用 `--per_device_eval_batch_size=1` 和 `--max_target_length 128`。
### 使用容器
#### 使用 Docker
```bash
docker build -f ./Dockerfile -t llama-factory:latest .
docker run --gpus=all \
-v ./hf_cache:/root/.cache/huggingface/ \
-v ./data:/app/data \
-v ./output:/app/output \
-e CUDA_VISIBLE_DEVICES=0 \
-p 7860:7860 \
--shm-size 16G \
--name llama_factory \
-d llama-factory:latest
```
#### 使用 Docker Compose
```bash
docker compose -f ./docker-compose.yml up -d
```
> [!TIP]
> 数据卷详情:
> * hf_cache使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。
> * data宿主机中存放数据集的文件夹路径。
> * output将导出目录设置为该路径后即可在宿主机中访问导出后的模型。
## 使用了 LLaMA Factory 的项目
1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223)
@@ -682,6 +712,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。

View File

@@ -10,6 +10,8 @@ services:
- ./hf_cache:/root/.cache/huggingface/
- ./data:/app/data
- ./output:/app/output
environment:
- CUDA_VISIBLE_DEVICES=0
ports:
- "7860:7860"
ipc: host

View File

@@ -9,7 +9,7 @@ fsdp_config:
fsdp_forward_prefetch: false
fsdp_offload_params: true
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: false
machine_rank: 0

View File

@@ -1,31 +0,0 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../../data \
--template default \
--finetuning_type full \
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--fp16

View File

@@ -1,32 +0,0 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../../data \
--template default \
--finetuning_type full \
--optim adamw_8bit \
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--pure_bf16

View File

@@ -1,36 +0,0 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../../data \
--template default \
--finetuning_type full \
--optim adamw_8bit \
--use_galore \
--galore_layerwise \
--galore_target mlp,self_attn \
--galore_rank 128 \
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--pure_bf16

View File

@@ -32,4 +32,4 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--fp16
--pure_bf16

View File

@@ -1,5 +1,5 @@
```bash
pip install git+https://github.com/huggingface/transformers.git
pip install "transformers>=4.39.1"
pip install "accelerate>=0.28.0"
pip install "bitsandbytes>=0.43.0"
```

View File

@@ -2,7 +2,7 @@ torch>=1.13.1
transformers>=4.37.2
datasets>=2.14.3
accelerate>=0.27.2
peft>=0.9.0
peft>=0.10.0
trl>=0.8.1
gradio>=3.38.0,<4.0.0
scipy

View File

@@ -7,5 +7,5 @@ from .train import export_model, run_exp
from .webui import create_ui, create_web_demo
__version__ = "0.6.0"
__version__ = "0.6.1"
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]

View File

@@ -689,6 +689,8 @@ _register_template(
_register_template(
name="vanilla",
format_separator=EmptyFormatter(slots=["\n"]),
efficient_eos=True,
)

View File

@@ -1,14 +1,10 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Tuple
from typing import Dict, List, Sequence, Tuple
from ..data import Role
from ..extras.constants import CHOICES
if TYPE_CHECKING:
from datasets import Dataset
@dataclass
class EvalTemplate:
system: str
@@ -16,22 +12,29 @@ class EvalTemplate:
answer: str
prefix: str
def parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
r"""
input: a dict with keys {"question", "A", "B", "C", "D", "answer"}
output: a tuple of (prompt, response)
"""
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
def format_example(
self, target_data: Dict[str, str], support_set: "Dataset", subject_name: str
self, target_data: Dict[str, str], support_set: Sequence[Dict[str, str]], subject_name: str
) -> List[Dict[str, str]]:
r"""
Converts dataset examples to messages.
"""
messages = []
for k in range(len(support_set)):
prompt, response = self.parse_example(support_set[k])
messages.append({"role": Role.USER, "content": prompt})
messages.append({"role": Role.ASSISTANT, "content": response})
prompt, response = self._parse_example(support_set[k])
messages.append({"role": Role.USER.value, "content": prompt})
messages.append({"role": Role.ASSISTANT.value, "content": response})
prompt, response = self.parse_example(target_data)
messages.append({"role": Role.USER, "content": prompt})
messages.append({"role": Role.ASSISTANT, "content": response})
prompt, response = self._parse_example(target_data)
messages.append({"role": Role.USER.value, "content": prompt})
messages.append({"role": Role.ASSISTANT.value, "content": response})
messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
return messages
@@ -39,7 +42,7 @@ class EvalTemplate:
eval_templates: Dict[str, "EvalTemplate"] = {}
def register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
def _register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix)
@@ -49,7 +52,7 @@ def get_eval_template(name: str) -> "EvalTemplate":
return eval_template
register_eval_template(
_register_eval_template(
name="en",
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
choice="\n{choice}. {content}",
@@ -58,10 +61,10 @@ register_eval_template(
)
register_eval_template(
_register_eval_template(
name="zh",
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
choice="\n{choice}. {content}",
answer="\n答案:",
prefix="\n",
prefix=" ",
)

View File

@@ -58,9 +58,17 @@ class LogCallback(TrainerCallback):
self.in_training = True
self.start_time = time.time()
self.max_steps = state.max_steps
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
if args.save_on_each_node:
if not state.is_local_process_zero:
return
else:
if not state.is_world_process_zero:
return
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
@@ -112,8 +120,12 @@ class LogCallback(TrainerCallback):
r"""
Event called after logging the last logs.
"""
if not state.is_local_process_zero:
return
if args.save_on_each_node:
if not state.is_local_process_zero:
return
else:
if not state.is_world_process_zero:
return
logs = dict(
current_steps=self.cur_steps,

View File

@@ -64,7 +64,7 @@ def check_dependencies() -> None:
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
require_version("peft>=0.9.0", "To fix: pip install peft>=0.9.0")
require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0")
require_version("trl>=0.8.1", "To fix: pip install trl>=0.8.1")

View File

@@ -102,6 +102,10 @@ class RLHFArguments:
default="sigmoid",
metadata={"help": "The type of DPO loss to use."},
)
dpo_label_smoothing: float = field(
default=0.0,
metadata={"help": "The robust DPO label smoothing parameter in cDPO that should be between 0 and 0.5."},
)
dpo_ftx: float = field(
default=0.0,
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
@@ -114,10 +118,6 @@ class RLHFArguments:
default=4,
metadata={"help": "The number of epochs to perform in a PPO optimization step."},
)
ppo_logger: Optional[str] = field(
default=None,
metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'},
)
ppo_score_norm: bool = field(
default=False,
metadata={"help": "Use score normalization in PPO training."},
@@ -248,6 +248,9 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
if self.stage == "dpo" and self.dpo_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6:
raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
if self.use_llama_pro and self.finetuning_type == "full":
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.")

View File

@@ -8,7 +8,6 @@ import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.versions import require_version
from ..extras.logging import get_logger
from ..extras.misc import check_dependencies
@@ -119,6 +118,13 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
raise ValueError("Unsloth does not support lora reward model.")
if (
finetuning_args.stage == "ppo"
and training_args.report_to is not None
and training_args.report_to[0] not in ["wandb", "tensorboard"]
):
raise ValueError("PPO only accepts wandb or tensorboard logger.")
if training_args.max_steps == -1 and data_args.streaming:
raise ValueError("Please specify `max_steps` in streaming mode.")
@@ -128,12 +134,8 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if training_args.do_train and model_args.use_unsloth and not is_unsloth_available():
raise ValueError("Unsloth was not installed: https://github.com/unslothai/unsloth")
if finetuning_args.use_dora:
if model_args.quantization_bit is not None:
require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0")
if model_args.use_unsloth:
raise ValueError("Unsloth does not support DoRA.")
if finetuning_args.use_dora and model_args.use_unsloth:
raise ValueError("Unsloth does not support DoRA.")
if finetuning_args.pure_bf16:
if not is_torch_bf16_gpu_available():

View File

@@ -173,7 +173,7 @@ def _configure_quantization(
"""
if getattr(config, "quantization_config", None): # ptq
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.")
init_kwargs["device_map"] = {"": get_current_device()}
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)

View File

@@ -47,6 +47,8 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
output_layer_names = ["lm_head"]
if model.config.model_type == "chatglm":
output_layer_names.append("output_layer")
elif model.config.model_type == "internlm2":
output_layer_names.append("output")
module_names = set()
for name, module in model.named_modules():

View File

@@ -8,7 +8,7 @@ from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX
from ..utils import create_custom_optimzer
from ..utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING:
@@ -20,12 +20,9 @@ if TYPE_CHECKING:
class CustomDPOTrainer(DPOTrainer):
def __init__(
self,
beta: float,
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"],
ftx_gamma: float,
model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]],
finetuning_args: "FinetuningArguments",
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: bool = True,
**kwargs,
):
@@ -47,10 +44,10 @@ class CustomDPOTrainer(DPOTrainer):
self._peft_has_been_casted_to_bf16 = False
self.ref_model = ref_model
self.beta = beta
self.label_smoothing = 0
self.loss_type = loss_type
self.ftx_gamma = ftx_gamma
self.beta = finetuning_args.dpo_beta
self.label_smoothing = finetuning_args.dpo_label_smoothing
self.loss_type = finetuning_args.dpo_loss
self.ftx_gamma = finetuning_args.dpo_ftx
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, model=model, **kwargs)
@@ -66,12 +63,16 @@ class CustomDPOTrainer(DPOTrainer):
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
def create_optimizer_and_scheduler(self, num_training_steps: int) -> None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.create_optimizer()
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
def sft_loss(self, chosen_logits: torch.FloatTensor, chosen_labels: torch.LongTensor) -> torch.Tensor:
r"""

View File

@@ -28,6 +28,7 @@ def run_dpo(
tokenizer = load_tokenizer(model_args)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = DPODataCollatorWithPadding(
tokenizer=tokenizer,
pad_to_multiple_of=8,
@@ -45,13 +46,10 @@ def run_dpo(
# Initialize our Trainer
trainer = CustomDPOTrainer(
beta=finetuning_args.dpo_beta,
loss_type=finetuning_args.dpo_loss,
ftx_gamma=finetuning_args.dpo_ftx,
finetuning_args=finetuning_args,
model=model,
ref_model=ref_model,
args=training_args,
finetuning_args=finetuning_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,

View File

@@ -6,20 +6,23 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import torch
from tqdm import tqdm
from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
from transformers.optimization import get_scheduler
from transformers.trainer_pt_utils import remove_dummy_checkpoint
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from trl import PPOTrainer
from trl import PPOConfig, PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits
from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
from ...extras.logging import get_logger
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
from ..utils import create_custom_optimzer, create_custom_scheduler
from .utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from datasets import Dataset
from transformers import DataCollatorWithPadding, PreTrainedTokenizer, Seq2SeqTrainingArguments, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead
from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments
@@ -40,10 +43,53 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: List["TrainerCallback"],
reward_model: "AutoModelForCausalLMWithValueHead",
**kwargs,
model: "AutoModelForCausalLMWithValueHead",
reward_model: Optional["AutoModelForCausalLMWithValueHead"],
ref_model: Optional["AutoModelForCausalLMWithValueHead"],
tokenizer: "PreTrainedTokenizer",
dataset: "Dataset",
data_collator: "DataCollatorWithPadding",
):
PPOTrainer.__init__(self, **kwargs)
backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
ppo_config = PPOConfig(
model_name=model_args.model_name_or_path,
learning_rate=training_args.learning_rate,
mini_batch_size=training_args.per_device_train_batch_size,
batch_size=backward_batch_size * finetuning_args.ppo_buffer_size,
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
ppo_epochs=finetuning_args.ppo_epochs,
max_grad_norm=training_args.max_grad_norm,
seed=training_args.seed,
optimize_device_cache=True,
target=finetuning_args.ppo_target,
use_score_scaling=finetuning_args.ppo_score_norm,
use_score_norm=finetuning_args.ppo_score_norm,
whiten_rewards=finetuning_args.ppo_whiten_rewards,
accelerator_kwargs={"step_scheduler_with_optimizer": False},
log_with=training_args.report_to[0] if training_args.report_to is not None else None,
project_kwargs={"logging_dir": training_args.logging_dir},
)
# Create optimizer and scheduler
if training_args.max_steps > 0:
num_training_steps = training_args.max_steps
else:
total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
optimizer = self.create_optimizer(model, training_args, finetuning_args)
scheduler = self.create_scheduler(training_args, num_training_steps, optimizer)
PPOTrainer.__init__(
self,
config=ppo_config,
model=model,
ref_model=ref_model,
tokenizer=tokenizer,
dataset=dataset,
data_collator=data_collator,
lr_scheduler=scheduler,
)
self.args = training_args
self.model_args = model_args
@@ -205,6 +251,44 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
)
def create_optimizer(
self,
model: "AutoModelForCausalLMWithValueHead",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
if optimizer is None:
decay_params, nodecay_params = [], []
decay_param_names = self.get_decay_parameter_names(model)
for name, param in model.named_parameters():
if param.requires_grad:
if name in decay_param_names:
decay_params.append(param)
else:
nodecay_params.append(param)
optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
param_groups = [
dict(params=nodecay_params),
dict(params=decay_params, weight_decay=training_args.weight_decay),
]
optimizer = optim_class(param_groups, **optim_kwargs)
return optimizer
def create_scheduler(
self, training_args: "Seq2SeqTrainingArguments", num_training_steps: int, optimizer: "torch.optim.Optimizer"
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(training_args, num_training_steps, optimizer)
lr_scheduler = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
return lr_scheduler
@torch.no_grad()
def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
r"""

View File

@@ -1,19 +1,15 @@
# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
import math
from typing import TYPE_CHECKING, List, Optional
from torch.optim import AdamW
from transformers import DataCollatorWithPadding
from transformers.optimization import get_scheduler
from trl import PPOConfig
from ...data import get_dataset
from ...extras.callbacks import FixValueHeadModelCallback
from ...extras.misc import fix_valuehead_checkpoint
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..utils import create_custom_optimzer, create_ref_model, create_reward_model
from ..utils import create_ref_model, create_reward_model
from .trainer import CustomPPOTrainer
@@ -42,45 +38,6 @@ def run_ppo(
ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True)
reward_model = create_reward_model(model, model_args, finetuning_args)
# Create ppo config
backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
ppo_config = PPOConfig(
model_name=model_args.model_name_or_path,
learning_rate=training_args.learning_rate,
mini_batch_size=training_args.per_device_train_batch_size,
batch_size=backward_batch_size * finetuning_args.ppo_buffer_size,
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
ppo_epochs=finetuning_args.ppo_epochs,
max_grad_norm=training_args.max_grad_norm,
seed=training_args.seed,
optimize_device_cache=True,
target=finetuning_args.ppo_target,
log_with=finetuning_args.ppo_logger,
use_score_scaling=finetuning_args.ppo_score_norm,
use_score_norm=finetuning_args.ppo_score_norm,
whiten_rewards=finetuning_args.ppo_whiten_rewards,
accelerator_kwargs={"step_scheduler_with_optimizer": False},
project_kwargs={"logging_dir": training_args.logging_dir},
)
# Create optimizer and scheduler
if training_args.max_steps > 0:
num_training_steps = training_args.max_steps
else:
total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
optimizer = create_custom_optimzer(model, training_args, finetuning_args, num_training_steps)
if optimizer is None:
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
lr_scheduler = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
# Initialize our Trainer
ppo_trainer = CustomPPOTrainer(
model_args=model_args,
@@ -88,15 +45,12 @@ def run_ppo(
finetuning_args=finetuning_args,
generating_args=generating_args,
callbacks=callbacks + [FixValueHeadModelCallback()],
reward_model=reward_model,
config=ppo_config,
model=model,
reward_model=reward_model,
ref_model=ref_model,
tokenizer=tokenizer,
dataset=dataset,
data_collator=data_collator,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
# Training

View File

@@ -1,12 +1,14 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
from transformers import Trainer
from ...extras.logging import get_logger
from ..utils import create_custom_optimzer
from ..utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING:
import torch
from ...hparams import FinetuningArguments
@@ -22,9 +24,13 @@ class CustomTrainer(Trainer):
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
def create_optimizer_and_scheduler(self, num_training_steps: int) -> None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.create_optimizer()
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)

View File

@@ -1,12 +1,12 @@
import json
import os
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from transformers import Trainer
from ...extras.logging import get_logger
from ..utils import create_custom_optimzer
from ..utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING:
@@ -29,12 +29,16 @@ class PairwiseTrainer(Trainer):
self.finetuning_args = finetuning_args
self.can_return_loss = True # override property to return eval_loss
def create_optimizer_and_scheduler(self, num_training_steps: int) -> None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.create_optimizer()
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False

View File

@@ -8,7 +8,7 @@ from transformers import Seq2SeqTrainer
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from ..utils import create_custom_optimzer
from ..utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING:
@@ -29,12 +29,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
def create_optimizer_and_scheduler(self, num_training_steps: int) -> None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.create_optimizer()
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
def prediction_step(
self,

View File

@@ -29,7 +29,13 @@ logger = get_logger(__name__)
class DummyOptimizer(torch.optim.Optimizer):
def __init__(self, lr: float = 1e-3, optimizer_dict: Optional[dict] = None, *args, **kwargs) -> None:
r"""
A dummy optimizer used for the GaLore algorithm.
"""
def __init__(
self, lr: float = 1e-3, optimizer_dict: Optional[Dict["torch.nn.Parameter", "torch.optim.Optimizer"]] = None
) -> None:
dummy_tensor = torch.randn(1, 1)
self.optimizer_dict = optimizer_dict
super().__init__([dummy_tensor], {"lr": lr})
@@ -64,7 +70,7 @@ def create_modelcard_and_push(
def create_ref_model(
model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: bool = False
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
) -> Optional[Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]]:
r"""
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
@@ -99,7 +105,7 @@ def create_ref_model(
def create_reward_model(
model: "AutoModelForCausalLMWithValueHead", model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
) -> "AutoModelForCausalLMWithValueHead":
) -> Optional["AutoModelForCausalLMWithValueHead"]:
r"""
Creates reward model for PPO training.
"""
@@ -156,8 +162,9 @@ def _create_galore_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
max_steps: int,
) -> "torch.optim.Optimizer":
require_version("galore_torch", "To fix: pip install galore_torch")
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
galore_targets = find_all_linear_modules(model)
else:
@@ -212,29 +219,19 @@ def _create_galore_optimizer(
for param in decay_params:
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
for param in galore_params:
for param in galore_params: # galore params have weight decay
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **galore_kwargs)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
scheduler_dict: Dict["torch.Tensor", "torch.optim.lr_scheduler.LRScheduler"] = {}
for param in trainable_params:
scheduler_dict[param] = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer_dict[param],
num_warmup_steps=training_args.get_warmup_steps(max_steps) * 2,
num_training_steps=max_steps * 2,
)
def optimizer_hook(param: "torch.Tensor"):
def optimizer_hook(param: "torch.nn.Parameter"):
if param.grad is not None:
optimizer_dict[param].step()
optimizer_dict[param].zero_grad()
scheduler_dict[param].step()
for param in trainable_params:
param.register_post_accumulate_grad_hook(optimizer_hook)
optimizer = DummyOptimizer(lr=training_args.learning_rate) # display scheduler result
optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict)
else:
param_groups = [
dict(params=nodecay_params),
@@ -293,10 +290,34 @@ def create_custom_optimzer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
max_steps: int,
) -> Optional["torch.optim.Optimizer"]:
if finetuning_args.use_galore:
return _create_galore_optimizer(model, training_args, finetuning_args, max_steps)
return _create_galore_optimizer(model, training_args, finetuning_args)
if finetuning_args.loraplus_lr_ratio is not None:
return _create_loraplus_optimizer(model, training_args, finetuning_args)
def create_custom_scheduler(
training_args: "Seq2SeqTrainingArguments",
num_training_steps: int,
optimizer: Optional["torch.optim.Optimizer"] = None,
) -> None:
if optimizer is not None and isinstance(optimizer, DummyOptimizer):
optimizer_dict = optimizer.optimizer_dict
scheduler_dict: Dict["torch.nn.Parameter", "torch.optim.lr_scheduler.LRScheduler"] = {}
for param in optimizer_dict.keys():
scheduler_dict[param] = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer_dict[param],
num_warmup_steps=training_args.get_warmup_steps(num_training_steps) * 2,
num_training_steps=num_training_steps * 2,
)
def scheduler_hook(param: "torch.nn.Parameter"):
if param.grad is not None:
scheduler_dict[param].step()
for param in optimizer_dict.keys():
param.register_post_accumulate_grad_hook(scheduler_hook)