Compare commits
70 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c8fe3f544b | ||
|
|
0f1ad7140f | ||
|
|
233e167f68 | ||
|
|
1d341dcd83 | ||
|
|
d16561e7a4 | ||
|
|
f8e219dc81 | ||
|
|
3365cc8cf0 | ||
|
|
3a5e68b7d9 | ||
|
|
0cb596fee1 | ||
|
|
b3b5b530d1 | ||
|
|
9225c15c88 | ||
|
|
abd9fed445 | ||
|
|
44cda2eece | ||
|
|
8397808d1d | ||
|
|
9e1bd6420d | ||
|
|
619264c854 | ||
|
|
1ebac62e3d | ||
|
|
ce9bdb3509 | ||
|
|
0c8d6369ac | ||
|
|
bee796f6b5 | ||
|
|
9f6349a333 | ||
|
|
171a029c5e | ||
|
|
eaefaa0fe0 | ||
|
|
d301f0a64b | ||
|
|
0a1578e4e3 | ||
|
|
a4167fd925 | ||
|
|
42084e08ae | ||
|
|
9d23f5dc89 | ||
|
|
5978427ae0 | ||
|
|
c7c216069c | ||
|
|
cde9d1b917 | ||
|
|
96213f04b0 | ||
|
|
7ecea08b9b | ||
|
|
191971865d | ||
|
|
ff4f587dd9 | ||
|
|
de728d0371 | ||
|
|
d08e09642d | ||
|
|
351493b183 | ||
|
|
86ab47e121 | ||
|
|
6dd6b3e396 | ||
|
|
5f1418a68b | ||
|
|
7b97a79efc | ||
|
|
ce4f653121 | ||
|
|
b053c6454e | ||
|
|
ebf0f4a77c | ||
|
|
efa808069a | ||
|
|
b5c5283dd6 | ||
|
|
b638c65519 | ||
|
|
d4d471450f | ||
|
|
3144bdec2c | ||
|
|
c6d6c4c209 | ||
|
|
f5f1589662 | ||
|
|
276f2cb24e | ||
|
|
952b785bb3 | ||
|
|
72dd676208 | ||
|
|
dfaa31e991 | ||
|
|
86556b1c74 | ||
|
|
0c80751e87 | ||
|
|
9338f878a3 | ||
|
|
fde3d91242 | ||
|
|
19adfb88a9 | ||
|
|
daaafa900a | ||
|
|
0dcc9e0bca | ||
|
|
aeec78b35c | ||
|
|
c991654cb4 | ||
|
|
f328413646 | ||
|
|
106a0104da | ||
|
|
5486ea09e3 | ||
|
|
31bbbb6d13 | ||
|
|
1a77de82fa |
61
README.md
61
README.md
@@ -5,7 +5,7 @@
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](#projects-using-llama-factory)
|
||||
[](#projects-using-llama-factory)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[](https://discord.gg/rKfvV9r9FK)
|
||||
[](https://twitter.com/llamafactory_ai)
|
||||
@@ -46,7 +46,7 @@ Choose your path:
|
||||
- **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
||||
- **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO, DPO and ORPO.
|
||||
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
|
||||
- **Advanced algorithms**: GaLore, DoRA, LongLoRA, LLaMA Pro, LoRA+, LoftQ and Agent tuning.
|
||||
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and Agent tuning.
|
||||
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
|
||||
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
|
||||
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
|
||||
@@ -68,14 +68,22 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
|
||||
## Changelog
|
||||
|
||||
[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See `examples/extras/mod` for usage.
|
||||
|
||||
[24/04/19] We supported **Meta Llama 3** model series.
|
||||
|
||||
[24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)**. See `examples/extras/badam` for usage.
|
||||
|
||||
[24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison).
|
||||
|
||||
<details><summary>Full Changelog</summary>
|
||||
|
||||
[24/03/31] We supported **[ORPO](https://arxiv.org/abs/2403.07691)**. See `examples/lora_single_gpu` for usage.
|
||||
|
||||
[24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv!
|
||||
|
||||
[24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See `examples/extras/fsdp_qlora` for usage.
|
||||
|
||||
<details><summary>Full Changelog</summary>
|
||||
|
||||
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See `examples/extras/loraplus` for usage.
|
||||
|
||||
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. See `examples/extras/galore` for usage.
|
||||
@@ -129,20 +137,22 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
| Model | Model size | Default module | Template |
|
||||
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
|
||||
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
|
||||
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
|
||||
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | query_key_value | chatglm3 |
|
||||
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | q_proj,v_proj | cohere |
|
||||
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
|
||||
| [Gemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
|
||||
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
|
||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B | q_proj,v_proj | mistral |
|
||||
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | att_proj | olmo |
|
||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
|
||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||
| [Qwen1.5 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B | q_proj,v_proj | qwen |
|
||||
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B | q_proj,v_proj | qwen |
|
||||
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||
| [Yi](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi |
|
||||
@@ -241,6 +251,7 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [DPO mix (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
|
||||
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
|
||||
|
||||
</details>
|
||||
@@ -275,16 +286,15 @@ huggingface-cli login
|
||||
|
||||
\* *estimated*
|
||||
|
||||
| Method | Bits | 7B | 13B | 30B | 70B | 8x7B |
|
||||
| ------ | ---- | ----- | ----- | ----- | ------ | ------ |
|
||||
| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB |
|
||||
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 400GB |
|
||||
| GaLore | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
|
||||
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 160GB |
|
||||
| LoRA | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
|
||||
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB |
|
||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB |
|
||||
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB |
|
||||
| Method | Bits | 7B | 13B | 30B | 70B | 8x7B | 8x22B |
|
||||
| ----------------- | ---- | ----- | ----- | ----- | ------ | ----- | ------ |
|
||||
| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB | 2400GB |
|
||||
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 400GB | 1200GB |
|
||||
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 160GB | 400GB |
|
||||
| LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 120GB | 320GB |
|
||||
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB | 160GB |
|
||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB | 96GB |
|
||||
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB | 48GB |
|
||||
|
||||
## Getting Started
|
||||
|
||||
@@ -305,7 +315,7 @@ cd LLaMA-Factory
|
||||
pip install -e .[metrics]
|
||||
```
|
||||
|
||||
Extra dependencies available: deepspeed, metrics, unsloth, galore, vllm, bitsandbytes, gptq, awq, aqlm, qwen, modelscope, quality
|
||||
Extra dependencies available: deepspeed, metrics, unsloth, galore, badam, vllm, bitsandbytes, gptq, awq, aqlm, qwen, modelscope, quality
|
||||
|
||||
<details><summary>For Windows users</summary>
|
||||
|
||||
@@ -328,6 +338,7 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES=0 # `set CUDA_VISIBLE_DEVICES=0` for Windows
|
||||
export GRADIO_SERVER_PORT=7860 # `set GRADIO_SERVER_PORT=7860` for Windows
|
||||
python src/train_web.py # or python -m llmtuner.webui.interface
|
||||
```
|
||||
|
||||
@@ -413,8 +424,14 @@ If you have a project that should be incorporated, please contact via email or c
|
||||
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
|
||||
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
|
||||
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
|
||||
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
|
||||
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
|
||||
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
|
||||
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
|
||||
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
|
||||
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
|
||||
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
|
||||
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
|
||||
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
|
||||
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
||||
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
||||
@@ -427,7 +444,7 @@ If you have a project that should be incorporated, please contact via email or c
|
||||
|
||||
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||
|
||||
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
|
||||
## Citation
|
||||
|
||||
|
||||
63
README_zh.md
63
README_zh.md
@@ -5,7 +5,7 @@
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](#使用了-llama-factory-的项目)
|
||||
[](#使用了-llama-factory-的项目)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[](https://discord.gg/rKfvV9r9FK)
|
||||
[](https://twitter.com/llamafactory_ai)
|
||||
@@ -46,7 +46,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
- **多种模型**:LLaMA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
|
||||
- **集成方法**:(增量)预训练、指令监督微调、奖励模型训练、PPO 训练、DPO 训练和 ORPO 训练。
|
||||
- **多种精度**:32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
|
||||
- **先进算法**:GaLore、DoRA、LongLoRA、LLaMA Pro、LoRA+、LoftQ 和 Agent 微调。
|
||||
- **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 Agent 微调。
|
||||
- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
|
||||
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。
|
||||
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
|
||||
@@ -68,14 +68,22 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
|
||||
## 更新日志
|
||||
|
||||
[24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 `examples/extras/mod`。
|
||||
|
||||
[24/04/19] 我们支持了 **Meta Llama 3** 系列模型。
|
||||
|
||||
[24/04/16] 我们支持了 **[BAdam](https://arxiv.org/abs/2404.02827)**。详细用法请参照 `examples/extras/badam`。
|
||||
|
||||
[24/04/16] 我们支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的长序列训练(24GB 可训练 Llama-2-7B-56k)。该方法相比 FlashAttention-2 提供了 **117%** 的训练速度和 **50%** 的显存节约。更多数据请见[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
|
||||
|
||||
<details><summary>展开日志</summary>
|
||||
|
||||
[24/03/31] 我们支持了 **[ORPO](https://arxiv.org/abs/2403.07691)**。详细用法请参照 `examples/lora_single_gpu`。
|
||||
|
||||
[24/03/21] 我们的论文 "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" 可在 arXiv 上查看!
|
||||
|
||||
[24/03/20] 我们支持了能在 2x24GB GPU 上微调 70B 模型的 **FSDP+QLoRA**。详细用法请参照 `examples/extras/fsdp_qlora`。
|
||||
|
||||
<details><summary>展开日志</summary>
|
||||
|
||||
[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。详细用法请参照 `examples/extras/loraplus`。
|
||||
|
||||
[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。详细用法请参照 `examples/extras/galore`。
|
||||
@@ -129,20 +137,22 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
| 模型名 | 模型大小 | 默认模块 | Template |
|
||||
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
|
||||
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
|
||||
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
|
||||
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | query_key_value | chatglm3 |
|
||||
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | q_proj,v_proj | cohere |
|
||||
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
|
||||
| [Gemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
|
||||
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
|
||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B | q_proj,v_proj | mistral |
|
||||
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | att_proj | olmo |
|
||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
|
||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||
| [Qwen1.5 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B | q_proj,v_proj | qwen |
|
||||
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B | q_proj,v_proj | qwen |
|
||||
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||
| [Yi](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi |
|
||||
@@ -241,6 +251,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [DPO mix (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
|
||||
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
|
||||
|
||||
</details>
|
||||
@@ -275,16 +286,15 @@ huggingface-cli login
|
||||
|
||||
\* *估算值*
|
||||
|
||||
| 训练方法 | 精度 | 7B | 13B | 30B | 70B | 8x7B |
|
||||
| ------- | ---- | ----- | ----- | ----- | ------ | ------ |
|
||||
| 全参数 | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB |
|
||||
| 全参数 | 16 | 60GB | 120GB | 300GB | 600GB | 400GB |
|
||||
| GaLore | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
|
||||
| 部分参数 | 16 | 20GB | 40GB | 80GB | 200GB | 160GB |
|
||||
| LoRA | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
|
||||
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB |
|
||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB |
|
||||
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB |
|
||||
| 方法 | 精度 | 7B | 13B | 30B | 70B | 8x7B | 8x22B |
|
||||
| ----------------- | ---- | ----- | ----- | ----- | ------ | ----- | ------ |
|
||||
| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB | 2400GB |
|
||||
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 400GB | 1200GB |
|
||||
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 160GB | 400GB |
|
||||
| LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 120GB | 320GB |
|
||||
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB | 160GB |
|
||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB | 96GB |
|
||||
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB | 48GB |
|
||||
|
||||
## 如何使用
|
||||
|
||||
@@ -305,7 +315,7 @@ cd LLaMA-Factory
|
||||
pip install -e .[metrics]
|
||||
```
|
||||
|
||||
可选的额外依赖项:deepspeed、metrics、unsloth、galore、vllm、bitsandbytes、gptq、awq、aqlm、qwen、modelscope、quality
|
||||
可选的额外依赖项:deepspeed、metrics、unsloth、galore、badam、vllm、bitsandbytes、gptq、awq、aqlm、qwen、modelscope、quality
|
||||
|
||||
<details><summary>Windows 用户指南</summary>
|
||||
|
||||
@@ -328,6 +338,7 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES=0 # Windows 使用 `set CUDA_VISIBLE_DEVICES=0`
|
||||
export GRADIO_SERVER_PORT=7860 # Windows 使用 `set GRADIO_SERVER_PORT=7860`
|
||||
python src/train_web.py # 或 python -m llmtuner.webui.interface
|
||||
```
|
||||
|
||||
@@ -388,7 +399,7 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
||||
|
||||
## 使用了 LLaMA Factory 的项目
|
||||
|
||||
如果您有项目希望添加至上述列表,请通过邮件联系或者创建一个 PR。
|
||||
如果您有项目希望添加至下述列表,请通过邮件联系或者创建一个 PR。
|
||||
|
||||
<details><summary>点击显示</summary>
|
||||
|
||||
@@ -413,8 +424,14 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
||||
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
|
||||
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
|
||||
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
|
||||
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
|
||||
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
|
||||
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
|
||||
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
|
||||
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
|
||||
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
|
||||
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
|
||||
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
|
||||
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
|
||||
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
|
||||
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
||||
@@ -427,7 +444,7 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
||||
|
||||
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
||||
|
||||
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
|
||||
## 引用
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
34c723573fbc2d7601f6d9c882ccf5aa4f9bcc4b
|
||||
a97cf9475291591843976554878568e046d8a46d
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
|
||||
import datasets
|
||||
|
||||
|
||||
@@ -22,31 +23,19 @@ _URL = "{}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0
|
||||
|
||||
|
||||
class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
|
||||
VERSION = datasets.Version("0.0.0")
|
||||
|
||||
def _info(self):
|
||||
features = datasets.Features({
|
||||
"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]
|
||||
})
|
||||
features = datasets.Features(
|
||||
{"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]}
|
||||
)
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
features=features,
|
||||
homepage=_HOMEPAGE,
|
||||
license=_LICENSE,
|
||||
citation=_CITATION
|
||||
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager):
|
||||
file_path = dl_manager.download(_URL)
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"filepath": file_path
|
||||
}
|
||||
)
|
||||
]
|
||||
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
|
||||
|
||||
def _generate_examples(self, filepath: str):
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import json
|
||||
import datasets
|
||||
from typing import Any, Dict, Generator, List, Tuple
|
||||
|
||||
import datasets
|
||||
|
||||
|
||||
_DESCRIPTION = "An example of dataset."
|
||||
_CITATION = ""
|
||||
@@ -11,34 +12,24 @@ _URL = "examples.json"
|
||||
|
||||
|
||||
class ExampleDataset(datasets.GeneratorBasedBuilder):
|
||||
|
||||
VERSION = datasets.Version("0.0.0")
|
||||
|
||||
def _info(self) -> datasets.DatasetInfo:
|
||||
features = datasets.Features({
|
||||
features = datasets.Features(
|
||||
{
|
||||
"instruction": datasets.Value("string"),
|
||||
"input": datasets.Value("string"),
|
||||
"output": datasets.Value("string"),
|
||||
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
|
||||
})
|
||||
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
|
||||
}
|
||||
)
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
features=features,
|
||||
homepage=_HOMEPAGE,
|
||||
license=_LICENSE,
|
||||
citation=_CITATION
|
||||
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
|
||||
file_path = dl_manager.download(_URL)
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"filepath": file_path
|
||||
}
|
||||
)
|
||||
]
|
||||
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
|
||||
|
||||
def _generate_examples(self, filepath: str) -> Generator[Tuple[int, Dict[str, Any]], None, None]:
|
||||
example_dataset = json.load(open(filepath, "r", encoding="utf-8"))
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
import json
|
||||
import datasets
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import datasets
|
||||
|
||||
|
||||
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
|
||||
_DESCRIPTION = "Human preference data about helpfulness and harmlessness."
|
||||
_CITATION = ""
|
||||
@@ -14,50 +16,37 @@ _URLS = {
|
||||
_URL + "harmless-base/train.jsonl.gz",
|
||||
_URL + "helpful-base/train.jsonl.gz",
|
||||
_URL + "helpful-online/train.jsonl.gz",
|
||||
_URL + "helpful-rejection-sampled/train.jsonl.gz"
|
||||
_URL + "helpful-rejection-sampled/train.jsonl.gz",
|
||||
],
|
||||
"test": [
|
||||
_URL + "harmless-base/test.jsonl.gz",
|
||||
_URL + "helpful-base/test.jsonl.gz",
|
||||
_URL + "helpful-online/test.jsonl.gz",
|
||||
_URL + "helpful-rejection-sampled/test.jsonl.gz"
|
||||
]
|
||||
_URL + "helpful-rejection-sampled/test.jsonl.gz",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class HhRlhfEn(datasets.GeneratorBasedBuilder):
|
||||
|
||||
VERSION = datasets.Version("0.0.0")
|
||||
|
||||
def _info(self) -> datasets.DatasetInfo:
|
||||
features = datasets.Features({
|
||||
features = datasets.Features(
|
||||
{
|
||||
"instruction": datasets.Value("string"),
|
||||
"output": datasets.Sequence(datasets.Value("string")),
|
||||
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
|
||||
})
|
||||
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
|
||||
}
|
||||
)
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
features=features,
|
||||
homepage=_HOMEPAGE,
|
||||
license=_LICENSE,
|
||||
citation=_CITATION
|
||||
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager):
|
||||
file_path = dl_manager.download_and_extract(_URLS)
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"filepaths": file_path["train"]
|
||||
}
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TEST,
|
||||
gen_kwargs={
|
||||
"filepaths": file_path["test"]
|
||||
}
|
||||
)
|
||||
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": file_path["train"]}),
|
||||
datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepaths": file_path["test"]}),
|
||||
]
|
||||
|
||||
def _generate_examples(self, filepaths: List[str]):
|
||||
@@ -90,9 +79,5 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
|
||||
break
|
||||
prompt = prompt[:human_idx]
|
||||
|
||||
yield key, {
|
||||
"instruction": query,
|
||||
"output": [r_accept, r_reject],
|
||||
"history": history
|
||||
}
|
||||
yield key, {"instruction": query, "output": [r_accept, r_reject], "history": history}
|
||||
key += 1
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
import json
|
||||
import datasets
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import datasets
|
||||
|
||||
|
||||
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
|
||||
|
||||
_DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data."
|
||||
@@ -24,31 +26,19 @@ _BASE_DATA_URL = "{}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jso
|
||||
|
||||
|
||||
class UltraChat(datasets.GeneratorBasedBuilder):
|
||||
|
||||
VERSION = datasets.Version("0.0.0")
|
||||
|
||||
def _info(self):
|
||||
features = datasets.Features({
|
||||
"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]
|
||||
})
|
||||
features = datasets.Features(
|
||||
{"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]}
|
||||
)
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
features=features,
|
||||
homepage=_HOMEPAGE,
|
||||
license=_LICENSE,
|
||||
citation=_CITATION
|
||||
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager):
|
||||
file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(10)] # multiple shards
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"filepaths": file_paths
|
||||
}
|
||||
)
|
||||
]
|
||||
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": file_paths})]
|
||||
|
||||
def _generate_examples(self, filepaths: List[str]):
|
||||
for filepath in filepaths:
|
||||
@@ -56,7 +46,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
|
||||
for row in f:
|
||||
try:
|
||||
data = json.loads(row)
|
||||
except:
|
||||
except Exception:
|
||||
continue
|
||||
key: int = data["id"]
|
||||
content: List[str] = data["data"]
|
||||
@@ -64,8 +54,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
|
||||
content.pop(-1)
|
||||
if len(content) < 2:
|
||||
continue
|
||||
conversations = [{
|
||||
"from": "human" if i % 2 == 0 else "gpt",
|
||||
"value": content[i]
|
||||
} for i in range(len(content))]
|
||||
conversations = [
|
||||
{"from": "human" if i % 2 == 0 else "gpt", "value": content[i]} for i in range(len(content))
|
||||
]
|
||||
yield key, {"conversations": conversations}
|
||||
|
||||
@@ -3,41 +3,46 @@ We provide diverse examples about fine-tuning LLMs.
|
||||
```
|
||||
examples/
|
||||
├── lora_single_gpu/
|
||||
│ ├── pretrain.sh: Do pre-training
|
||||
│ ├── sft.sh: Do supervised fine-tuning
|
||||
│ ├── reward.sh: Do reward modeling
|
||||
│ ├── ppo.sh: Do PPO training
|
||||
│ ├── dpo.sh: Do DPO training
|
||||
│ ├── orpo.sh: Do ORPO training
|
||||
│ ├── pretrain.sh: Do continuous pre-training using LoRA
|
||||
│ ├── sft.sh: Do supervised fine-tuning using LoRA
|
||||
│ ├── reward.sh: Do reward modeling using LoRA
|
||||
│ ├── ppo.sh: Do PPO training using LoRA
|
||||
│ ├── dpo.sh: Do DPO training using LoRA
|
||||
│ ├── orpo.sh: Do ORPO training using LoRA
|
||||
│ ├── prepare.sh: Save tokenized dataset
|
||||
│ └── predict.sh: Do batch predict
|
||||
│ └── predict.sh: Do batch predict and compute BLEU and ROUGE scores after LoRA tuning
|
||||
├── qlora_single_gpu/
|
||||
│ ├── bitsandbytes.sh: Fine-tune 4/8-bit BNB models
|
||||
│ ├── gptq.sh: Fine-tune 4/8-bit GPTQ models
|
||||
│ ├── awq.sh: Fine-tune 4-bit AWQ models
|
||||
│ └── aqlm.sh: Fine-tune 2-bit AQLM models
|
||||
│ ├── bitsandbytes.sh: Fine-tune 4/8-bit BNB models using QLoRA
|
||||
│ ├── gptq.sh: Fine-tune 4/8-bit GPTQ models using QLoRA
|
||||
│ ├── awq.sh: Fine-tune 4-bit AWQ models using QLoRA
|
||||
│ └── aqlm.sh: Fine-tune 2-bit AQLM models using QLoRA
|
||||
├── lora_multi_gpu/
|
||||
│ ├── single_node.sh: Fine-tune model with Accelerate on single node
|
||||
│ └── multi_node.sh: Fine-tune model with Accelerate on multiple nodes
|
||||
│ ├── single_node.sh: Fine-tune model with Accelerate on single node using LoRA
|
||||
│ └── multi_node.sh: Fine-tune model with Accelerate on multiple nodes using LoRA
|
||||
├── full_multi_gpu/
|
||||
│ ├── single_node.sh: Fine-tune model with DeepSpeed on single node
|
||||
│ └── multi_node.sh: Fine-tune model with DeepSpeed on multiple nodes
|
||||
│ ├── single_node.sh: Full fine-tune model with DeepSpeed on single node
|
||||
│ ├── multi_node.sh: Full fine-tune model with DeepSpeed on multiple nodes
|
||||
│ └── predict.sh: Do batch predict and compute BLEU and ROUGE scores after full tuning
|
||||
├── merge_lora/
|
||||
│ ├── merge.sh: Merge LoRA weights into the pre-trained models
|
||||
│ └── quantize.sh: Quantize fine-tuned model with AutoGPTQ
|
||||
│ └── quantize.sh: Quantize the fine-tuned model with AutoGPTQ
|
||||
├── inference/
|
||||
│ ├── cli_demo.sh: Launch a command line interface
|
||||
│ ├── api_demo.sh: Launch an OpenAI-style API
|
||||
│ ├── web_demo.sh: Launch a web interface
|
||||
│ └── evaluate.sh: Evaluate model on the MMLU benchmark
|
||||
│ ├── cli_demo.sh: Launch a command line interface with LoRA adapters
|
||||
│ ├── api_demo.sh: Launch an OpenAI-style API with LoRA adapters
|
||||
│ ├── web_demo.sh: Launch a web interface with LoRA adapters
|
||||
│ └── evaluate.sh: Evaluate model on the MMLU/CMMLU/C-Eval benchmarks with LoRA adapters
|
||||
└── extras/
|
||||
├── galore/
|
||||
│ └── sft.sh: Fine-tune model with GaLore
|
||||
├── badam/
|
||||
│ └── sft.sh: Fine-tune model with BAdam
|
||||
├── loraplus/
|
||||
│ └── sft.sh: Fine-tune model with LoRA+
|
||||
│ └── sft.sh: Fine-tune model using LoRA+
|
||||
├── mod/
|
||||
│ └── sft.sh: Fine-tune model using Mixture-of-Depths
|
||||
├── llama_pro/
|
||||
│ ├── expand.sh: Expand layers in the model
|
||||
│ └── sft.sh: Fine-tune expanded model
|
||||
│ └── sft.sh: Fine-tune the expanded model
|
||||
└── fsdp_qlora/
|
||||
└── sft.sh: Fine-tune quantized model with FSDP
|
||||
└── sft.sh: Fine-tune quantized model with FSDP+QLoRA
|
||||
```
|
||||
|
||||
@@ -1,43 +1,48 @@
|
||||
我们提供了多样化的示例脚本。
|
||||
我们提供了多样化的大模型微调示例脚本。
|
||||
|
||||
```
|
||||
examples/
|
||||
├── lora_single_gpu/
|
||||
│ ├── pretrain.sh: 进行预训练
|
||||
│ ├── sft.sh: 进行指令监督微调
|
||||
│ ├── reward.sh: 进行奖励模型训练
|
||||
│ ├── ppo.sh: 进行 PPO 训练
|
||||
│ ├── dpo.sh: 进行 DPO 训练
|
||||
│ ├── orpo.sh: 进行 ORPO 训练
|
||||
│ ├── pretrain.sh: 基于 LoRA 进行增量预训练
|
||||
│ ├── sft.sh: 基于 LoRA 进行指令监督微调
|
||||
│ ├── reward.sh: 基于 LoRA 进行奖励模型训练
|
||||
│ ├── ppo.sh: 基于 LoRA 进行 PPO 训练
|
||||
│ ├── dpo.sh: 基于 LoRA 进行 DPO 训练
|
||||
│ ├── orpo.sh: 基于 LoRA 进行 ORPO 训练
|
||||
│ ├── prepare.sh: 保存预处理后的数据集
|
||||
│ └── predict.sh: 进行批量预测
|
||||
│ └── predict.sh: 基于 LoRA 进行批量预测并计算 BLEU 和 ROUGE 分数
|
||||
├── qlora_single_gpu/
|
||||
│ ├── bitsandbytes.sh: 微调 4/8 比特 BNB 模型
|
||||
│ ├── gptq.sh: 微调 4/8 比特 GPTQ 模型
|
||||
│ ├── awq.sh: 微调 4 比特 AWQ 模型
|
||||
│ └── aqlm.sh: 微调 2 比特 AQLM 模型
|
||||
│ ├── bitsandbytes.sh: 基于 QLoRA 微调 4/8 比特 BNB 模型
|
||||
│ ├── gptq.sh: 基于 QLoRA 微调 4/8 比特 GPTQ 模型
|
||||
│ ├── awq.sh: 基于 QLoRA 微调 4 比特 AWQ 模型
|
||||
│ └── aqlm.sh: 基于 QLoRA 微调 2 比特 AQLM 模型
|
||||
├── lora_multi_gpu/
|
||||
│ ├── single_node.sh: 使用 Accelerate 进行单节点训练
|
||||
│ └── multi_node.sh: 使用 Accelerate 进行多节点训练
|
||||
│ ├── single_node.sh: 使用 Accelerate 进行单节点 LoRA 训练
|
||||
│ └── multi_node.sh: 使用 Accelerate 进行多节点 LoRA 训练
|
||||
├── full_multi_gpu/
|
||||
│ ├── single_node.sh: 使用 DeepSpeed 进行单节点训练
|
||||
│ └── multi_node.sh: 使用 DeepSpeed 进行多节点训练
|
||||
│ ├── single_node.sh: 使用 DeepSpeed 进行单节点全量训练
|
||||
│ ├── multi_node.sh: 使用 DeepSpeed 进行多节点全量训练
|
||||
│ └── predict.sh: 基于全量训练进行批量预测并计算 BLEU 和 ROUGE 分数
|
||||
├── merge_lora/
|
||||
│ ├── merge.sh: 将 LoRA 权重合并到预训练模型中
|
||||
│ └── quantize.sh: 使用 AutoGPTQ 量化模型
|
||||
│ └── quantize.sh: 使用 AutoGPTQ 量化微调后的模型
|
||||
├── inference/
|
||||
│ ├── cli_demo.sh: 启动命令行推理接口
|
||||
│ ├── api_demo.sh: 启动 OpenAI 风格 API
|
||||
│ ├── web_demo.sh: 启动浏览器推理接口
|
||||
│ └── evaluate.sh: 在 MMLU 数据集上评测模型
|
||||
│ ├── cli_demo.sh: 启动 LoRA 模型的命令行推理接口
|
||||
│ ├── api_demo.sh: 启动 LoRA 模型的 OpenAI 风格 API
|
||||
│ ├── web_demo.sh: 启动 LoRA 模型的浏览器推理接口
|
||||
│ └── evaluate.sh: 在 MMLU/CMMLU/C-Eval 数据集上评测 LoRA 模型
|
||||
└── extras/
|
||||
├── galore/
|
||||
│ └── sft.sh: 使用 GaLore 训练模型
|
||||
├── badam/
|
||||
│ └── sft.sh: 使用 BAdam 训练模型
|
||||
├── loraplus/
|
||||
│ └── sft.sh: 使用 LoRA+ 训练模型
|
||||
├── mod/
|
||||
│ └── sft.sh: 使用深度混合训练模型
|
||||
├── llama_pro/
|
||||
│ ├── expand.sh: 扩展模型中的层
|
||||
│ └── sft.sh: 训练扩展后的模型
|
||||
└── fsdp_qlora/
|
||||
└── sft.sh: 使用 FSDP 微调量化模型
|
||||
└── sft.sh: 使用 FSDP+QLoRA 微调量化模型
|
||||
```
|
||||
|
||||
33
examples/extras/MoD/sft.sh
Normal file
33
examples/extras/MoD/sft.sh
Normal file
@@ -0,0 +1,33 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||
--dataset_dir ../../../data \
|
||||
--template default \
|
||||
--finetuning_type full \
|
||||
--mixture_of_depths convert \
|
||||
--output_dir ../../../saves/LLaMA2-7B/mod/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--preprocessing_num_workers 16 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--optim paged_adamw_8bit \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--warmup_steps 20 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--load_best_model_at_end \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--max_samples 3000 \
|
||||
--val_size 0.1 \
|
||||
--plot_loss \
|
||||
--pure_bf16
|
||||
35
examples/extras/badam/sft.sh
Normal file
35
examples/extras/badam/sft.sh
Normal file
@@ -0,0 +1,35 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||
--dataset_dir ../../../data \
|
||||
--template default \
|
||||
--finetuning_type full \
|
||||
--use_badam \
|
||||
--badam_switch_mode descending \
|
||||
--badam_switch_block_every 50 \
|
||||
--badam_verbose 2 \
|
||||
--output_dir ../../../saves/LLaMA2-7B/badam/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--preprocessing_num_workers 16 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--warmup_steps 20 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--load_best_model_at_end \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--max_samples 3000 \
|
||||
--val_size 0.1 \
|
||||
--plot_loss \
|
||||
--pure_bf16
|
||||
@@ -12,6 +12,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
||||
--galore_layerwise \
|
||||
--galore_target mlp,self_attn \
|
||||
--galore_rank 128 \
|
||||
--galore_scale 2.0 \
|
||||
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
|
||||
@@ -9,6 +9,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--loraplus_lr_ratio 16.0 \
|
||||
--output_dir ../../saves/LLaMA2-7B/loraplus/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
@@ -29,5 +30,4 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
--max_samples 3000 \
|
||||
--val_size 0.1 \
|
||||
--plot_loss \
|
||||
--fp16 \
|
||||
--loraplus_lr_ratio 16.0
|
||||
--fp16
|
||||
|
||||
18
examples/full_multi_gpu/predict.sh
Normal file
18
examples/full_multi_gpu/predict.sh
Normal file
@@ -0,0 +1,18 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
--stage sft \
|
||||
--do_predict \
|
||||
--model_name_or_path ../../saves/LLaMA2-7B/full/sft \
|
||||
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||
--dataset_dir ../../data \
|
||||
--template default \
|
||||
--finetuning_type full \
|
||||
--output_dir ../../saves/LLaMA2-7B/full/predict \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--preprocessing_num_workers 16 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--max_samples 20 \
|
||||
--predict_with_generate
|
||||
@@ -3,7 +3,7 @@
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/evaluate.py \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||
--template vanilla \
|
||||
--template fewshot \
|
||||
--finetuning_type lora \
|
||||
--task mmlu \
|
||||
--split test \
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/bin/bash
|
||||
# DO NOT use quantized model or quantization_bit when merging lora weights
|
||||
|
||||
CUDA_VISIBLE_DEVICES= python ../../src/export_model.py \
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||
--template default \
|
||||
|
||||
@@ -4,7 +4,7 @@ datasets>=2.14.3
|
||||
accelerate>=0.27.2
|
||||
peft>=0.10.0
|
||||
trl>=0.8.1
|
||||
gradio>=4.0.0,<=4.21.0
|
||||
gradio>=4.0.0
|
||||
scipy
|
||||
einops
|
||||
sentencepiece
|
||||
|
||||
1
setup.py
1
setup.py
@@ -24,6 +24,7 @@ extra_require = {
|
||||
"metrics": ["nltk", "jieba", "rouge-chinese"],
|
||||
"unsloth": ["torch==2.2.0", "unsloth[cu121-ampere-torch220]"],
|
||||
"galore": ["galore-torch"],
|
||||
"badam": ["badam"],
|
||||
"vllm": ["vllm>=0.3.3"],
|
||||
"bitsandbytes": ["bitsandbytes>=0.39.0"],
|
||||
"gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"],
|
||||
|
||||
@@ -7,5 +7,5 @@ from .train import export_model, run_exp
|
||||
from .webui import create_ui, create_web_demo
|
||||
|
||||
|
||||
__version__ = "0.6.2"
|
||||
__version__ = "0.6.3"
|
||||
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]
|
||||
|
||||
@@ -5,14 +5,11 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Opti
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
from vllm import AsyncLLMEngine
|
||||
|
||||
from ..data import Template
|
||||
from ..extras.packages import is_vllm_available
|
||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
if is_vllm_available():
|
||||
from vllm import AsyncLLMEngine
|
||||
|
||||
|
||||
@dataclass
|
||||
class Response:
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence
|
||||
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.misc import get_device_count
|
||||
from ..extras.packages import is_vllm_available
|
||||
@@ -25,7 +23,6 @@ class VllmEngine(BaseEngine):
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
) -> None:
|
||||
require_version("vllm>=0.3.3", "To fix: pip install vllm>=0.3.3")
|
||||
self.can_generate = finetuning_args.stage == "sft"
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model_args.model_name_or_path,
|
||||
|
||||
@@ -343,7 +343,7 @@ def get_template_and_fix_tokenizer(
|
||||
name: Optional[str] = None,
|
||||
) -> Template:
|
||||
if name is None:
|
||||
template = templates["vanilla"] # placeholder
|
||||
template = templates["empty"] # placeholder
|
||||
else:
|
||||
template = templates.get(name, None)
|
||||
if template is None:
|
||||
@@ -385,7 +385,8 @@ _register_template(
|
||||
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
default_system=(
|
||||
"Below is an instruction that describes a task. " "Write a response that appropriately completes the request."
|
||||
"Below is an instruction that describes a task. "
|
||||
"Write a response that appropriately completes the request.\n\n"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -526,6 +527,21 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="cohere",
|
||||
format_user=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"
|
||||
"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
|
||||
)
|
||||
]
|
||||
),
|
||||
format_system=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="cpm",
|
||||
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
|
||||
@@ -566,6 +582,13 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="empty",
|
||||
format_user=StringFormatter(slots=["{{content}}"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}"]),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="falcon",
|
||||
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
|
||||
@@ -574,6 +597,13 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="fewshot",
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="gemma",
|
||||
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
||||
@@ -635,6 +665,25 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="llama3",
|
||||
format_user=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
]
|
||||
),
|
||||
format_system=StringFormatter(
|
||||
slots=[{"bos_token"}, "<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]
|
||||
),
|
||||
default_system="You are a helpful assistant.",
|
||||
stop_words=["<|eot_id|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="mistral",
|
||||
format_user=StringFormatter(slots=[" [INST] {{content}} [/INST]"]),
|
||||
@@ -699,13 +748,6 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="vanilla",
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="vicuna",
|
||||
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||
|
||||
@@ -28,6 +28,8 @@ LOG_FILE_NAME = "trainer_log.jsonl"
|
||||
|
||||
METHODS = ["full", "freeze", "lora"]
|
||||
|
||||
MOD_SUPPORTED_MODELS = ["bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"]
|
||||
|
||||
PEFT_METHODS = ["lora"]
|
||||
|
||||
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
|
||||
@@ -242,6 +244,28 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"CommandR-35B-Chat": {
|
||||
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-v01",
|
||||
},
|
||||
"CommandR-Plus-104B-Chat": {
|
||||
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-plus",
|
||||
},
|
||||
"CommandR-35B-4bit-Chat": {
|
||||
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01-4bit",
|
||||
DownloadSource.MODELSCOPE: "mirror013/c4ai-command-r-v01-4bit",
|
||||
},
|
||||
"CommandR-Plus-104B-4bit-Chat": {
|
||||
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus-4bit",
|
||||
},
|
||||
},
|
||||
template="cohere",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"DeepSeek-LLM-7B-Base": {
|
||||
@@ -363,6 +387,23 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"CodeGemma-2B": {
|
||||
DownloadSource.DEFAULT: "google/codegemma-2b",
|
||||
},
|
||||
"CodeGemma-7B": {
|
||||
DownloadSource.DEFAULT: "google/codegemma-7b",
|
||||
},
|
||||
"CodeGemma-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "google/codegemma-7b-it",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/codegemma-7b-it",
|
||||
},
|
||||
},
|
||||
template="gemma",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"InternLM-7B": {
|
||||
@@ -474,6 +515,29 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaMA3-8B": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B",
|
||||
},
|
||||
"LLaMA3-70B": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B",
|
||||
},
|
||||
"LLaMA3-8B-Chat": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B-Instruct",
|
||||
},
|
||||
"LLaMA3-70B-Chat": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B-Instruct",
|
||||
},
|
||||
},
|
||||
template="llama3",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Mistral-7B-v0.1": {
|
||||
@@ -499,14 +563,20 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Mixtral-8x7B": {
|
||||
"Mixtral-8x7B-v0.1": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1",
|
||||
},
|
||||
"Mixtral-8x7B-Chat": {
|
||||
"Mixtral-8x7B-v0.1-Chat": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1",
|
||||
},
|
||||
"Mixtral-8x22B-v0.1": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-v0.1",
|
||||
},
|
||||
"Mixtral-8x22B-v0.1-Chat": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-Instruct-v0.1",
|
||||
},
|
||||
},
|
||||
template="mistral",
|
||||
)
|
||||
@@ -688,6 +758,10 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B",
|
||||
},
|
||||
"Qwen1.5-Code-7B": {
|
||||
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
|
||||
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B",
|
||||
},
|
||||
"Qwen1.5-0.5B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat",
|
||||
@@ -720,6 +794,10 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat",
|
||||
},
|
||||
"Qwen1.5-Code-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat",
|
||||
},
|
||||
"Qwen1.5-0.5B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
||||
@@ -776,6 +854,10 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
|
||||
},
|
||||
"Qwen1.5-Code-7B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat-AWQ",
|
||||
},
|
||||
},
|
||||
template="qwen",
|
||||
)
|
||||
@@ -979,18 +1061,3 @@ register_model_group(
|
||||
},
|
||||
template="zephyr",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Atom-7B": {
|
||||
DownloadSource.DEFAULT: "FlagAlpha/Atom-7B",
|
||||
DownloadSource.MODELSCOPE: "FlagAlpha/Atom-7B",
|
||||
},
|
||||
"Atom-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "FlagAlpha/Atom-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "FlagAlpha/Atom-7B-Chat",
|
||||
},
|
||||
},
|
||||
template="atom",
|
||||
)
|
||||
|
||||
@@ -66,7 +66,6 @@ def check_dependencies() -> None:
|
||||
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
|
||||
require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0")
|
||||
require_version("trl>=0.8.1", "To fix: pip install trl>=0.8.1")
|
||||
require_version("gradio>=4.0.0,<=4.21.0", "To fix: pip install gradio==4.21.0")
|
||||
|
||||
|
||||
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
@@ -84,6 +83,8 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
if param.__class__.__name__ == "Params4bit":
|
||||
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
|
||||
num_bytes = param.quant_storage.itemsize
|
||||
elif hasattr(param, "element_size"): # for older pytorch version
|
||||
num_bytes = param.element_size()
|
||||
else:
|
||||
num_bytes = 1
|
||||
|
||||
|
||||
@@ -25,6 +25,10 @@ def is_galore_available():
|
||||
return _is_package_available("galore_torch")
|
||||
|
||||
|
||||
def is_gradio_available():
|
||||
return _is_package_available("gradio")
|
||||
|
||||
|
||||
def is_jieba_available():
|
||||
return _is_package_available("jieba")
|
||||
|
||||
@@ -49,10 +53,6 @@ def is_starlette_available():
|
||||
return _is_package_available("sse_starlette")
|
||||
|
||||
|
||||
def is_unsloth_available():
|
||||
return _is_package_available("unsloth")
|
||||
|
||||
|
||||
def is_uvicorn_available():
|
||||
return _is_package_available("uvicorn")
|
||||
|
||||
|
||||
@@ -172,7 +172,7 @@ class GaloreArguments:
|
||||
|
||||
use_galore: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use gradient low-Rank projection."},
|
||||
metadata={"help": "Whether or not to use the gradient low-Rank projection (GaLore)."},
|
||||
)
|
||||
galore_target: str = field(
|
||||
default="all",
|
||||
@@ -204,7 +204,54 @@ class GaloreArguments:
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments):
|
||||
class BAdamArgument:
|
||||
r"""
|
||||
Arguments pertaining to the BAdam optimizer.
|
||||
"""
|
||||
|
||||
use_badam: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use the BAdam optimizer."},
|
||||
)
|
||||
badam_mode: Literal["layer", "ratio"] = field(
|
||||
default="layer",
|
||||
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
|
||||
)
|
||||
badam_start_block: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The starting block index for layer-wise BAdam."},
|
||||
)
|
||||
badam_switch_block_every: Optional[int] = field(
|
||||
default=50,
|
||||
metadata={"help": "How often to switch model's block update. Set to -1 to disable the block update."},
|
||||
)
|
||||
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
|
||||
default="ascending",
|
||||
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
|
||||
)
|
||||
badam_update_ratio: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "The ratio of the update for ratio-wise BAdam."},
|
||||
)
|
||||
badam_mask_mode: Literal["adjacent", "scatter"] = field(
|
||||
default="adjacent",
|
||||
metadata={
|
||||
"help": """The mode of the mask for BAdam optimizer. \
|
||||
`adjacent` means that the trainable parameters are adjacent to each other, \
|
||||
`scatter` means that trainable parameters are randomly choosed from the weight."""
|
||||
},
|
||||
)
|
||||
badam_verbose: int = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": """The verbosity level of BAdam optimizer. \
|
||||
0 for no print, 1 for print the block prefix, 2 for print trainable parameters"""
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument):
|
||||
r"""
|
||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
"""
|
||||
@@ -256,11 +303,14 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||
raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
|
||||
|
||||
if self.use_llama_pro and self.finetuning_type == "full":
|
||||
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.")
|
||||
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA training.")
|
||||
|
||||
if self.use_galore and self.finetuning_type == "lora":
|
||||
raise ValueError("Cannot use LoRA with GaLore together.")
|
||||
|
||||
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
|
||||
raise ValueError("`loraplus_lr_ratio` is only valid for the LoRA training.")
|
||||
|
||||
def save_to_json(self, json_path: str):
|
||||
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
@@ -55,7 +55,7 @@ class ModelArguments:
|
||||
)
|
||||
quantization_device_map: Optional[Literal["auto"]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Device map used for loading the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
|
||||
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
|
||||
)
|
||||
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
|
||||
default=None,
|
||||
@@ -63,12 +63,16 @@ class ModelArguments:
|
||||
)
|
||||
flash_attn: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable FlashAttention-2 for faster training."},
|
||||
metadata={"help": "Enable FlashAttention for faster training."},
|
||||
)
|
||||
shift_attn: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
|
||||
)
|
||||
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
|
||||
)
|
||||
use_unsloth: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
|
||||
@@ -129,6 +133,10 @@ class ModelArguments:
|
||||
default=1,
|
||||
metadata={"help": "The file shard size (in GB) of the exported model."},
|
||||
)
|
||||
export_device: str = field(
|
||||
default="cpu",
|
||||
metadata={"help": "The device used in model export."},
|
||||
)
|
||||
export_quantization_bit: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the exported model."},
|
||||
|
||||
@@ -8,10 +8,10 @@ import transformers
|
||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import check_dependencies
|
||||
from ..extras.packages import is_unsloth_available
|
||||
from ..extras.misc import check_dependencies, get_current_device
|
||||
from .data_args import DataArguments
|
||||
from .evaluation_args import EvaluationArguments
|
||||
from .finetuning_args import FinetuningArguments
|
||||
@@ -74,6 +74,32 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
|
||||
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
||||
|
||||
|
||||
def _check_extra_dependencies(
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
training_args: Optional["Seq2SeqTrainingArguments"] = None,
|
||||
) -> None:
|
||||
if model_args.use_unsloth:
|
||||
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
|
||||
|
||||
if model_args.mixture_of_depths is not None:
|
||||
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
require_version("vllm>=0.3.3", "To fix: pip install vllm>=0.3.3")
|
||||
|
||||
if finetuning_args.use_galore:
|
||||
require_version("galore_torch", "To fix: pip install galore_torch")
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
require_version("badam", "To fix: pip install badam")
|
||||
|
||||
if training_args is not None and training_args.predict_with_generate:
|
||||
require_version("jieba", "To fix: pip install jieba")
|
||||
require_version("nltk", "To fix: pip install nltk")
|
||||
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
|
||||
|
||||
|
||||
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||
return _parse_args(parser, args)
|
||||
@@ -131,8 +157,8 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
if training_args.do_train and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
||||
|
||||
if training_args.do_train and model_args.use_unsloth and not is_unsloth_available():
|
||||
raise ValueError("Unsloth was not installed: https://github.com/unslothai/unsloth")
|
||||
if training_args.do_train and model_args.quantization_device_map == "auto":
|
||||
raise ValueError("Cannot use device map for quantized models in training.")
|
||||
|
||||
if finetuning_args.use_dora and model_args.use_unsloth:
|
||||
raise ValueError("Unsloth does not support DoRA.")
|
||||
@@ -151,13 +177,21 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
):
|
||||
raise ValueError("Distributed training does not support layer-wise GaLore.")
|
||||
|
||||
if finetuning_args.use_galore and training_args.deepspeed is not None:
|
||||
raise ValueError("GaLore is incompatible with DeepSpeed.")
|
||||
if (
|
||||
finetuning_args.use_badam
|
||||
and finetuning_args.badam_mode == "layer"
|
||||
and training_args.parallel_mode.value == "distributed"
|
||||
):
|
||||
raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.")
|
||||
|
||||
if (finetuning_args.use_galore or finetuning_args.use_badam) and training_args.deepspeed is not None:
|
||||
raise ValueError("GaLore and BAdam are incompatible with DeepSpeed yet.")
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
||||
|
||||
if (
|
||||
training_args.do_train
|
||||
@@ -235,6 +269,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
elif training_args.fp16:
|
||||
model_args.compute_dtype = torch.float16
|
||||
|
||||
model_args.device_map = {"": get_current_device()}
|
||||
model_args.model_max_length = data_args.cutoff_len
|
||||
data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"
|
||||
|
||||
@@ -276,7 +311,11 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||
raise ValueError("vLLM engine does not support RoPE scaling.")
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args)
|
||||
|
||||
if model_args.export_dir is not None:
|
||||
model_args.device_map = {"": torch.device(model_args.export_device)}
|
||||
else:
|
||||
model_args.device_map = "auto"
|
||||
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
@@ -294,6 +333,7 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args)
|
||||
|
||||
model_args.device_map = "auto"
|
||||
|
||||
|
||||
@@ -32,9 +32,12 @@ def init_adapter(
|
||||
logger.info("Adapter is not found at evaluation, load the base model.")
|
||||
return model
|
||||
|
||||
if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None):
|
||||
raise ValueError("You can only use lora for quantized models.")
|
||||
|
||||
if finetuning_args.finetuning_type == "full" and is_trainable:
|
||||
logger.info("Fine-tuning method: Full")
|
||||
if not finetuning_args.pure_bf16:
|
||||
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
|
||||
model = model.float()
|
||||
|
||||
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
||||
@@ -66,6 +69,8 @@ def init_adapter(
|
||||
for name, _ in model.named_modules():
|
||||
if ".0." in name:
|
||||
freeze_modules.add(name.split(".0.")[-1].split(".")[0])
|
||||
elif ".1." in name: # MoD starts from layer 1
|
||||
freeze_modules.add(name.split(".1.")[-1].split(".")[0])
|
||||
|
||||
trainable_layers = []
|
||||
for module_name in finetuning_args.name_module_trainable:
|
||||
@@ -79,7 +84,7 @@ def init_adapter(
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if any(trainable_layer in name for trainable_layer in trainable_layers):
|
||||
if not finetuning_args.pure_bf16:
|
||||
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
|
||||
param.data = param.data.to(torch.float32)
|
||||
else:
|
||||
param.requires_grad_(False)
|
||||
@@ -129,8 +134,11 @@ def init_adapter(
|
||||
if finetuning_args.use_llama_pro:
|
||||
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
|
||||
|
||||
if finetuning_args.use_dora and getattr(model, "quantization_method", None) is not None:
|
||||
if getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES:
|
||||
if (
|
||||
finetuning_args.use_dora
|
||||
and getattr(model, "quantization_method", None) is not None
|
||||
and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES
|
||||
):
|
||||
raise ValueError("DoRA is not compatible with PTQ-quantized models.")
|
||||
|
||||
peft_kwargs = {
|
||||
@@ -139,24 +147,28 @@ def init_adapter(
|
||||
"lora_alpha": finetuning_args.lora_alpha,
|
||||
"lora_dropout": finetuning_args.lora_dropout,
|
||||
"use_rslora": finetuning_args.use_rslora,
|
||||
"modules_to_save": finetuning_args.additional_target,
|
||||
}
|
||||
|
||||
if model_args.use_unsloth:
|
||||
from unsloth import FastLanguageModel # type: ignore
|
||||
|
||||
unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length}
|
||||
unsloth_peft_kwargs = {
|
||||
"model": model,
|
||||
"max_seq_length": model_args.model_max_length,
|
||||
"use_gradient_checkpointing": "unsloth",
|
||||
}
|
||||
model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
|
||||
else:
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
modules_to_save=finetuning_args.additional_target,
|
||||
use_dora=finetuning_args.use_dora,
|
||||
**peft_kwargs,
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
if not finetuning_args.pure_bf16:
|
||||
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
|
||||
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Dict
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..extras.constants import MOD_SUPPORTED_MODELS
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
|
||||
from .adapter import init_adapter
|
||||
@@ -36,6 +37,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer":
|
||||
Note: including inplace operation of model_args.
|
||||
"""
|
||||
init_kwargs = _get_init_kwargs(model_args)
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
@@ -43,6 +45,14 @@ def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer":
|
||||
padding_side="right",
|
||||
**init_kwargs,
|
||||
)
|
||||
except ValueError: # try the fast one
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
use_fast=True,
|
||||
padding_side="right",
|
||||
**init_kwargs,
|
||||
)
|
||||
|
||||
patch_tokenizer(tokenizer)
|
||||
return tokenizer
|
||||
|
||||
@@ -73,6 +83,8 @@ def load_model(
|
||||
"token": model_args.hf_hub_token,
|
||||
"device_map": {"": get_current_device()},
|
||||
"rope_scaling": getattr(config, "rope_scaling", None),
|
||||
"fix_tokenizer": False,
|
||||
"trust_remote_code": True,
|
||||
}
|
||||
try:
|
||||
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
|
||||
@@ -85,7 +97,24 @@ def load_model(
|
||||
logger.warning("Unsloth does not support loading adapters.")
|
||||
|
||||
if model is None:
|
||||
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **init_kwargs)
|
||||
init_kwargs["config"] = config
|
||||
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
|
||||
|
||||
if model_args.mixture_of_depths == "load":
|
||||
from MoD import AutoMoDModelForCausalLM
|
||||
|
||||
model = AutoMoDModelForCausalLM.from_pretrained(**init_kwargs)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(**init_kwargs)
|
||||
|
||||
if model_args.mixture_of_depths == "convert":
|
||||
from MoD import apply_mod_to_hf
|
||||
|
||||
if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS:
|
||||
raise ValueError("Current model is not supported by mixture-of-depth.")
|
||||
|
||||
model = apply_mod_to_hf(model)
|
||||
model = model.to(model_args.compute_dtype)
|
||||
|
||||
patch_model(model, tokenizer, model_args, is_trainable)
|
||||
register_autoclass(config, model, tokenizer)
|
||||
@@ -93,7 +122,7 @@ def load_model(
|
||||
model = init_adapter(model, model_args, finetuning_args, is_trainable)
|
||||
|
||||
if add_valuehead:
|
||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
patch_valuehead_model(model)
|
||||
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
|
||||
@@ -17,7 +17,7 @@ from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_current_device, infer_optim_dtype
|
||||
from ..extras.packages import is_flash_attn2_available
|
||||
from ..extras.patches.llama_patch import apply_llama_patch
|
||||
from .utils import QuantizationMethod, add_z3_leaf_module
|
||||
from .utils import QuantizationMethod, add_z3_leaf_module, gradient_checkpointing_enable
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -61,9 +61,7 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
|
||||
return samples
|
||||
|
||||
|
||||
def _configure_attn_implementation(
|
||||
config: "PretrainedConfig", model_args: "ModelArguments", init_kwargs: Dict[str, Any]
|
||||
) -> None:
|
||||
def _configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
|
||||
if model_args.flash_attn:
|
||||
if not is_flash_attn2_available():
|
||||
logger.warning("FlashAttention2 is not installed.")
|
||||
@@ -73,9 +71,9 @@ def _configure_attn_implementation(
|
||||
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
|
||||
setattr(config, "attn_implementation", "flash_attention_2")
|
||||
else:
|
||||
init_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
setattr(config, "_attn_implementation", "flash_attention_2")
|
||||
else:
|
||||
init_kwargs["attn_implementation"] = "eager"
|
||||
setattr(config, "_attn_implementation", "eager")
|
||||
|
||||
|
||||
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
@@ -133,12 +131,15 @@ def _configure_quantization(
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.")
|
||||
|
||||
if model_args.quantization_device_map != "auto":
|
||||
init_kwargs["device_map"] = {"": get_current_device()}
|
||||
|
||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
quant_method = quantization_config.get("quant_method", "")
|
||||
|
||||
if quant_method == QuantizationMethod.GPTQ:
|
||||
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
||||
quantization_config.pop("disable_exllama", None) # remove deprecated args
|
||||
quantization_config["use_exllama"] = False # disable exllama
|
||||
|
||||
if quant_method == QuantizationMethod.AWQ:
|
||||
@@ -266,8 +267,8 @@ def _prepare_model_for_training(
|
||||
else:
|
||||
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
|
||||
# According to: https://github.com/huggingface/transformers/issues/28339
|
||||
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
|
||||
model.enable_input_require_grads()
|
||||
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
||||
logger.info("Gradient checkpointing enabled.")
|
||||
|
||||
@@ -293,7 +294,7 @@ def patch_config(
|
||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
|
||||
_configure_attn_implementation(config, model_args, init_kwargs)
|
||||
_configure_attn_implementation(config, model_args)
|
||||
_configure_rope(config, model_args, is_trainable)
|
||||
_configure_longlora(config, model_args, is_trainable)
|
||||
_configure_quantization(config, tokenizer, model_args, init_kwargs)
|
||||
@@ -316,15 +317,15 @@ def patch_config(
|
||||
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn:
|
||||
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn
|
||||
|
||||
if getattr(config, "model_type", None) == "qwen2_moe" and is_trainable:
|
||||
if getattr(config, "model_type", None) in ["mixtral", "qwen2_moe"] and is_trainable:
|
||||
setattr(config, "output_router_logits", True)
|
||||
|
||||
init_kwargs["torch_dtype"] = model_args.compute_dtype
|
||||
if not is_deepspeed_zero3_enabled():
|
||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
|
||||
if init_kwargs["low_cpu_mem_usage"]:
|
||||
if "device_map" not in init_kwargs:
|
||||
init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()}
|
||||
if "device_map" not in init_kwargs and model_args.device_map:
|
||||
init_kwargs["device_map"] = model_args.device_map
|
||||
|
||||
if init_kwargs["device_map"] == "auto":
|
||||
init_kwargs["offload_folder"] = model_args.offload_folder
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import inspect
|
||||
from enum import Enum, unique
|
||||
from typing import TYPE_CHECKING, Dict, List
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedModel
|
||||
@@ -100,6 +102,42 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
|
||||
return module_names
|
||||
|
||||
|
||||
def gradient_checkpointing_enable(
|
||||
self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
r"""
|
||||
Activates gradient checkpointing for the current model.
|
||||
|
||||
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
|
||||
"""
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
if not self.supports_gradient_checkpointing:
|
||||
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__))
|
||||
|
||||
if gradient_checkpointing_kwargs is None:
|
||||
gradient_checkpointing_kwargs = {"use_reentrant": True}
|
||||
|
||||
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
|
||||
|
||||
def custom_gradient_checkpointing_func(func, *args, **kwargs):
|
||||
module: "torch.nn.Module" = func.__self__
|
||||
|
||||
if any(param.requires_grad for param in module.parameters()):
|
||||
for arg in args:
|
||||
if torch.is_tensor(arg) and torch.is_floating_point(arg):
|
||||
arg.requires_grad_(True)
|
||||
|
||||
return gradient_checkpointing_func(func, *args, **kwargs)
|
||||
|
||||
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
|
||||
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
||||
self.enable_input_require_grads()
|
||||
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
|
||||
else: # have already enabled input require gradients
|
||||
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
|
||||
|
||||
|
||||
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Loads value head parameters from Hugging Face Hub or local disk.
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -63,6 +64,11 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -44,6 +45,10 @@ class CustomORPOTrainer(DPOTrainer):
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
Trainer.__init__(self, model=model, **kwargs)
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -124,6 +125,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
else:
|
||||
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
|
||||
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
|
||||
r"""
|
||||
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from transformers import Trainer
|
||||
@@ -23,6 +24,10 @@ class CustomTrainer(Trainer):
|
||||
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -28,6 +29,10 @@ class PairwiseTrainer(Trainer):
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
self.can_return_loss = True # override property to return eval_loss
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
|
||||
@@ -2,7 +2,6 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available
|
||||
@@ -33,10 +32,6 @@ class ComputeMetrics:
|
||||
r"""
|
||||
Uses the model predictions to compute metrics.
|
||||
"""
|
||||
require_version("jieba", "To fix: pip install jieba")
|
||||
require_version("nltk", "To fix: pip install nltk")
|
||||
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
|
||||
|
||||
preds, labels = eval_preds
|
||||
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -28,6 +29,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
|
||||
@@ -65,8 +65,7 @@ def export_model(args: Optional[Dict[str, Any]] = None):
|
||||
if getattr(model, "quantization_method", None) is None: # cannot convert dtype of a quantized model
|
||||
output_dtype = getattr(model.config, "torch_dtype", torch.float16)
|
||||
setattr(model.config, "torch_dtype", output_dtype)
|
||||
for param in model.parameters():
|
||||
param.data = param.data.to(output_dtype)
|
||||
model = model.to(output_dtype)
|
||||
|
||||
model.save_pretrained(
|
||||
save_directory=model_args.export_dir,
|
||||
|
||||
@@ -5,7 +5,6 @@ from transformers import Trainer
|
||||
from transformers.optimization import get_scheduler
|
||||
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.packages import is_galore_available
|
||||
@@ -57,9 +56,11 @@ def create_modelcard_and_push(
|
||||
kwargs = {
|
||||
"tasks": "text-generation",
|
||||
"finetuned_from": model_args.model_name_or_path,
|
||||
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
|
||||
"tags": ["llama-factory", finetuning_args.finetuning_type],
|
||||
}
|
||||
if data_args.dataset is not None:
|
||||
kwargs["dataset"] = [dataset.strip() for dataset in data_args.dataset.split(",")]
|
||||
|
||||
if not training_args.do_train:
|
||||
pass
|
||||
elif training_args.push_to_hub:
|
||||
@@ -166,8 +167,6 @@ def _create_galore_optimizer(
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> "torch.optim.Optimizer":
|
||||
require_version("galore_torch", "To fix: pip install galore_torch")
|
||||
|
||||
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
|
||||
galore_targets = find_all_linear_modules(model)
|
||||
else:
|
||||
@@ -217,7 +216,7 @@ def _create_galore_optimizer(
|
||||
|
||||
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
|
||||
for param in nodecay_params:
|
||||
param_groups = [dict(params=[param])]
|
||||
param_groups = [dict(params=[param], weight_decay=0.0)]
|
||||
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
||||
for param in decay_params:
|
||||
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)]
|
||||
@@ -237,7 +236,7 @@ def _create_galore_optimizer(
|
||||
optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict)
|
||||
else:
|
||||
param_groups = [
|
||||
dict(params=nodecay_params),
|
||||
dict(params=nodecay_params, weight_decay=0.0),
|
||||
dict(params=decay_params, weight_decay=training_args.weight_decay),
|
||||
dict(params=galore_params, weight_decay=training_args.weight_decay, **galore_kwargs),
|
||||
]
|
||||
@@ -252,11 +251,9 @@ def _create_loraplus_optimizer(
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> "torch.optim.Optimizer":
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("You should use LoRA tuning to activate LoRA+.")
|
||||
|
||||
default_lr = training_args.learning_rate
|
||||
loraplus_lr = training_args.learning_rate * finetuning_args.loraplus_lr_ratio
|
||||
decay_args = {"weight_decay": training_args.weight_decay}
|
||||
embedding_lr = finetuning_args.loraplus_lr_embedding
|
||||
|
||||
decay_param_names = _get_decay_parameter_names(model)
|
||||
param_dict: Dict[str, List["torch.nn.Parameter"]] = {
|
||||
@@ -279,16 +276,76 @@ def _create_loraplus_optimizer(
|
||||
|
||||
optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
|
||||
param_groups = [
|
||||
dict(params=param_dict["lora_a"], **decay_args),
|
||||
dict(params=param_dict["lora_b"], lr=loraplus_lr, **decay_args),
|
||||
dict(params=param_dict["lora_b_nodecay"], lr=loraplus_lr),
|
||||
dict(params=param_dict["embedding"], lr=finetuning_args.loraplus_lr_embedding, **decay_args),
|
||||
dict(params=param_dict["lora_a"], lr=default_lr, weight_decay=training_args.weight_decay),
|
||||
dict(params=param_dict["lora_b"], lr=loraplus_lr, weight_decay=training_args.weight_decay),
|
||||
dict(params=param_dict["lora_b_nodecay"], lr=loraplus_lr, weight_decay=0.0),
|
||||
dict(params=param_dict["embedding"], lr=embedding_lr, weight_decay=training_args.weight_decay),
|
||||
]
|
||||
optimizer = optim_class(param_groups, **optim_kwargs)
|
||||
logger.info("Using LoRA+ optimizer with loraplus lr ratio {:.2f}.".format(finetuning_args.loraplus_lr_ratio))
|
||||
return optimizer
|
||||
|
||||
|
||||
def _create_badam_optimizer(
|
||||
model: "PreTrainedModel",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> "torch.optim.Optimizer":
|
||||
decay_params, nodecay_params = [], []
|
||||
decay_param_names = _get_decay_parameter_names(model)
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
if name in decay_param_names:
|
||||
decay_params.append(param)
|
||||
else:
|
||||
nodecay_params.append(param)
|
||||
|
||||
optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
|
||||
param_groups = [
|
||||
dict(params=nodecay_params, weight_decay=0.0),
|
||||
dict(params=decay_params, weight_decay=training_args.weight_decay),
|
||||
]
|
||||
|
||||
if finetuning_args.badam_mode == "layer":
|
||||
from badam import BlockOptimizer
|
||||
|
||||
base_optimizer = optim_class(param_groups, **optim_kwargs)
|
||||
optimizer = BlockOptimizer(
|
||||
base_optimizer=base_optimizer,
|
||||
named_parameters_list=list(model.named_parameters()),
|
||||
block_prefix_list=None,
|
||||
switch_block_every=finetuning_args.badam_switch_block_every,
|
||||
start_block=finetuning_args.badam_start_block,
|
||||
switch_mode=finetuning_args.badam_switch_mode,
|
||||
verbose=finetuning_args.badam_verbose,
|
||||
)
|
||||
logger.info(
|
||||
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
|
||||
f"switch block every {finetuning_args.badam_switch_block_every} steps, "
|
||||
f"default start block is {finetuning_args.badam_start_block}"
|
||||
)
|
||||
|
||||
elif finetuning_args.badam_mode == "ratio":
|
||||
from badam import BlockOptimizerRatio
|
||||
|
||||
assert finetuning_args.badam_update_ratio > 1e-6
|
||||
optimizer = BlockOptimizerRatio(
|
||||
param_groups=param_groups,
|
||||
named_parameters_list=list(model.named_parameters()),
|
||||
update_ratio=finetuning_args.badam_update_ratio,
|
||||
mask_mode=finetuning_args.badam_mask_mode,
|
||||
verbose=finetuning_args.badam_verbose,
|
||||
include_embedding=False,
|
||||
**optim_kwargs,
|
||||
)
|
||||
logger.info(
|
||||
f"Using BAdam optimizer with ratio-wise update, update ratio is {finetuning_args.badam_update_ratio}, "
|
||||
f"mask mode is {finetuning_args.badam_mask_mode}"
|
||||
)
|
||||
|
||||
return optimizer
|
||||
|
||||
|
||||
def create_custom_optimzer(
|
||||
model: "PreTrainedModel",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
@@ -300,6 +357,9 @@ def create_custom_optimzer(
|
||||
if finetuning_args.loraplus_lr_ratio is not None:
|
||||
return _create_loraplus_optimizer(model, training_args, finetuning_args)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
return _create_badam_optimizer(model, training_args, finetuning_args)
|
||||
|
||||
|
||||
def create_custom_scheduler(
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
@@ -314,12 +374,11 @@ def create_custom_scheduler(
|
||||
scheduler_dict[param] = get_scheduler(
|
||||
training_args.lr_scheduler_type,
|
||||
optimizer=optimizer_dict[param],
|
||||
num_warmup_steps=training_args.get_warmup_steps(num_training_steps) * 2,
|
||||
num_training_steps=num_training_steps * 2,
|
||||
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
)
|
||||
|
||||
def scheduler_hook(param: "torch.nn.Parameter"):
|
||||
if param.grad is not None:
|
||||
scheduler_dict[param].step()
|
||||
|
||||
for param in optimizer_dict.keys():
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple
|
||||
|
||||
import gradio as gr
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..data import Role
|
||||
from ..extras.misc import torch_gc
|
||||
from ..extras.packages import is_gradio_available
|
||||
from .common import get_save_dir
|
||||
from .locales import ALERTS
|
||||
|
||||
@@ -17,6 +15,10 @@ if TYPE_CHECKING:
|
||||
from .manager import Manager
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
class WebChatModel(ChatModel):
|
||||
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
|
||||
self.manager = manager
|
||||
@@ -35,7 +37,7 @@ class WebChatModel(ChatModel):
|
||||
def loaded(self) -> bool:
|
||||
return self.engine is not None
|
||||
|
||||
def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
|
||||
def load_model(self, data) -> Generator[str, None, None]:
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
lang = get("top.lang")
|
||||
error = ""
|
||||
@@ -79,7 +81,7 @@ class WebChatModel(ChatModel):
|
||||
|
||||
yield ALERTS["info_loaded"][lang]
|
||||
|
||||
def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
|
||||
def unload_model(self, data) -> Generator[str, None, None]:
|
||||
lang = data[self.manager.get_elem_by_id("top.lang")]
|
||||
|
||||
if self.demo_mode:
|
||||
|
||||
@@ -3,7 +3,6 @@ import os
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import gradio as gr
|
||||
from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME
|
||||
|
||||
from ..extras.constants import (
|
||||
@@ -17,6 +16,11 @@ from ..extras.constants import (
|
||||
DownloadSource,
|
||||
)
|
||||
from ..extras.misc import use_modelscope
|
||||
from ..extras.packages import is_gradio_available
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from typing import TYPE_CHECKING, Dict, Tuple
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from ...data import Role
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..utils import check_json_schema
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Dict, Tuple
|
||||
|
||||
import gradio as gr
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
||||
|
||||
from ...extras.constants import DATA_CONFIG
|
||||
from ...extras.packages import is_gradio_available
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -29,28 +32,38 @@ def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
|
||||
except Exception:
|
||||
return gr.Button(interactive=False)
|
||||
|
||||
if (
|
||||
len(dataset) > 0
|
||||
and "file_name" in dataset_info[dataset[0]]
|
||||
and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]))
|
||||
):
|
||||
if len(dataset) == 0 or "file_name" not in dataset_info[dataset[0]]:
|
||||
return gr.Button(interactive=False)
|
||||
|
||||
data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])
|
||||
if os.path.isfile(data_path) or (os.path.isdir(data_path) and os.listdir(data_path)):
|
||||
return gr.Button(interactive=True)
|
||||
else:
|
||||
return gr.Button(interactive=False)
|
||||
|
||||
|
||||
def _load_data_file(file_path: str) -> List[Any]:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
if file_path.endswith(".json"):
|
||||
return json.load(f)
|
||||
elif file_path.endswith(".jsonl"):
|
||||
return [json.loads(line) for line in f]
|
||||
else:
|
||||
return list(f)
|
||||
|
||||
|
||||
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
|
||||
data_file: str = dataset_info[dataset[0]]["file_name"]
|
||||
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
|
||||
if data_file.endswith(".json"):
|
||||
data = json.load(f)
|
||||
elif data_file.endswith(".jsonl"):
|
||||
data = [json.loads(line) for line in f]
|
||||
data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])
|
||||
if os.path.isfile(data_path):
|
||||
data = _load_data_file(data_path)
|
||||
else:
|
||||
data = [line for line in f] # noqa: C416
|
||||
data = []
|
||||
for file_name in os.listdir(data_path):
|
||||
data.extend(_load_data_file(os.path.join(data_path, file_name)))
|
||||
|
||||
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.Column(visible=True)
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..common import DEFAULT_DATA_DIR, list_dataset
|
||||
from .data import create_preview_box
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
@@ -18,7 +21,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
|
||||
with gr.Row():
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||
|
||||
input_elems.update({dataset_dir, dataset})
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
from typing import TYPE_CHECKING, Dict, Generator, List
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ...train import export_model
|
||||
from ..common import get_save_dir
|
||||
from ..locales import ALERTS
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from ...extras.packages import is_gradio_available
|
||||
from .chatbot import create_chat_box
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from ...data import templates
|
||||
from ...extras.constants import METHODS, SUPPORTED_MODELS
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..common import get_model_path, get_template, list_adapters, save_config
|
||||
from ..utils import can_quantize
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import gradio as gr
|
||||
from transformers.trainer_utils import SchedulerType
|
||||
|
||||
from ...extras.constants import TRAINING_STAGES
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset
|
||||
from ..components.data import create_preview_box
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
@@ -23,7 +27,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1
|
||||
)
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
|
||||
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||
|
||||
input_elems.update({training_stage, dataset_dir, dataset})
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from typing import Any, Dict, Generator
|
||||
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
from .chatter import WebChatModel
|
||||
from .common import get_model_path, list_dataset, load_config
|
||||
@@ -10,6 +8,10 @@ from .runner import Runner
|
||||
from .utils import get_time
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
|
||||
class Engine:
|
||||
def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None:
|
||||
self.demo_mode = demo_mode
|
||||
@@ -29,7 +31,7 @@ class Engine:
|
||||
|
||||
return output_dict
|
||||
|
||||
def resume(self) -> Generator[Dict[Component, Component], None, None]:
|
||||
def resume(self):
|
||||
user_config = load_config() if not self.demo_mode else {}
|
||||
lang = user_config.get("lang", None) or "en"
|
||||
|
||||
@@ -55,7 +57,7 @@ class Engine:
|
||||
else:
|
||||
yield self._update_component({"eval.resume_btn": {"value": True}})
|
||||
|
||||
def change_lang(self, lang: str) -> Dict[Component, Component]:
|
||||
def change_lang(self, lang: str):
|
||||
return {
|
||||
elem: elem.__class__(**LOCALES[elem_name][lang])
|
||||
for elem_name, elem in self.manager.get_elem_iter()
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import gradio as gr
|
||||
|
||||
from ..extras.packages import is_gradio_available
|
||||
from .common import save_config
|
||||
from .components import (
|
||||
create_chat_box,
|
||||
@@ -13,6 +12,10 @@ from .css import CSS
|
||||
from .engine import Engine
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def create_ui(demo_mode: bool = False) -> gr.Blocks:
|
||||
engine = Engine(demo_mode=demo_mode, pure_chat=False)
|
||||
|
||||
|
||||
@@ -4,9 +4,7 @@ import time
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator
|
||||
|
||||
import gradio as gr
|
||||
import transformers
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.utils import is_torch_cuda_available
|
||||
|
||||
@@ -14,13 +12,20 @@ from ..extras.callbacks import LogCallback
|
||||
from ..extras.constants import TRAINING_STAGES
|
||||
from ..extras.logging import LoggerHandler
|
||||
from ..extras.misc import get_device_count, torch_gc
|
||||
from ..extras.packages import is_gradio_available
|
||||
from ..train import run_exp
|
||||
from .common import get_module, get_save_dir, load_args, load_config, save_args
|
||||
from .locales import ALERTS
|
||||
from .utils import gen_cmd, gen_plot, get_eval_results, update_process_bar
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
from .manager import Manager
|
||||
|
||||
|
||||
@@ -239,7 +244,7 @@ class Runner:
|
||||
|
||||
return args
|
||||
|
||||
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict[Component, str], None, None]:
|
||||
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]:
|
||||
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
|
||||
error = self._initialize(data, do_train, from_preview=True)
|
||||
if error:
|
||||
@@ -249,7 +254,7 @@ class Runner:
|
||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
yield {output_box: gen_cmd(args)}
|
||||
|
||||
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict[Component, Any], None, None]:
|
||||
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", Any], None, None]:
|
||||
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
|
||||
error = self._initialize(data, do_train, from_preview=False)
|
||||
if error:
|
||||
@@ -263,19 +268,19 @@ class Runner:
|
||||
self.thread.start()
|
||||
yield from self.monitor()
|
||||
|
||||
def preview_train(self, data: Dict[Component, Any]) -> Generator[Dict[Component, str], None, None]:
|
||||
def preview_train(self, data):
|
||||
yield from self._preview(data, do_train=True)
|
||||
|
||||
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Dict[Component, str], None, None]:
|
||||
def preview_eval(self, data):
|
||||
yield from self._preview(data, do_train=False)
|
||||
|
||||
def run_train(self, data: Dict[Component, Any]) -> Generator[Dict[Component, Any], None, None]:
|
||||
def run_train(self, data):
|
||||
yield from self._launch(data, do_train=True)
|
||||
|
||||
def run_eval(self, data: Dict[Component, Any]) -> Generator[Dict[Component, Any], None, None]:
|
||||
def run_eval(self, data):
|
||||
yield from self._launch(data, do_train=False)
|
||||
|
||||
def monitor(self) -> Generator[Dict[Component, Any], None, None]:
|
||||
def monitor(self):
|
||||
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
|
||||
self.aborted = False
|
||||
self.running = True
|
||||
@@ -332,7 +337,7 @@ class Runner:
|
||||
|
||||
yield return_dict
|
||||
|
||||
def save_args(self, data: Dict[Component, Any]) -> Dict[Component, str]:
|
||||
def save_args(self, data):
|
||||
output_box = self.manager.get_elem_by_id("train.output_box")
|
||||
error = self._initialize(data, do_train=True, from_preview=True)
|
||||
if error:
|
||||
@@ -351,7 +356,7 @@ class Runner:
|
||||
save_path = save_args(config_path, config_dict)
|
||||
return {output_box: ALERTS["info_config_saved"][lang] + save_path}
|
||||
|
||||
def load_args(self, lang: str, config_path: str) -> Dict[Component, Any]:
|
||||
def load_args(self, lang: str, config_path: str):
|
||||
output_box = self.manager.get_elem_by_id("train.output_box")
|
||||
config_dict = load_args(config_path)
|
||||
if config_dict is None:
|
||||
|
||||
@@ -3,21 +3,24 @@ import os
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from ..extras.packages import is_matplotlib_available
|
||||
from ..extras.packages import is_gradio_available, is_matplotlib_available
|
||||
from ..extras.ploting import smooth
|
||||
from .locales import ALERTS
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..extras.callbacks import LogCallback
|
||||
if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if is_matplotlib_available():
|
||||
import matplotlib.figure
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..extras.callbacks import LogCallback
|
||||
|
||||
|
||||
def update_process_bar(callback: "LogCallback") -> "gr.Slider":
|
||||
if not callback.max_steps:
|
||||
return gr.Slider(visible=False)
|
||||
|
||||
Reference in New Issue
Block a user