Compare commits
31 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dad7ca6633 | ||
|
|
a1468139a5 | ||
|
|
49c90044ce | ||
|
|
0f7cdac207 | ||
|
|
c4e9694c6e | ||
|
|
2006a96570 | ||
|
|
5dcd95645f | ||
|
|
9b3304b054 | ||
|
|
e580d4ef41 | ||
|
|
64db4abc68 | ||
|
|
5ba0b80e5c | ||
|
|
7a43ff3d89 | ||
|
|
7e1a1d141a | ||
|
|
6d881f161b | ||
|
|
a02b3e6192 | ||
|
|
bcdee9fc19 | ||
|
|
8b688251be | ||
|
|
718f3382ad | ||
|
|
dc8283d3d7 | ||
|
|
35e76879f5 | ||
|
|
8e4ae0aaac | ||
|
|
5ed2a97056 | ||
|
|
03eba6f041 | ||
|
|
ec166e736a | ||
|
|
c85a6b83b3 | ||
|
|
a864a7b395 | ||
|
|
fd8c2d4aac | ||
|
|
baf2e4e825 | ||
|
|
eac7f97337 | ||
|
|
c08ff734a7 | ||
|
|
e9736b2ba0 |
121
README.md
121
README.md
@@ -10,7 +10,11 @@
|
||||
|
||||
## Changelog
|
||||
|
||||
[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-13B-Base`, `--padding_side right` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--prompt_template baichuan` argument when you are using the Baichuan-13B-Chat model.
|
||||
[23/07/19] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path meta-llama/Llama-2-7b-hf` argument to use the LLaMA-2 model. Remember to use `--prompt_template llama2` argument when you are using the LLaMA-2-chat model.
|
||||
|
||||
[23/07/18] Now we develop an all-in-one Web UI for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
|
||||
|
||||
[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Please replace the Baichuan-13B model file with `tests/modeling_baichuan.py` and try `--model_name_or_path path_to_baichuan_model` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--prompt_template baichuan` argument when you are using the Baichuan-13B-Chat model.
|
||||
|
||||
[23/07/09] Now we release [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
|
||||
|
||||
@@ -18,11 +22,11 @@
|
||||
|
||||
[23/07/05] Now we support training the **Falcon-7B/40B** models in this repo. Try `--model_name_or_path tiiuae/falcon-7b` and `--lora_target query_key_value` arguments to use the Falcon model.
|
||||
|
||||
[23/06/29] We provide a **reproducible example** of training a chat model using instruction-following datasets, see this [HuggingFace Repo](https://huggingface.co/hiyouga/baichuan-7b-sft) for details.
|
||||
[23/06/29] We provide a **reproducible example** of training a chat model using instruction-following datasets, see this [Hugging Face Repo](https://huggingface.co/hiyouga/baichuan-7b-sft) for details.
|
||||
|
||||
[23/06/22] Now we align the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in **arbitrary ChatGPT-based applications**.
|
||||
|
||||
[23/06/15] Now we support training the **Baichuan-7B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-7B` and `--lora_target W_pack` arguments to use the Baichuan-7B model. If you want to train with RTX3090, use `git checkout baichuan-7b-rtx3090` to switch to the `baichuan-7b-rtx3090` branch and try the `--baichuan_rtx_gpu true` argument. (Other RTX series GPUs can also be tried)
|
||||
[23/06/15] Now we support training the **Baichuan-7B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-7B` and `--lora_target W_pack` arguments to use the Baichuan-7B model.
|
||||
|
||||
[23/06/03] Now we support quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature)
|
||||
|
||||
@@ -31,6 +35,7 @@
|
||||
## Supported Models
|
||||
|
||||
- [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B)
|
||||
- [LLaMA-2](https://huggingface.co/meta-llama) (7B/13B/70B)
|
||||
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B)
|
||||
- [Falcon](https://huggingface.co/tiiuae/falcon-7b) (7B/40B)
|
||||
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B) (7B/13B)
|
||||
@@ -55,36 +60,36 @@
|
||||
## Provided Datasets
|
||||
|
||||
- For pre-training:
|
||||
- [Wiki Demo](data/wiki_demo.txt)
|
||||
- [Wiki Demo (en)](data/wiki_demo.txt)
|
||||
- For supervised fine-tuning:
|
||||
- [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (Chinese)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||
- [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [BELLE 2M](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
||||
- [BELLE 1M](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
||||
- [BELLE 0.5M](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
|
||||
- [BELLE Dialogue 0.4M](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
|
||||
- [BELLE School Math 0.25M](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
|
||||
- [BELLE Multiturn Chat 0.8M](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
||||
- [Guanaco Dataset](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||
- [Firefly 1.1M](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||
- [CodeAlpaca 20k](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||
- [Alpaca CoT](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||
- [Web QA (Chinese)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [UltraChat](https://github.com/thunlp/UltraChat)
|
||||
- [Open Assistant](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [Open Assistant (Chinese)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [WebNovel (Chinese)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- For reward model training:
|
||||
- [HH-RLHF](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Open Assistant](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [Open Assistant (Chinese)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [GPT-4 Generated Data (Chinese)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [Self-cognition (zh)](data/self_cognition.json)
|
||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||
- [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
|
||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
||||
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
||||
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
|
||||
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
|
||||
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
|
||||
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- For reward modelling:
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
|
||||
Please refer to [data/README.md](data/README.md) for details.
|
||||
|
||||
Some datasets require confirmation before using them, so we recommend logging in with your HuggingFace account using these commands.
|
||||
Some datasets require confirmation before using them, so we recommend logging in with your Hugging Face account using these commands.
|
||||
|
||||
```bash
|
||||
pip install --upgrade huggingface_hub
|
||||
@@ -125,14 +130,10 @@ cd LLaMA-Efficient-Tuning
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### LLaMA Weights Preparation (optional)
|
||||
|
||||
1. Download the weights of the LLaMA models.
|
||||
2. Convert them to HF format using the following command.
|
||||
### All-in-one Web UI
|
||||
|
||||
```bash
|
||||
python -m transformers.models.llama.convert_llama_weights_to_hf \
|
||||
--input_dir path_to_llama_weights --model_size 7B --output_dir path_to_llama_model
|
||||
python src/train_web.py
|
||||
```
|
||||
|
||||
### (Continually) Pre-Training
|
||||
@@ -262,24 +263,64 @@ use_cpu: false
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage pt \
|
||||
--stage sft \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--do_eval \
|
||||
--dataset alpaca_gpt4_en \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--output_dir path_to_eval_result \
|
||||
--per_device_eval_batch_size 8 \
|
||||
--max_samples 50 \
|
||||
--max_samples 100 \
|
||||
--predict_with_generate
|
||||
```
|
||||
|
||||
We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation.
|
||||
|
||||
### API / CLI / Web Demo
|
||||
### Predict
|
||||
|
||||
```bash
|
||||
python src/xxx_demo.py \
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage sft \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--do_predict \
|
||||
--dataset alpaca_gpt4_en \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--output_dir path_to_predict_result \
|
||||
--per_device_eval_batch_size 8 \
|
||||
--max_samples 100 \
|
||||
--predict_with_generate
|
||||
```
|
||||
|
||||
If you want to predict the samples with empty responses, please kindly fill the `response` column with **dummy tokens** to ensure the sample will not be discarded throughout the preprocessing phase.
|
||||
|
||||
### API Demo
|
||||
|
||||
```bash
|
||||
python src/api_demo.py \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
|
||||
Visit `http://localhost:8000/docs` for API documentation.
|
||||
|
||||
### CLI Demo
|
||||
|
||||
```bash
|
||||
python src/cli_demo.py \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
|
||||
### Web Demo
|
||||
|
||||
```bash
|
||||
python src/web_demo.py \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
|
||||
@@ -288,6 +329,7 @@ python src/xxx_demo.py \
|
||||
```bash
|
||||
python src/export_model.py \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--output_dir path_to_export
|
||||
```
|
||||
@@ -299,6 +341,7 @@ This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||
Please follow the model licenses to use the corresponding model weights:
|
||||
|
||||
- [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)
|
||||
- [LLaMA-2](https://ai.meta.com/llama/license/)
|
||||
- [BLOOM](https://huggingface.co/spaces/bigscience/license)
|
||||
- [Falcon](LICENSE)
|
||||
- [baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
Data format in `dataset_info.json`:
|
||||
If you are using a custom dataset, please provide your dataset definition in the following format in `dataset_info.json`.
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
"hf_hub_url": "the name of the dataset repository on the HuggingFace hub. (if specified, ignore below 3 arguments)",
|
||||
@@ -14,40 +15,4 @@ Data format in `dataset_info.json`:
|
||||
}
|
||||
```
|
||||
|
||||
`dataset_info.json` 中的数据集定义格式:
|
||||
```json
|
||||
"数据集名称": {
|
||||
"hf_hub_url": "HuggingFace上的项目地址(若指定,则忽略下列三个参数)",
|
||||
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
|
||||
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
|
||||
"file_sha1": "数据集文件的SHA-1哈希值(可选)",
|
||||
"columns": {
|
||||
"prompt": "数据集代表提示词的表头名称(默认:instruction)",
|
||||
"query": "数据集代表请求的表头名称(默认:input)",
|
||||
"response": "数据集代表回答的表头名称(默认:output)",
|
||||
"history": "数据集代表历史对话的表头名称(默认:None)"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
部分预置数据集简介:
|
||||
|
||||
| 数据集名称 | 规模 | 描述 |
|
||||
| --- | --- | --- |
|
||||
| [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) | 52k | 斯坦福大学开源的 Alpaca 数据集,训练了 Alpaca 这类早期基于 LLaMA 的模型 |
|
||||
| [Stanford Alpaca (Chinese)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) | 51k | 使用 ChatGPT 翻译的 Alpaca 数据集 |
|
||||
| [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) | 100k+ | 基于 GPT-4 的 self-instruction 数据集 |
|
||||
| [BELLE 2M](https://huggingface.co/datasets/BelleGroup/train_2M_CN) | 2m | 包含约 200 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 |
|
||||
| [BELLE 1M](https://huggingface.co/datasets/BelleGroup/train_1M_CN) | 1m | 包含约 100 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 |
|
||||
| [BELLE 0.5M](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN) | 500k | 包含约 50 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 |
|
||||
| [BELLE Dialogue 0.4M](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M) | 400k | 包含约 40 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的个性化角色对话数据,包含角色介绍 |
|
||||
| [BELLE School Math 0.25M](https://huggingface.co/datasets/BelleGroup/school_math_0.25M) | 250k | 包含约 25 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文数学题数据,包含解题过程 |
|
||||
| [BELLE Multiturn Chat 0.8M](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M) | 800k | 包含约 80 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的用户与助手的多轮对话 |
|
||||
| [Guanaco Dataset](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) | 100k+ | 包含日文、简繁体中文、英文等多类数据,数据集原用于 Guanaco 模型训练 |
|
||||
| [Firefly 1.1M](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) | 1.1M | 中文对话大模型 firefly(流萤)的中文数据集,包含多个 NLP 任务 |
|
||||
| [CodeAlpaca 20k](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) | 20k | 英文代码生成任务数据集 |
|
||||
| [Alpaca CoT](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) | 6M | 用于微调的指令数据集集合 |
|
||||
| [Web QA](https://huggingface.co/datasets/suolyer/webqa) | 36k | 百度知道汇集的中文问答数据集 |
|
||||
| [UltraChat](https://github.com/thunlp/UltraChat) | 1.57M | 清华 NLP 发布的大规模多轮对话数据集 |
|
||||
|
||||
注:BELLE 数据集是由 ChatGPT 产生的数据集,不保证数据准确性,所有类 GPT 模型产生的 self-instruction 数据集均不能保证其准确性。
|
||||
where the `prompt` and `response` columns should contain non-empty values. The `query` column will be concatenated with the `prompt` column and used as input for the model. The `history` column should contain a list where each element is a string tuple representing a query-response pair.
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
{"id": 0,"title": "大卫·亨利","content": "大卫·亨利\n\n大卫·克莱顿·亨利(David Clayton Henrie,),美国演员。近来在迪士尼频道原创电视影集《少年魔法师》(Wizards of Waverly Place)当中演出贾斯汀·鲁索(Justin Russo)一角。\n\n大卫·亨利出生在加州Mission Viejo,在凤凰城长大。他的胞弟劳伦斯·亨利(Lorenzo Henrie)也是演员。大卫·亨利就读夏安传统学校。家中是信奉罗马天主教。 \n\n大卫在2007年拍摄少年魔法师期间认识女演员露西·海尔(Lucy Hale),之后与其交往,于2009年分手。\n\n10岁时,大卫·亨利和SAG在凤凰城签订了合约,并开始走出去试镜。 9岁的时候,在沙加缅度进行商业拍摄,SAG董事建议大卫·亨利搬到洛杉矶。在10岁那年夏天,他和他的家人搬到了好莱坞。他预定他的前2支商业试镜,扮演主要角色为汉堡王和桂格燕麦。他初演电视节目为Providence。 \n\n到了13岁,大卫有了他的第一次重大突破,在福克斯公司的喜剧The Pitts饰演 Petey Pitt一角。大卫下出作品为的Hallmark movie为Monster Maker,和琳达布莱儿、乔治甘迺迪共同演出,并要求回来Hallmark movie公司。 \n\n在18岁时,大卫得到了迪士尼频道原创系列演出机会,该节目2007年10月12日首播。大卫2008年参加了迪士尼频道的游戏节目。他是绿色团队的队长,隔年,为旋风队队长。他在迪士尼原创电影《少年魔法师》之后在《酷爸的疯狂假期》中有饰演一角。\n"}
|
||||
{"id": 1,"title": "大卫·亨利","content": "大卫·亨利\n\n大卫·克莱顿·亨利(David Clayton Henrie,),美国演员。近来在迪士尼频道原创电视影集《少年魔法师》(Wizards of Waverly Place)当中演出贾斯汀·鲁索(Justin Russo)一角。\n\n大卫·亨利出生在加州Mission Viejo,在凤凰城长大。他的胞弟劳伦斯·亨利(Lorenzo Henrie)也是演员。大卫·亨利就读夏安传统学校。家中是信奉罗马天主教。 \n\n大卫在2007年拍摄少年魔法师期间认识女演员露西·海尔(Lucy Hale),之后与其交往,于2009年分手。\n\n10岁时,大卫·亨利和SAG在凤凰城签订了合约,并开始走出去试镜。 9岁的时候,在沙加缅度进行商业拍摄,SAG董事建议大卫·亨利搬到洛杉矶。在10岁那年夏天,他和他的家人搬到了好莱坞。他预定他的前2支商业试镜,扮演主要角色为汉堡王和桂格燕麦。他初演电视节目为Providence。 \n\n到了13岁,大卫有了他的第一次重大突破,在福克斯公司的喜剧The Pitts饰演 Petey Pitt一角。大卫下出作品为的Hallmark movie为Monster Maker,和琳达布莱儿、乔治甘迺迪共同演出,并要求回来Hallmark movie公司。 \n\n在18岁时,大卫得到了迪士尼频道原创系列演出机会,该节目2007年10月12日首播。大卫2008年参加了迪士尼频道的游戏节目。他是绿色团队的队长,隔年,为旋风队队长。他在迪士尼原创电影《少年魔法师》之后在《酷爸的疯狂假期》中有饰演一角。\n"}
|
||||
1
data/refgpt_zh_50k_p1.json.REMOVED.git-id
Normal file
1
data/refgpt_zh_50k_p1.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
||||
56405bb8f52727e52e99693739494b9b7b0d7ba6
|
||||
1
data/refgpt_zh_50k_p2.json.REMOVED.git-id
Normal file
1
data/refgpt_zh_50k_p2.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
||||
fa935248a5d40d2bdd5649af99a72a754d40ae7a
|
||||
1
data/sharegpt_zh_27k.json.REMOVED.git-id
Normal file
1
data/sharegpt_zh_27k.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
||||
38c89869c6aeca2a3af9ea1e09afe460f9b46810
|
||||
@@ -1,16 +1,16 @@
|
||||
torch>=1.13.1
|
||||
transformers>=4.29.1
|
||||
datasets>=2.12.0
|
||||
accelerate>=0.19.0
|
||||
peft>=0.3.0
|
||||
trl>=0.4.4
|
||||
accelerate>=0.21.0
|
||||
peft>=0.4.0
|
||||
trl>=0.4.7
|
||||
sentencepiece
|
||||
jieba
|
||||
rouge-chinese
|
||||
nltk
|
||||
gradio>=3.36.0
|
||||
uvicorn
|
||||
pydantic==1.10.7
|
||||
fastapi
|
||||
pydantic==1.10.11
|
||||
fastapi==0.95.1
|
||||
sse-starlette
|
||||
matplotlib
|
||||
|
||||
@@ -5,9 +5,16 @@
|
||||
|
||||
import uvicorn
|
||||
|
||||
from llmtuner import create_app
|
||||
from llmtuner import ChatModel
|
||||
from llmtuner.api.app import create_app
|
||||
from llmtuner.tuner import get_infer_args
|
||||
|
||||
|
||||
def main():
|
||||
chat_model = ChatModel(*get_infer_args())
|
||||
app = create_app(chat_model)
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = create_app()
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
||||
main()
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
# Implements stream chat in command line for fine-tuned models.
|
||||
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
||||
|
||||
from llmtuner import ChatModel, get_infer_args
|
||||
from llmtuner import ChatModel
|
||||
from llmtuner.tuner import get_infer_args
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Exports the fine-tuned model.
|
||||
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
|
||||
|
||||
from llmtuner import get_train_args, load_model_and_tokenizer
|
||||
from llmtuner.tuner import get_train_args, load_model_and_tokenizer
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from llmtuner.api import create_app
|
||||
from llmtuner.chat import ChatModel
|
||||
from llmtuner.tuner import get_train_args, get_infer_args, load_model_and_tokenizer, run_pt, run_sft, run_rm, run_ppo
|
||||
|
||||
|
||||
__version__ = "0.0.9"
|
||||
__version__ = "0.1.3"
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
from llmtuner.api.app import create_app
|
||||
|
||||
@@ -9,6 +9,8 @@ from llmtuner.tuner import get_infer_args
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.chat.stream_chat import ChatModel
|
||||
from llmtuner.api.protocol import (
|
||||
Role,
|
||||
Finish,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ChatMessage,
|
||||
@@ -28,9 +30,7 @@ async def lifespan(app: FastAPI): # collects GPU memory
|
||||
torch_gc()
|
||||
|
||||
|
||||
def create_app():
|
||||
chat_model = ChatModel(*get_infer_args())
|
||||
|
||||
def create_app(chat_model: ChatModel) -> FastAPI:
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
@@ -48,12 +48,12 @@ def create_app():
|
||||
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
if request.messages[-1].role != "user":
|
||||
if request.messages[-1].role != Role.USER:
|
||||
raise HTTPException(status_code=400, detail="Invalid request")
|
||||
query = request.messages[-1].content
|
||||
|
||||
prev_messages = request.messages[:-1]
|
||||
if len(prev_messages) > 0 and prev_messages[0].role == "system":
|
||||
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
|
||||
prefix = prev_messages.pop(0).content
|
||||
else:
|
||||
prefix = None
|
||||
@@ -61,7 +61,7 @@ def create_app():
|
||||
history = []
|
||||
if len(prev_messages) % 2 == 0:
|
||||
for i in range(0, len(prev_messages), 2):
|
||||
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
|
||||
if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
|
||||
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
||||
|
||||
if request.stream:
|
||||
@@ -80,19 +80,19 @@ def create_app():
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role="assistant", content=response),
|
||||
finish_reason="stop"
|
||||
message=ChatMessage(role=Role.ASSISTANT, content=response),
|
||||
finish_reason=Finish.STOP
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage, object="chat.completion")
|
||||
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)
|
||||
|
||||
async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role="assistant"),
|
||||
delta=DeltaMessage(role=Role.ASSISTANT),
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk")
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
for new_text in chat_model.stream_chat(
|
||||
@@ -106,15 +106,15 @@ def create_app():
|
||||
delta=DeltaMessage(content=new_text),
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk")
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(),
|
||||
finish_reason="stop"
|
||||
finish_reason=Finish.STOP
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk")
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield "[DONE]"
|
||||
|
||||
@@ -122,5 +122,6 @@ def create_app():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = create_app()
|
||||
chat_model = ChatModel(*get_infer_args())
|
||||
app = create_app(chat_model)
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
||||
|
||||
@@ -1,6 +1,18 @@
|
||||
import time
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Literal, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class Role(str, Enum):
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class Finish(str, Enum):
|
||||
STOP = "stop"
|
||||
LENGTH = "length"
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
@@ -19,12 +31,12 @@ class ModelList(BaseModel):
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Literal["user", "assistant", "system"]
|
||||
role: Role
|
||||
content: str
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
role: Optional[Literal["user", "assistant", "system"]] = None
|
||||
role: Optional[Role] = None
|
||||
content: Optional[str] = None
|
||||
|
||||
|
||||
@@ -41,13 +53,13 @@ class ChatCompletionRequest(BaseModel):
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: Literal["stop", "length"]
|
||||
finish_reason: Finish
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
delta: DeltaMessage
|
||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||
finish_reason: Optional[Finish] = None
|
||||
|
||||
|
||||
class ChatCompletionResponseUsage(BaseModel):
|
||||
@@ -58,7 +70,7 @@ class ChatCompletionResponseUsage(BaseModel):
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: Optional[str] = "chatcmpl-default"
|
||||
object: Literal["chat.completion"]
|
||||
object: Optional[str] = "chat.completion"
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseChoice]
|
||||
@@ -67,7 +79,7 @@ class ChatCompletionResponse(BaseModel):
|
||||
|
||||
class ChatCompletionStreamResponse(BaseModel):
|
||||
id: Optional[str] = "chatcmpl-default"
|
||||
object: Literal["chat.completion.chunk"]
|
||||
object: Optional[str] = "chat.completion.chunk"
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseStreamChoice]
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import torch
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||
from threading import Thread
|
||||
from transformers import TextIteratorStreamer
|
||||
|
||||
from llmtuner.extras.misc import get_logits_processor
|
||||
from llmtuner.extras.template import Template
|
||||
from llmtuner.extras.template import get_template
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
from llmtuner.tuner import load_model_and_tokenizer
|
||||
|
||||
@@ -18,14 +19,14 @@ class ChatModel:
|
||||
generating_args: GeneratingArguments
|
||||
) -> None:
|
||||
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
self.template = Template(data_args.prompt_template)
|
||||
self.source_prefix = data_args.source_prefix if data_args.source_prefix else ""
|
||||
self.template = get_template(data_args.prompt_template)
|
||||
self.source_prefix = data_args.source_prefix or ""
|
||||
self.generating_args = generating_args
|
||||
|
||||
def process_args(
|
||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
prefix = prefix if prefix else self.source_prefix
|
||||
prefix = prefix or self.source_prefix
|
||||
|
||||
inputs = self.tokenizer([self.template.get_prompt(query, history, prefix)], return_tensors="pt")
|
||||
inputs = inputs.to(self.model.device)
|
||||
@@ -41,10 +42,10 @@ class ChatModel:
|
||||
gen_kwargs = self.generating_args.to_dict()
|
||||
gen_kwargs.update(dict(
|
||||
input_ids=inputs["input_ids"],
|
||||
temperature=temperature if temperature else gen_kwargs["temperature"],
|
||||
top_p=top_p if top_p else gen_kwargs["top_p"],
|
||||
top_k=top_k if top_k else gen_kwargs["top_k"],
|
||||
repetition_penalty=repetition_penalty if repetition_penalty else gen_kwargs["repetition_penalty"],
|
||||
temperature=temperature or gen_kwargs["temperature"],
|
||||
top_p=top_p or gen_kwargs["top_p"],
|
||||
top_k=top_k or gen_kwargs["top_k"],
|
||||
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
|
||||
logits_processor=get_logits_processor()
|
||||
))
|
||||
|
||||
@@ -58,6 +59,7 @@ class ChatModel:
|
||||
|
||||
return gen_kwargs, prompt_length
|
||||
|
||||
@torch.inference_mode()
|
||||
def chat(
|
||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
|
||||
) -> Tuple[str, Tuple[int, int]]:
|
||||
@@ -68,6 +70,7 @@ class ChatModel:
|
||||
response_length = len(outputs)
|
||||
return response, (prompt_length, response_length)
|
||||
|
||||
@torch.inference_mode()
|
||||
def stream_chat(
|
||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
|
||||
) -> Generator[str, None, None]:
|
||||
@@ -78,5 +81,4 @@ class ChatModel:
|
||||
thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
|
||||
thread.start()
|
||||
|
||||
for new_text in streamer:
|
||||
yield new_text
|
||||
yield from streamer
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
from llmtuner.dsets.loader import get_dataset
|
||||
from llmtuner.dsets.preprocess import preprocess_dataset
|
||||
from llmtuner.dsets.utils import split_dataset
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from datetime import timedelta
|
||||
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments
|
||||
)
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
|
||||
def __init__(self, runner=None):
|
||||
self.runner = runner
|
||||
self.start_time = time.time()
|
||||
self.tracker = {}
|
||||
|
||||
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
r"""
|
||||
Event called at the beginning of a training step. If using gradient accumulation, one training step
|
||||
might take several inputs.
|
||||
"""
|
||||
if self.runner is not None and self.runner.aborted:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
r"""
|
||||
Event called at the end of an substep during gradient accumulation.
|
||||
"""
|
||||
if self.runner is not None and self.runner.aborted:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
|
||||
r"""
|
||||
Event called after logging the last logs.
|
||||
"""
|
||||
if "loss" not in state.log_history[-1]:
|
||||
return
|
||||
cur_time = time.time()
|
||||
cur_steps = state.log_history[-1].get("step")
|
||||
elapsed_time = cur_time - self.start_time
|
||||
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
|
||||
remaining_steps = state.max_steps - cur_steps
|
||||
remaining_time = remaining_steps * avg_time_per_step
|
||||
self.tracker = {
|
||||
"current_steps": cur_steps,
|
||||
"total_steps": state.max_steps,
|
||||
"loss": state.log_history[-1].get("loss", None),
|
||||
"reward": state.log_history[-1].get("reward", None),
|
||||
"learning_rate": state.log_history[-1].get("learning_rate", None),
|
||||
"epoch": state.log_history[-1].get("epoch", None),
|
||||
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
|
||||
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
|
||||
"remaining_time": str(timedelta(seconds=int(remaining_time)))
|
||||
}
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(self.tracker) + "\n")
|
||||
@@ -6,7 +6,7 @@ from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from datasets import Dataset
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.template import Template
|
||||
from llmtuner.extras.template import get_template
|
||||
from llmtuner.hparams import DataArguments
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ def preprocess_dataset(
|
||||
) -> Dataset:
|
||||
|
||||
column_names = list(dataset.column_names)
|
||||
prompt_template = Template(data_args.prompt_template)
|
||||
prompt_template = get_template(data_args.prompt_template)
|
||||
|
||||
# support question with a single answer or multiple answers
|
||||
def get_dialog(examples):
|
||||
@@ -143,8 +143,10 @@ def preprocess_dataset(
|
||||
if stage == "pt":
|
||||
preprocess_function = preprocess_pretrain_dataset
|
||||
elif stage == "sft":
|
||||
preprocess_function = preprocess_unsupervised_dataset \
|
||||
if training_args.predict_with_generate else preprocess_supervised_dataset
|
||||
if not training_args.predict_with_generate:
|
||||
preprocess_function = preprocess_supervised_dataset
|
||||
else:
|
||||
preprocess_function = preprocess_unsupervised_dataset
|
||||
elif stage == "rm":
|
||||
preprocess_function = preprocess_pairwise_dataset
|
||||
elif stage == "ppo":
|
||||
|
||||
16
src/llmtuner/dsets/utils.py
Normal file
16
src/llmtuner/dsets/utils.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from typing import Dict
|
||||
from datasets import Dataset
|
||||
|
||||
|
||||
def split_dataset(
|
||||
dataset: Dataset, dev_ratio: float, do_train: bool
|
||||
) -> Dict[str, Dataset]:
|
||||
# Split the dataset
|
||||
if do_train:
|
||||
if dev_ratio > 1e-6:
|
||||
dataset = dataset.train_test_split(test_size=dev_ratio)
|
||||
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
||||
else:
|
||||
return {"train_dataset": dataset}
|
||||
else: # do_eval or do_predict
|
||||
return {"eval_dataset": dataset}
|
||||
@@ -47,6 +47,9 @@ class LogCallback(TrainerCallback):
|
||||
r"""
|
||||
Event called after logging the last logs.
|
||||
"""
|
||||
if not state.is_world_process_zero:
|
||||
return
|
||||
|
||||
cur_time = time.time()
|
||||
cur_steps = state.log_history[-1].get("step")
|
||||
elapsed_time = cur_time - self.start_time
|
||||
|
||||
@@ -5,3 +5,43 @@ VALUE_HEAD_FILE_NAME = "value_head.bin"
|
||||
FINETUNING_ARGS_NAME = "finetuning_args.json"
|
||||
|
||||
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings
|
||||
|
||||
METHODS = ["full", "freeze", "lora"]
|
||||
|
||||
SUPPORTED_MODELS = {
|
||||
"LLaMA-7B": "huggyllama/llama-7b",
|
||||
"LLaMA-13B": "huggyllama/llama-13b",
|
||||
"LLaMA-30B": "huggyllama/llama-30b",
|
||||
"LLaMA-65B": "huggyllama/llama-65b",
|
||||
"LLaMA2-7B": "meta-llama/Llama-2-7b-hf",
|
||||
"LLaMA2-13B": "meta-llama/Llama-2-13b-hf",
|
||||
"LLaMA2-70B": "meta-llama/Llama-2-70b-hf",
|
||||
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
|
||||
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf",
|
||||
"BLOOM-560M": "bigscience/bloom-560m",
|
||||
"BLOOM-3B": "bigscience/bloom-3b",
|
||||
"BLOOM-7B1": "bigscience/bloom-7b1",
|
||||
"BLOOMZ-560M": "bigscience/bloomz-560m",
|
||||
"BLOOMZ-3B": "bigscience/bloomz-3b",
|
||||
"BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt",
|
||||
"Falcon-7B-Base": "tiiuae/falcon-7b",
|
||||
"Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
|
||||
"Falcon-40B-Base": "tiiuae/falcon-40b",
|
||||
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
|
||||
"Baichuan-7B": "baichuan-inc/Baichuan-7B",
|
||||
"Baichuan-13B-Base": "baichuan-inc/Baichuan-13B-Base",
|
||||
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat",
|
||||
"InternLM-7B-Base": "internlm/internlm-7b",
|
||||
"InternLM-7B-Chat": "internlm/internlm-chat-7b"
|
||||
}
|
||||
|
||||
DEFAULT_MODULE = {
|
||||
"LLaMA": "q_proj,v_proj",
|
||||
"LLaMA2": "q_proj,v_proj",
|
||||
"BLOOM": "query_key_value",
|
||||
"BLOOMZ": "query_key_value",
|
||||
"Falcon": "query_key_value",
|
||||
"Baichuan": "W_pack",
|
||||
"InternLM": "q_proj,v_proj"
|
||||
}
|
||||
|
||||
@@ -2,6 +2,20 @@ import sys
|
||||
import logging
|
||||
|
||||
|
||||
class LoggerHandler(logging.Handler):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.log = ""
|
||||
|
||||
def emit(self, record):
|
||||
if record.name == "httpx":
|
||||
return
|
||||
log_entry = self.format(record)
|
||||
self.log += log_entry
|
||||
self.log += "\n\n"
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
|
||||
formatter = logging.Formatter(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import math
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import List, Optional
|
||||
@@ -10,12 +11,13 @@ from llmtuner.extras.logging import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]:
|
||||
def smooth(scalars: List[float]) -> List[float]:
|
||||
r"""
|
||||
EMA implementation according to TensorBoard.
|
||||
"""
|
||||
last = scalars[0]
|
||||
smoothed = list()
|
||||
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
|
||||
for next_val in scalars:
|
||||
smoothed_val = last * weight + (1 - weight) * next_val
|
||||
smoothed.append(smoothed_val)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import torch
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
||||
from transformers.modeling_utils import load_sharded_checkpoint
|
||||
@@ -12,12 +12,12 @@ from llmtuner.extras.logging import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get state dict containing trainable parameters
|
||||
def get_state_dict(model: torch.nn.Module, trainable_only: Optional[bool] = True) -> Dict[str, torch.Tensor]:
|
||||
state_dict = model.state_dict()
|
||||
filtered_state_dict = {}
|
||||
|
||||
for k, v in model.named_parameters():
|
||||
if v.requires_grad:
|
||||
if (not trainable_only) or v.requires_grad:
|
||||
filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
|
||||
|
||||
return filtered_state_dict
|
||||
|
||||
@@ -1,143 +1,14 @@
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Template:
|
||||
|
||||
name: str
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
if self.name == "vanilla":
|
||||
r"""
|
||||
Supports language model inference without histories.
|
||||
"""
|
||||
self._register_template(
|
||||
prefix="",
|
||||
prompt="{query}",
|
||||
sep="",
|
||||
use_history=False
|
||||
)
|
||||
|
||||
elif self.name == "default":
|
||||
r"""
|
||||
Default template.
|
||||
"""
|
||||
self._register_template(
|
||||
prefix="A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
||||
prompt="Human: {query}\nAssistant: ",
|
||||
sep="\n",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
elif self.name == "alpaca":
|
||||
r"""
|
||||
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
|
||||
https://github.com/ymcui/Chinese-LLaMA-Alpaca
|
||||
"""
|
||||
self._register_template(
|
||||
prefix="Below is an instruction that describes a task. "
|
||||
"Write a response that appropriately completes the request.",
|
||||
prompt="### Instruction:\n{query}\n\n### Response:\n",
|
||||
sep="\n\n",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
elif self.name == "vicuna":
|
||||
r"""
|
||||
Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
|
||||
https://huggingface.co/lmsys/vicuna-13b-delta-v1.1
|
||||
"""
|
||||
self._register_template(
|
||||
prefix="A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
||||
prompt="USER: {query} ASSISTANT: ",
|
||||
sep="</s>",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
elif self.name == "belle":
|
||||
r"""
|
||||
Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
|
||||
"""
|
||||
self._register_template(
|
||||
prefix="",
|
||||
prompt="Human: {query}\n\nBelle: ",
|
||||
sep="\n\n",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
elif self.name == "linly":
|
||||
r"""
|
||||
Supports: https://github.com/CVI-SZU/Linly
|
||||
"""
|
||||
self._register_template(
|
||||
prefix="",
|
||||
prompt="User: {query}\nBot: ",
|
||||
sep="\n",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
elif self.name == "billa":
|
||||
r"""
|
||||
Supports: https://github.com/Neutralzz/BiLLa
|
||||
"""
|
||||
self._register_template(
|
||||
prefix="",
|
||||
prompt="Human: {query}\nAssistant: ",
|
||||
sep="\n",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
elif self.name == "ziya":
|
||||
r"""
|
||||
Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
|
||||
"""
|
||||
self._register_template(
|
||||
prefix="",
|
||||
prompt="<human>:{query}\n<bot>:",
|
||||
sep="\n",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
elif self.name == "aquila":
|
||||
r"""
|
||||
Supports: https://huggingface.co/qhduan/aquilachat-7b
|
||||
"""
|
||||
self._register_template(
|
||||
prefix="A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
||||
prompt="Human: {query}###Assistant: ",
|
||||
sep="###",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
elif self.name == "intern":
|
||||
r"""
|
||||
Supports: https://huggingface.co/internlm/internlm-chat-7b
|
||||
"""
|
||||
self._register_template(
|
||||
prefix="",
|
||||
prompt="<|User|>:{query}<eoh>\n<|Bot|>:",
|
||||
sep="<eoa>\n",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
elif self.name == "baichuan":
|
||||
r"""
|
||||
Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
|
||||
"""
|
||||
self._register_template(
|
||||
prefix="",
|
||||
prompt="<reserved_102>{query}<reserved_103>",
|
||||
sep="",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError("Template {} does not exist.".format(self.name))
|
||||
prefix: str
|
||||
prompt: str
|
||||
sep: str
|
||||
use_history: bool
|
||||
|
||||
def get_prompt(
|
||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
|
||||
@@ -155,18 +26,10 @@ class Template:
|
||||
"""
|
||||
return self._format_example(query, history, prefix) + [resp]
|
||||
|
||||
def _register_template(
|
||||
self, prefix: str, prompt: str, sep: str, use_history: Optional[bool] = True
|
||||
) -> None:
|
||||
self.prefix = prefix
|
||||
self.prompt = prompt
|
||||
self.sep = sep
|
||||
self.use_history = use_history
|
||||
|
||||
def _format_example(
|
||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
|
||||
) -> List[str]:
|
||||
prefix = prefix if prefix else self.prefix # use prefix if provided
|
||||
prefix = prefix or self.prefix # use prefix if provided
|
||||
prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix
|
||||
history = history if (history and self.use_history) else []
|
||||
history = history + [(query, "<dummy>")]
|
||||
@@ -179,3 +42,193 @@ class Template:
|
||||
convs.append(self.sep + self.prompt.format(query=user_query))
|
||||
convs.append(bot_resp)
|
||||
return convs[:-1] # drop last
|
||||
|
||||
|
||||
templates: Dict[str, Template] = {}
|
||||
|
||||
|
||||
def register_template(name: str, prefix: str, prompt: str, sep: str, use_history: bool) -> None:
|
||||
templates[name] = Template(
|
||||
prefix=prefix,
|
||||
prompt=prompt,
|
||||
sep=sep,
|
||||
use_history=use_history
|
||||
)
|
||||
|
||||
|
||||
def get_template(name: str) -> Template:
|
||||
template = templates.get(name, None)
|
||||
assert template is not None, "Template {} does not exist.".format(name)
|
||||
return template
|
||||
|
||||
|
||||
r"""
|
||||
Supports language model inference without histories.
|
||||
"""
|
||||
register_template(
|
||||
name="vanilla",
|
||||
prefix="",
|
||||
prompt="{query}",
|
||||
sep="",
|
||||
use_history=False
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Default template.
|
||||
"""
|
||||
register_template(
|
||||
name="default",
|
||||
prefix="A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
||||
prompt="Human: {query}\nAssistant: ",
|
||||
sep="\n",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
|
||||
https://huggingface.co/meta-llama/Llama-2-13b-chat-hf
|
||||
https://huggingface.co/meta-llama/Llama-2-70b-chat-hf
|
||||
"""
|
||||
register_template(
|
||||
name="llama2",
|
||||
prefix="<<SYS>>\nYou are a helpful, respectful and honest assistant. "
|
||||
"Always answer as helpfully as possible, while being safe. "
|
||||
"Your answers should not include any harmful, unethical, "
|
||||
"racist, sexist, toxic, dangerous, or illegal content. "
|
||||
"Please ensure that your responses are socially unbiased and positive in nature.\n"
|
||||
"If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n",
|
||||
prompt=" [INST] {query} [/INST] ",
|
||||
sep="</s>",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
|
||||
https://github.com/ymcui/Chinese-LLaMA-Alpaca
|
||||
"""
|
||||
register_template(
|
||||
name="alpaca",
|
||||
prefix="Below is an instruction that describes a task. "
|
||||
"Write a response that appropriately completes the request.",
|
||||
prompt="### Instruction:\n{query}\n\n### Response:\n",
|
||||
sep="\n\n",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
|
||||
https://huggingface.co/lmsys/vicuna-13b-delta-v1.1
|
||||
"""
|
||||
register_template(
|
||||
name="vicuna",
|
||||
prefix="A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
||||
prompt="USER: {query} ASSISTANT: ",
|
||||
sep="</s>",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
|
||||
"""
|
||||
register_template(
|
||||
name="belle",
|
||||
prefix="",
|
||||
prompt="Human: {query}\n\nBelle: ",
|
||||
sep="\n\n",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://github.com/CVI-SZU/Linly
|
||||
"""
|
||||
register_template(
|
||||
name="linly",
|
||||
prefix="",
|
||||
prompt="User: {query}\nBot: ",
|
||||
sep="\n",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://github.com/Neutralzz/BiLLa
|
||||
"""
|
||||
register_template(
|
||||
name="billa",
|
||||
prefix="",
|
||||
prompt="Human: {query}\nAssistant: ",
|
||||
sep="\n",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
|
||||
"""
|
||||
register_template(
|
||||
name="ziya",
|
||||
prefix="",
|
||||
prompt="<human>:{query}\n<bot>:",
|
||||
sep="\n",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/qhduan/aquilachat-7b
|
||||
"""
|
||||
register_template(
|
||||
name="aquila",
|
||||
prefix="A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
||||
prompt="Human: {query}###Assistant: ",
|
||||
sep="###",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/internlm/internlm-chat-7b
|
||||
"""
|
||||
register_template(
|
||||
name="intern",
|
||||
prefix="",
|
||||
prompt="<|User|>:{query}<eoh>\n<|Bot|>:",
|
||||
sep="<eoa>\n",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
|
||||
"""
|
||||
register_template(
|
||||
name="baichuan",
|
||||
prefix="",
|
||||
prompt="<reserved_102>{query}<reserved_103>",
|
||||
sep="</s>",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha
|
||||
https://huggingface.co/HuggingFaceH4/starchat-beta
|
||||
"""
|
||||
register_template(
|
||||
name="starchat",
|
||||
prefix="<|system|>\n",
|
||||
prompt="<|user|>\n{query}<|end|>\n<|assistant|>\n",
|
||||
sep="<|end|>\n",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
@@ -16,9 +16,10 @@ class FinetuningArguments:
|
||||
default=32,
|
||||
metadata={"help": "Number of decoder blocks in the model. \
|
||||
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
|
||||
LLaMA-2 choices: [\"32\", \"40\", \"80\"], \
|
||||
BLOOM choices: [\"24\", \"30\", \"70\"], \
|
||||
Falcon choices: [\"32\", \"60\"], \
|
||||
Baichuan choices: [\"32\"]"}
|
||||
Baichuan choices: [\"32\", \"40\"]"}
|
||||
)
|
||||
num_layer_trainable: Optional[int] = field(
|
||||
default=3,
|
||||
@@ -27,7 +28,7 @@ class FinetuningArguments:
|
||||
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
|
||||
default="mlp",
|
||||
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
|
||||
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
||||
LLaMA & LLaMA-2 choices: [\"mlp\", \"self_attn\"], \
|
||||
BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \
|
||||
Baichuan choices: [\"mlp\", \"self_attn\"]"}
|
||||
)
|
||||
@@ -46,7 +47,7 @@ class FinetuningArguments:
|
||||
lora_target: Optional[str] = field(
|
||||
default="q_proj,v_proj",
|
||||
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
|
||||
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||
LLaMA & LLaMA-2 choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
|
||||
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"}
|
||||
)
|
||||
|
||||
@@ -10,8 +10,9 @@ from transformers import (
|
||||
)
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.utils.versions import require_version
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers.tokenization_utils import PreTrainedTokenizerBase
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
@@ -26,9 +27,9 @@ logger = get_logger(__name__)
|
||||
|
||||
check_min_version("4.29.1")
|
||||
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
||||
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
|
||||
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
|
||||
require_version("trl>=0.4.4", "To fix: pip install trl>=0.4.4")
|
||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
|
||||
require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
@@ -36,7 +37,7 @@ def load_model_and_tokenizer(
|
||||
finetuning_args: FinetuningArguments,
|
||||
is_trainable: Optional[bool] = False,
|
||||
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
|
||||
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
|
||||
r"""
|
||||
Loads pretrained model and tokenizer.
|
||||
|
||||
@@ -80,9 +81,6 @@ def load_model_and_tokenizer(
|
||||
|
||||
elif model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
|
||||
require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3")
|
||||
require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
|
||||
config_kwargs["load_in_4bit"] = True
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
@@ -108,17 +106,17 @@ def load_model_and_tokenizer(
|
||||
model_to_load,
|
||||
config=config,
|
||||
torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
||||
**config_kwargs
|
||||
)
|
||||
|
||||
# Register auto class to save the custom code files.
|
||||
if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map and isinstance(config, PretrainedConfig):
|
||||
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
|
||||
config.__class__.register_for_auto_class()
|
||||
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map and isinstance(tokenizer, PreTrainedTokenizer):
|
||||
tokenizer.__class__.register_for_auto_class()
|
||||
if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map and isinstance(model, PreTrainedModel):
|
||||
if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
|
||||
model.__class__.register_for_auto_class()
|
||||
if isinstance(tokenizer, PreTrainedTokenizerBase) and "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
|
||||
tokenizer.__class__.register_for_auto_class()
|
||||
|
||||
# Initialize adapters
|
||||
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
|
||||
|
||||
@@ -54,7 +54,7 @@ def get_train_args(
|
||||
assert not (training_args.do_train and training_args.predict_with_generate), \
|
||||
"`predict_with_generate` cannot be set as True while training."
|
||||
|
||||
assert (not training_args.do_predict) or training_args.predict_with_generate, \
|
||||
assert general_args.stage != "sft" or (not training_args.do_predict) or training_args.predict_with_generate, \
|
||||
"Please enable `predict_with_generate` to save model predictions."
|
||||
|
||||
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
|
||||
|
||||
@@ -4,7 +4,8 @@ from typing import Dict, Optional
|
||||
|
||||
from transformers import Seq2SeqTrainer
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.modeling_utils import unwrap_model
|
||||
from transformers.modeling_utils import PreTrainedModel, unwrap_model
|
||||
from peft import PeftModel
|
||||
|
||||
from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME
|
||||
from llmtuner.extras.logging import get_logger
|
||||
@@ -49,18 +50,20 @@ class PeftTrainer(Seq2SeqTrainer):
|
||||
else:
|
||||
backbone_model = model
|
||||
|
||||
if self.finetuning_args.finetuning_type == "lora":
|
||||
if isinstance(backbone_model, PeftModel): # LoRA tuning
|
||||
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model))
|
||||
else: # freeze/full tuning
|
||||
elif isinstance(backbone_model, PreTrainedModel): # freeze/full tuning
|
||||
backbone_model.config.use_cache = True
|
||||
backbone_model.save_pretrained(
|
||||
output_dir,
|
||||
state_dict=get_state_dict(backbone_model),
|
||||
state_dict=get_state_dict(backbone_model, trainable_only=(self.finetuning_args.finetuning_type != "full")),
|
||||
safe_serialization=self.args.save_safetensors
|
||||
)
|
||||
backbone_model.config.use_cache = False
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
else:
|
||||
logger.warning("No model to save.")
|
||||
|
||||
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
|
||||
f.write(self.args.to_json_string() + "\n")
|
||||
@@ -77,8 +80,8 @@ class PeftTrainer(Seq2SeqTrainer):
|
||||
model = unwrap_model(self.model)
|
||||
backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model
|
||||
|
||||
if self.finetuning_args.finetuning_type == "lora":
|
||||
backbone_model.load_adapter(self.state.best_model_checkpoint, getattr(backbone_model, "active_adapter"))
|
||||
if isinstance(backbone_model, PeftModel):
|
||||
backbone_model.load_adapter(self.state.best_model_checkpoint, backbone_model.active_adapter)
|
||||
if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint):
|
||||
model.v_head.load_state_dict({
|
||||
"summary.weight": getattr(model, "reward_head_weight"),
|
||||
|
||||
@@ -25,7 +25,6 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
r"""
|
||||
Inherits PPOTrainer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
@@ -46,12 +45,13 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
r"""
|
||||
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
|
||||
"""
|
||||
total_train_batch_size = self.config.batch_size * self.config.gradient_accumulation_steps * self.args.world_size
|
||||
total_train_batch_size = (
|
||||
self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size
|
||||
)
|
||||
len_dataloader = len(self.dataloader)
|
||||
num_steps_per_epoch = max(len_dataloader // self.config.gradient_accumulation_steps, 1)
|
||||
num_examples = len(self.dataset)
|
||||
num_train_epochs = self.args.num_train_epochs
|
||||
max_steps = math.ceil(num_train_epochs * num_steps_per_epoch)
|
||||
max_steps = math.ceil(num_train_epochs * len_dataloader)
|
||||
|
||||
self.state.max_steps = max_steps
|
||||
self.state.num_train_epochs = num_train_epochs
|
||||
@@ -62,9 +62,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {num_examples}")
|
||||
logger.info(f" Num Epochs = {num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {self.config.batch_size}")
|
||||
logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {self.config.gradient_accumulation_steps}")
|
||||
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {max_steps}")
|
||||
logger.info(f" Number of trainable parameters = {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}")
|
||||
|
||||
@@ -77,7 +77,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
"eos_token_id": self.tokenizer.eos_token_id,
|
||||
"logits_processor": get_logits_processor()
|
||||
}
|
||||
output_length_sampler = LengthSampler(max_target_length // 2, max_target_length)
|
||||
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
|
||||
unwrapped_model: PreTrainedModel = self.accelerator.unwrap_model(self.model)
|
||||
|
||||
dataiter = iter(self.dataloader)
|
||||
@@ -87,59 +87,49 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
self.log_callback.on_train_begin(self.args, self.state, self.control)
|
||||
|
||||
for step in tqdm(range(max_steps), disable=not self.is_world_process_zero(), leave=False):
|
||||
batch = next(dataiter)
|
||||
steps_trained += 1
|
||||
|
||||
for _ in range(self.config.gradient_accumulation_steps):
|
||||
unwrapped_model.gradient_checkpointing_disable()
|
||||
unwrapped_model.config.use_cache = True
|
||||
|
||||
batch = next(dataiter)
|
||||
steps_trained += 1
|
||||
# Get responses
|
||||
query_tensors = batch["input_ids"]
|
||||
response_tensors = self.generate(batch, length_sampler, return_prompt=False, **gen_kwargs)
|
||||
|
||||
unwrapped_model.gradient_checkpointing_disable()
|
||||
unwrapped_model.config.use_cache = True
|
||||
queries, responses = [], []
|
||||
for i in range(len(query_tensors)):
|
||||
query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0]
|
||||
response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||
queries.append(query_tensors[i, query_length:]) # remove padding from left
|
||||
responses.append(response_tensors[i, :response_length]) # remove padding from right
|
||||
|
||||
# Get response from model
|
||||
query_tensors: torch.Tensor = batch["input_ids"]
|
||||
response_tensors = self.generate(batch, length_sampler=output_length_sampler, return_prompt=False, **gen_kwargs)
|
||||
# Compute rewards
|
||||
replace_model(unwrapped_model, target="reward")
|
||||
with torch.no_grad():
|
||||
_, _, values = self.model(
|
||||
**self.prepare_model_inputs(queries, responses),
|
||||
output_hidden_states=True,
|
||||
return_dict=True
|
||||
)
|
||||
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
|
||||
replace_model(unwrapped_model, target="default")
|
||||
|
||||
queries: List[torch.Tensor] = []
|
||||
responses: List[torch.Tensor] = []
|
||||
for i in range(len(query_tensors)):
|
||||
query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0]
|
||||
response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||
queries.append(query_tensors[i, query_length:]) # remove padding from left
|
||||
if response_length < 2: # make response have at least 2 tokens
|
||||
responses.append(response_tensors.new_empty(2).fill_(self.tokenizer.eos_token_id))
|
||||
else:
|
||||
responses.append(response_tensors[i, :response_length]) # remove padding from right
|
||||
# Run PPO step
|
||||
unwrapped_model.gradient_checkpointing_enable()
|
||||
unwrapped_model.config.use_cache = False
|
||||
stats = self.step(queries, responses, rewards)
|
||||
|
||||
# Compute rewards
|
||||
replace_model(unwrapped_model, target="reward")
|
||||
_, _, values = self.model(**self.prepare_model_inputs(queries, responses))
|
||||
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
|
||||
replace_model(unwrapped_model, target="default") # make sure the model is default at the end
|
||||
|
||||
# Run PPO step
|
||||
unwrapped_model.gradient_checkpointing_enable()
|
||||
unwrapped_model.config.use_cache = False
|
||||
|
||||
stats = self.step(queries, responses, rewards)
|
||||
|
||||
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
|
||||
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
|
||||
|
||||
if self.control.should_epoch_stop or self.control.should_training_stop:
|
||||
break
|
||||
|
||||
if steps_trained == len_dataloader:
|
||||
dataiter = iter(self.dataloader)
|
||||
steps_trained = 0
|
||||
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
|
||||
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
|
||||
|
||||
if self.is_world_process_zero() and (step+1) % self.args.logging_steps == 0:
|
||||
logs = {
|
||||
"loss": round(loss_meter.avg, 4),
|
||||
"reward": round(reward_meter.avg, 4),
|
||||
"learning_rate": stats["ppo/learning_rate"],
|
||||
"epoch": round(step / num_steps_per_epoch, 2)
|
||||
}
|
||||
logs = dict(
|
||||
loss=round(loss_meter.avg, 4),
|
||||
reward=round(reward_meter.avg, 4),
|
||||
learning_rate=stats["ppo/learning_rate"],
|
||||
epoch=round(step / len_dataloader, 2)
|
||||
)
|
||||
print(logs)
|
||||
logs["step"] = step
|
||||
self.state.log_history.append(logs)
|
||||
@@ -150,9 +140,13 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
if (step+1) % self.args.save_steps == 0: # save checkpoint
|
||||
self.save_model(os.path.join(self.args.output_dir, f"checkpoint-{step+1}"))
|
||||
|
||||
if self.control.should_training_stop:
|
||||
if self.control.should_epoch_stop or self.control.should_training_stop:
|
||||
break
|
||||
|
||||
if steps_trained == len_dataloader:
|
||||
dataiter = iter(self.dataloader)
|
||||
steps_trained = 0
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
|
||||
@@ -4,7 +4,8 @@
|
||||
import math
|
||||
from trl import PPOConfig
|
||||
from torch.optim import AdamW
|
||||
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
|
||||
from typing import Optional, List
|
||||
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, TrainerCallback
|
||||
from transformers.optimization import get_scheduler
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||
@@ -19,7 +20,8 @@ def run_ppo(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
|
||||
@@ -30,7 +32,7 @@ def run_ppo(
|
||||
model_name=model_args.model_name_or_path,
|
||||
learning_rate=training_args.learning_rate,
|
||||
mini_batch_size=training_args.per_device_train_batch_size,
|
||||
batch_size=training_args.per_device_train_batch_size,
|
||||
batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
|
||||
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
|
||||
ppo_epochs=1,
|
||||
max_grad_norm=training_args.max_grad_norm
|
||||
@@ -50,7 +52,7 @@ def run_ppo(
|
||||
ppo_trainer = PPOPeftTrainer(
|
||||
training_args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
callbacks=[LogCallback()],
|
||||
callbacks=callbacks,
|
||||
config=ppo_config,
|
||||
model=model,
|
||||
ref_model=None,
|
||||
|
||||
@@ -4,7 +4,7 @@ import math
|
||||
from typing import Optional, List
|
||||
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
@@ -28,16 +28,6 @@ def run_pt(
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
# Split the dataset
|
||||
if training_args.do_train:
|
||||
if data_args.dev_ratio > 1e-6:
|
||||
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
|
||||
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
||||
else:
|
||||
trainer_kwargs = {"train_dataset": dataset}
|
||||
else: # do_eval or do_predict
|
||||
trainer_kwargs = {"eval_dataset": dataset}
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = PeftTrainer(
|
||||
finetuning_args=finetuning_args,
|
||||
@@ -46,7 +36,7 @@ def run_pt(
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
**trainer_kwargs
|
||||
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
|
||||
)
|
||||
|
||||
# Training
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from transformers.trainer import PredictionOutput
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PairwisePeftTrainer(PeftTrainer):
|
||||
r"""
|
||||
Inherits PeftTrainer to compute pairwise loss.
|
||||
@@ -32,7 +39,30 @@ class PairwisePeftTrainer(PeftTrainer):
|
||||
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
|
||||
"""
|
||||
batch_size = inputs["input_ids"].size(0) // 2
|
||||
_, _, values = model(**inputs)
|
||||
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
||||
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
|
||||
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
|
||||
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss
|
||||
|
||||
def save_predictions(
|
||||
self,
|
||||
predict_results: PredictionOutput
|
||||
) -> None:
|
||||
r"""
|
||||
Saves model predictions to `output_dir`.
|
||||
|
||||
A custom behavior that not contained in Seq2SeqTrainer.
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
|
||||
logger.info(f"Saving prediction results to {output_prediction_file}")
|
||||
|
||||
acc_scores, rej_scores = predict_results.predictions
|
||||
|
||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
||||
res: List[str] = []
|
||||
for acc_score, rej_score in zip(acc_scores, rej_scores):
|
||||
res.append(json.dumps({"accept": round(float(acc_score), 2), "reject": round(float(rej_score), 2)}))
|
||||
writer.write("\n".join(res))
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
|
||||
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
||||
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from typing import Optional, List
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
@@ -18,7 +19,8 @@ def run_rm(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
|
||||
@@ -27,16 +29,6 @@ def run_rm(
|
||||
|
||||
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||
|
||||
# Split the dataset
|
||||
if training_args.do_train:
|
||||
if data_args.dev_ratio > 1e-6:
|
||||
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
|
||||
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
||||
else:
|
||||
trainer_kwargs = {"train_dataset": dataset}
|
||||
else: # do_eval or do_predict
|
||||
trainer_kwargs = {"eval_dataset": dataset}
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = PairwisePeftTrainer(
|
||||
finetuning_args=finetuning_args,
|
||||
@@ -44,9 +36,9 @@ def run_rm(
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=[LogCallback()],
|
||||
callbacks=callbacks,
|
||||
compute_metrics=compute_accuracy,
|
||||
**trainer_kwargs
|
||||
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
|
||||
)
|
||||
|
||||
# Training
|
||||
@@ -64,3 +56,10 @@ def run_rm(
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
# Predict
|
||||
if training_args.do_predict:
|
||||
predict_results = trainer.predict(dataset, metric_key_prefix="predict")
|
||||
trainer.log_metrics("predict", predict_results.metrics)
|
||||
trainer.save_metrics("predict", predict_results.metrics)
|
||||
trainer.save_predictions(predict_results)
|
||||
|
||||
@@ -23,7 +23,7 @@ class ComputeMetrics:
|
||||
Uses the model predictions to compute metrics.
|
||||
"""
|
||||
preds, labels = eval_preds
|
||||
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
||||
score_dict = {"accuracy": [], "rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
||||
|
||||
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
||||
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
|
||||
@@ -47,5 +47,6 @@ class ComputeMetrics:
|
||||
|
||||
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
|
||||
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
|
||||
score_dict["accuracy"].append(float(len(label) != 0 and pred[:len(label)] == label))
|
||||
|
||||
return {k: float(np.mean(v)) for k, v in score_dict.items()}
|
||||
|
||||
@@ -32,17 +32,53 @@ class Seq2SeqPeftTrainer(PeftTrainer):
|
||||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
|
||||
if self.tokenizer.padding_side == "right": # pads the labels to the same length as the inputs
|
||||
inputs["labels"] = torch.cat((inputs["labels"], torch.zeros_like(inputs["input_ids"])[:, label_len:]), dim=-1)
|
||||
else:
|
||||
inputs["labels"] = torch.cat((torch.zeros_like(inputs["input_ids"])[:, label_len:], inputs["labels"]), dim=-1)
|
||||
if prompt_len > label_len:
|
||||
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
|
||||
if label_len > prompt_len:
|
||||
inputs["input_ids"] = self._pad_tensors_to_target_len(inputs["input_ids"], inputs["labels"])
|
||||
if "attention_mask" in inputs:
|
||||
inputs["attention_mask"] = self._pad_tensors_to_target_len(
|
||||
inputs["attention_mask"], inputs["labels"], pad_token_id=0
|
||||
)
|
||||
if "position_ids" in inputs:
|
||||
inputs["position_ids"] = self._pad_tensors_to_target_len(
|
||||
inputs["position_ids"], inputs["labels"], pad_token_id=0
|
||||
)
|
||||
|
||||
loss, generated_tokens, labels = super().prediction_step(
|
||||
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
||||
)
|
||||
generated_tokens = generated_tokens[:, prompt_len:] if generated_tokens is not None else None
|
||||
generated_tokens = (
|
||||
generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None
|
||||
)
|
||||
|
||||
return (loss, generated_tokens, labels)
|
||||
|
||||
def _pad_tensors_to_target_len(
|
||||
self,
|
||||
src_tensor: torch.Tensor,
|
||||
tgt_tensor: torch.Tensor,
|
||||
pad_token_id: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Pads the tensor to the same length as the target tensor.
|
||||
|
||||
Should only be called when predict_with_generate=True.
|
||||
"""
|
||||
if pad_token_id is None:
|
||||
if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
|
||||
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
else:
|
||||
if self.model.config.pad_token_id is not None:
|
||||
pad_token_id = self.model.config.pad_token_id
|
||||
else:
|
||||
raise ValueError("Pad_token_id must be set in the configuration of the model.")
|
||||
|
||||
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
|
||||
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
|
||||
return padded_tensor
|
||||
|
||||
def save_predictions(
|
||||
self,
|
||||
predict_results: PredictionOutput
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from typing import Optional, List
|
||||
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.misc import get_logits_processor
|
||||
@@ -35,16 +35,6 @@ def run_sft(
|
||||
training_args.generation_num_beams = data_args.eval_num_beams if \
|
||||
data_args.eval_num_beams is not None else training_args.generation_num_beams
|
||||
|
||||
# Split the dataset
|
||||
if training_args.do_train:
|
||||
if data_args.dev_ratio > 1e-6:
|
||||
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
|
||||
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
||||
else:
|
||||
trainer_kwargs = {"train_dataset": dataset}
|
||||
else: # do_eval or do_predict
|
||||
trainer_kwargs = {"eval_dataset": dataset}
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Seq2SeqPeftTrainer(
|
||||
finetuning_args=finetuning_args,
|
||||
@@ -54,7 +44,7 @@ def run_sft(
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
||||
**trainer_kwargs
|
||||
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
|
||||
)
|
||||
|
||||
# Keyword arguments for `model.generate`
|
||||
|
||||
0
src/llmtuner/webui/__init__.py
Normal file
0
src/llmtuner/webui/__init__.py
Normal file
95
src/llmtuner/webui/chat.py
Normal file
95
src/llmtuner/webui/chat.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
from llmtuner.chat.stream_chat import ChatModel
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.hparams import GeneratingArguments
|
||||
from llmtuner.tuner import get_infer_args
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
|
||||
|
||||
class WebChatModel(ChatModel):
|
||||
|
||||
def __init__(self, *args):
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.generating_args = GeneratingArguments()
|
||||
if len(args) != 0:
|
||||
super().__init__(*args)
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
lang: str,
|
||||
model_name: str,
|
||||
checkpoints: List[str],
|
||||
finetuning_type: str,
|
||||
quantization_bit: str,
|
||||
template: str,
|
||||
source_prefix: str
|
||||
):
|
||||
if self.model is not None:
|
||||
yield ALERTS["err_exists"][lang]
|
||||
return
|
||||
|
||||
if not model_name:
|
||||
yield ALERTS["err_no_model"][lang]
|
||||
return
|
||||
|
||||
model_name_or_path = get_model_path(model_name)
|
||||
if not model_name_or_path:
|
||||
yield ALERTS["err_no_path"][lang]
|
||||
return
|
||||
|
||||
if checkpoints:
|
||||
checkpoint_dir = ",".join(
|
||||
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
|
||||
)
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
|
||||
yield ALERTS["info_loading"][lang]
|
||||
args = dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||
prompt_template=template,
|
||||
source_prefix=source_prefix
|
||||
)
|
||||
super().__init__(*get_infer_args(args))
|
||||
|
||||
yield ALERTS["info_loaded"][lang]
|
||||
|
||||
def unload_model(self, lang: str):
|
||||
yield ALERTS["info_unloading"][lang]
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
torch_gc()
|
||||
yield ALERTS["info_unloaded"][lang]
|
||||
|
||||
def predict(
|
||||
self,
|
||||
chatbot: List[Tuple[str, str]],
|
||||
query: str,
|
||||
history: List[Tuple[str, str]],
|
||||
prefix: str,
|
||||
max_new_tokens: int,
|
||||
top_p: float,
|
||||
temperature: float
|
||||
):
|
||||
chatbot.append([query, ""])
|
||||
response = ""
|
||||
for new_text in self.stream_chat(
|
||||
query, history, prefix, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
):
|
||||
response += new_text
|
||||
response = self.postprocess(response)
|
||||
new_history = history + [(query, response)]
|
||||
chatbot[-1] = [query, response]
|
||||
yield chatbot, new_history
|
||||
|
||||
def postprocess(self, response: str) -> str:
|
||||
response = response.replace("<", "<")
|
||||
response = response.replace(">", ">")
|
||||
return response
|
||||
75
src/llmtuner/webui/common.py
Normal file
75
src/llmtuner/webui/common.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import gradio as gr
|
||||
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
|
||||
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
||||
|
||||
from llmtuner.extras.constants import SUPPORTED_MODELS
|
||||
|
||||
|
||||
DEFAULT_CACHE_DIR = "cache"
|
||||
DEFAULT_DATA_DIR = "data"
|
||||
DEFAULT_SAVE_DIR = "saves"
|
||||
USER_CONFIG = "user.config"
|
||||
DATA_CONFIG = "dataset_info.json"
|
||||
|
||||
|
||||
def get_save_dir(model_name: str) -> str:
|
||||
return os.path.join(DEFAULT_SAVE_DIR, os.path.split(model_name)[-1])
|
||||
|
||||
|
||||
def get_config_path() -> os.PathLike:
|
||||
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
|
||||
|
||||
|
||||
def load_config() -> Dict[str, Any]:
|
||||
try:
|
||||
with open(get_config_path(), "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except:
|
||||
return {"last_model": "", "path_dict": {}}
|
||||
|
||||
|
||||
def save_config(model_name: str, model_path: str) -> None:
|
||||
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
|
||||
user_config = load_config()
|
||||
user_config["last_model"] = model_name
|
||||
user_config["path_dict"][model_name] = model_path
|
||||
with open(get_config_path(), "w", encoding="utf-8") as f:
|
||||
json.dump(user_config, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def get_model_path(model_name: str) -> str:
|
||||
user_config = load_config()
|
||||
return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, ""))
|
||||
|
||||
|
||||
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
|
||||
checkpoints = []
|
||||
save_dir = os.path.join(get_save_dir(model_name), finetuning_type)
|
||||
if save_dir and os.path.isdir(save_dir):
|
||||
for checkpoint in os.listdir(save_dir):
|
||||
if (
|
||||
os.path.isdir(os.path.join(save_dir, checkpoint))
|
||||
and any([
|
||||
os.path.isfile(os.path.join(save_dir, checkpoint, name))
|
||||
for name in (WEIGHTS_NAME, WEIGHTS_INDEX_NAME, PEFT_WEIGHTS_NAME)
|
||||
])
|
||||
):
|
||||
checkpoints.append(checkpoint)
|
||||
return gr.update(value=[], choices=checkpoints)
|
||||
|
||||
|
||||
def load_dataset_info(dataset_dir: str) -> Dict[str, Any]:
|
||||
try:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except:
|
||||
return {}
|
||||
|
||||
|
||||
def list_dataset(dataset_dir: Optional[str] = None) -> Dict[str, Any]:
|
||||
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
|
||||
return gr.update(value=[], choices=list(dataset_info.keys()))
|
||||
5
src/llmtuner/webui/components/__init__.py
Normal file
5
src/llmtuner/webui/components/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from llmtuner.webui.components.top import create_top
|
||||
from llmtuner.webui.components.sft import create_sft_tab
|
||||
from llmtuner.webui.components.eval import create_eval_tab
|
||||
from llmtuner.webui.components.infer import create_infer_tab
|
||||
from llmtuner.webui.components.export import create_export_tab
|
||||
50
src/llmtuner/webui/components/chatbot.py
Normal file
50
src/llmtuner/webui/components/chatbot.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import gradio as gr
|
||||
from gradio.blocks import Block
|
||||
from gradio.components import Component
|
||||
|
||||
from llmtuner.webui.chat import WebChatModel
|
||||
|
||||
|
||||
def create_chat_box(
|
||||
chat_model: WebChatModel,
|
||||
visible: Optional[bool] = False
|
||||
) -> Tuple[Block, Component, Component, Dict[str, Component]]:
|
||||
with gr.Box(visible=visible) as chat_box:
|
||||
chatbot = gr.Chatbot()
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
prefix = gr.Textbox(show_label=False)
|
||||
query = gr.Textbox(show_label=False, lines=8)
|
||||
submit_btn = gr.Button(variant="primary")
|
||||
|
||||
with gr.Column(scale=1):
|
||||
clear_btn = gr.Button()
|
||||
max_new_tokens = gr.Slider(10, 2048, value=chat_model.generating_args.max_new_tokens, step=1)
|
||||
top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01)
|
||||
temperature = gr.Slider(0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01)
|
||||
|
||||
history = gr.State([])
|
||||
|
||||
submit_btn.click(
|
||||
chat_model.predict,
|
||||
[chatbot, query, history, prefix, max_new_tokens, top_p, temperature],
|
||||
[chatbot, history],
|
||||
show_progress=True
|
||||
).then(
|
||||
lambda: gr.update(value=""), outputs=[query]
|
||||
)
|
||||
|
||||
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
|
||||
|
||||
return chat_box, chatbot, history, dict(
|
||||
prefix=prefix,
|
||||
query=query,
|
||||
submit_btn=submit_btn,
|
||||
clear_btn=clear_btn,
|
||||
max_new_tokens=max_new_tokens,
|
||||
top_p=top_p,
|
||||
temperature=temperature
|
||||
)
|
||||
19
src/llmtuner/webui/components/data.py
Normal file
19
src/llmtuner/webui/components/data.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import gradio as gr
|
||||
from gradio.blocks import Block
|
||||
from gradio.components import Component
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def create_preview_box() -> Tuple[Block, Component, Component, Component]:
|
||||
with gr.Box(visible=False, elem_classes="modal-box") as preview_box:
|
||||
with gr.Row():
|
||||
preview_count = gr.Number(interactive=False)
|
||||
|
||||
with gr.Row():
|
||||
preview_samples = gr.JSON(interactive=False)
|
||||
|
||||
close_btn = gr.Button()
|
||||
|
||||
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box])
|
||||
|
||||
return preview_box, preview_count, preview_samples, close_btn
|
||||
74
src/llmtuner/webui/components/eval.py
Normal file
74
src/llmtuner/webui/components/eval.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from typing import Dict
|
||||
import gradio as gr
|
||||
from gradio.components import Component
|
||||
|
||||
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
|
||||
from llmtuner.webui.components.data import create_preview_box
|
||||
from llmtuner.webui.runner import Runner
|
||||
from llmtuner.webui.utils import can_preview, get_preview
|
||||
|
||||
|
||||
def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
|
||||
with gr.Row():
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
preview_btn = gr.Button(interactive=False, scale=1)
|
||||
|
||||
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
||||
|
||||
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
|
||||
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn])
|
||||
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
|
||||
|
||||
with gr.Row():
|
||||
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
||||
max_target_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
||||
max_samples = gr.Textbox(value="100000")
|
||||
batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1)
|
||||
predict = gr.Checkbox(value=True)
|
||||
|
||||
with gr.Row():
|
||||
start_btn = gr.Button()
|
||||
stop_btn = gr.Button()
|
||||
|
||||
with gr.Box():
|
||||
output_box = gr.Markdown()
|
||||
|
||||
start_btn.click(
|
||||
runner.run_eval,
|
||||
[
|
||||
top_elems["lang"],
|
||||
top_elems["model_name"],
|
||||
top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"],
|
||||
top_elems["quantization_bit"],
|
||||
top_elems["template"],
|
||||
top_elems["source_prefix"],
|
||||
dataset_dir,
|
||||
dataset,
|
||||
max_source_length,
|
||||
max_target_length,
|
||||
max_samples,
|
||||
batch_size,
|
||||
predict
|
||||
],
|
||||
[output_box]
|
||||
)
|
||||
stop_btn.click(runner.set_abort, queue=False)
|
||||
|
||||
return dict(
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=dataset,
|
||||
preview_btn=preview_btn,
|
||||
preview_count=preview_count,
|
||||
preview_samples=preview_samples,
|
||||
close_btn=close_btn,
|
||||
max_source_length=max_source_length,
|
||||
max_target_length=max_target_length,
|
||||
max_samples=max_samples,
|
||||
batch_size=batch_size,
|
||||
predict=predict,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
||||
output_box=output_box
|
||||
)
|
||||
34
src/llmtuner/webui/components/export.py
Normal file
34
src/llmtuner/webui/components/export.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from typing import Dict
|
||||
import gradio as gr
|
||||
from gradio.components import Component
|
||||
|
||||
from llmtuner.webui.utils import export_model
|
||||
|
||||
|
||||
def create_export_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
|
||||
with gr.Row():
|
||||
save_dir = gr.Textbox()
|
||||
max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)
|
||||
|
||||
export_btn = gr.Button()
|
||||
info_box = gr.Textbox(show_label=False, interactive=False)
|
||||
|
||||
export_btn.click(
|
||||
export_model,
|
||||
[
|
||||
top_elems["lang"],
|
||||
top_elems["model_name"],
|
||||
top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"],
|
||||
max_shard_size,
|
||||
save_dir
|
||||
],
|
||||
[info_box]
|
||||
)
|
||||
|
||||
return dict(
|
||||
save_dir=save_dir,
|
||||
max_shard_size=max_shard_size,
|
||||
export_btn=export_btn,
|
||||
info_box=info_box
|
||||
)
|
||||
49
src/llmtuner/webui/components/infer.py
Normal file
49
src/llmtuner/webui/components/infer.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from typing import Dict
|
||||
|
||||
import gradio as gr
|
||||
from gradio.components import Component
|
||||
|
||||
from llmtuner.webui.chat import WebChatModel
|
||||
from llmtuner.webui.components.chatbot import create_chat_box
|
||||
|
||||
|
||||
def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
|
||||
with gr.Row():
|
||||
load_btn = gr.Button()
|
||||
unload_btn = gr.Button()
|
||||
|
||||
info_box = gr.Textbox(show_label=False, interactive=False)
|
||||
|
||||
chat_model = WebChatModel()
|
||||
chat_box, chatbot, history, chat_elems = create_chat_box(chat_model)
|
||||
|
||||
load_btn.click(
|
||||
chat_model.load_model,
|
||||
[
|
||||
top_elems["lang"],
|
||||
top_elems["model_name"],
|
||||
top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"],
|
||||
top_elems["quantization_bit"],
|
||||
top_elems["template"],
|
||||
top_elems["source_prefix"]
|
||||
],
|
||||
[info_box]
|
||||
).then(
|
||||
lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box]
|
||||
)
|
||||
|
||||
unload_btn.click(
|
||||
chat_model.unload_model, [top_elems["lang"]], [info_box]
|
||||
).then(
|
||||
lambda: ([], []), outputs=[chatbot, history]
|
||||
).then(
|
||||
lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box]
|
||||
)
|
||||
|
||||
return dict(
|
||||
info_box=info_box,
|
||||
load_btn=load_btn,
|
||||
unload_btn=unload_btn,
|
||||
**chat_elems
|
||||
)
|
||||
138
src/llmtuner/webui/components/sft.py
Normal file
138
src/llmtuner/webui/components/sft.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from typing import Dict
|
||||
from transformers.trainer_utils import SchedulerType
|
||||
|
||||
import gradio as gr
|
||||
from gradio.components import Component
|
||||
|
||||
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
|
||||
from llmtuner.webui.components.data import create_preview_box
|
||||
from llmtuner.webui.runner import Runner
|
||||
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
|
||||
|
||||
|
||||
def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
|
||||
with gr.Row():
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
preview_btn = gr.Button(interactive=False, scale=1)
|
||||
|
||||
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
||||
|
||||
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
|
||||
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn])
|
||||
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
|
||||
|
||||
with gr.Row():
|
||||
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
||||
max_target_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
||||
learning_rate = gr.Textbox(value="5e-5")
|
||||
num_train_epochs = gr.Textbox(value="3.0")
|
||||
max_samples = gr.Textbox(value="100000")
|
||||
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
||||
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
||||
lr_scheduler_type = gr.Dropdown(
|
||||
value="cosine", choices=[scheduler.value for scheduler in SchedulerType]
|
||||
)
|
||||
max_grad_norm = gr.Textbox(value="1.0")
|
||||
dev_ratio = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
|
||||
|
||||
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
||||
with gr.Row():
|
||||
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
|
||||
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
|
||||
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
|
||||
compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")
|
||||
|
||||
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
|
||||
with gr.Row():
|
||||
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
|
||||
lora_dropout = gr.Slider(value=0, minimum=0, maximum=1, step=0.01, scale=1)
|
||||
lora_target = gr.Textbox(scale=2)
|
||||
|
||||
with gr.Row():
|
||||
start_btn = gr.Button()
|
||||
stop_btn = gr.Button()
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
output_dir = gr.Textbox()
|
||||
|
||||
with gr.Box():
|
||||
output_box = gr.Markdown()
|
||||
|
||||
with gr.Column(scale=1):
|
||||
loss_viewer = gr.Plot()
|
||||
|
||||
start_btn.click(
|
||||
runner.run_train,
|
||||
[
|
||||
top_elems["lang"],
|
||||
top_elems["model_name"],
|
||||
top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"],
|
||||
top_elems["quantization_bit"],
|
||||
top_elems["template"],
|
||||
top_elems["source_prefix"],
|
||||
dataset_dir,
|
||||
dataset,
|
||||
max_source_length,
|
||||
max_target_length,
|
||||
learning_rate,
|
||||
num_train_epochs,
|
||||
max_samples,
|
||||
batch_size,
|
||||
gradient_accumulation_steps,
|
||||
lr_scheduler_type,
|
||||
max_grad_norm,
|
||||
dev_ratio,
|
||||
logging_steps,
|
||||
save_steps,
|
||||
warmup_steps,
|
||||
compute_type,
|
||||
lora_rank,
|
||||
lora_dropout,
|
||||
lora_target,
|
||||
output_dir
|
||||
],
|
||||
[output_box]
|
||||
)
|
||||
stop_btn.click(runner.set_abort, queue=False)
|
||||
|
||||
output_box.change(
|
||||
gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False
|
||||
)
|
||||
|
||||
return dict(
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=dataset,
|
||||
preview_btn=preview_btn,
|
||||
preview_count=preview_count,
|
||||
preview_samples=preview_samples,
|
||||
close_btn=close_btn,
|
||||
max_source_length=max_source_length,
|
||||
max_target_length=max_target_length,
|
||||
learning_rate=learning_rate,
|
||||
num_train_epochs=num_train_epochs,
|
||||
max_samples=max_samples,
|
||||
batch_size=batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
lr_scheduler_type=lr_scheduler_type,
|
||||
max_grad_norm=max_grad_norm,
|
||||
dev_ratio=dev_ratio,
|
||||
advanced_tab=advanced_tab,
|
||||
logging_steps=logging_steps,
|
||||
save_steps=save_steps,
|
||||
warmup_steps=warmup_steps,
|
||||
compute_type=compute_type,
|
||||
lora_tab=lora_tab,
|
||||
lora_rank=lora_rank,
|
||||
lora_dropout=lora_dropout,
|
||||
lora_target=lora_target,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
||||
output_dir=output_dir,
|
||||
output_box=output_box,
|
||||
loss_viewer=loss_viewer
|
||||
)
|
||||
57
src/llmtuner/webui/components/top.py
Normal file
57
src/llmtuner/webui/components/top.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import Dict
|
||||
|
||||
import gradio as gr
|
||||
from gradio.components import Component
|
||||
|
||||
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
|
||||
from llmtuner.extras.template import templates
|
||||
from llmtuner.webui.common import list_checkpoint, get_model_path, save_config
|
||||
from llmtuner.webui.utils import can_quantize
|
||||
|
||||
|
||||
def create_top() -> Dict[str, Component]:
|
||||
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
||||
|
||||
with gr.Row():
|
||||
lang = gr.Dropdown(choices=["en", "zh"], value="en", scale=1)
|
||||
model_name = gr.Dropdown(choices=available_models, scale=3)
|
||||
model_path = gr.Textbox(scale=3)
|
||||
|
||||
with gr.Row():
|
||||
finetuning_type = gr.Dropdown(value="lora", choices=METHODS, scale=1)
|
||||
checkpoints = gr.Dropdown(multiselect=True, scale=5)
|
||||
refresh_btn = gr.Button(scale=1)
|
||||
|
||||
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
||||
with gr.Row():
|
||||
quantization_bit = gr.Dropdown([8, 4], scale=1)
|
||||
template = gr.Dropdown(value="default", choices=list(templates.keys()), scale=1)
|
||||
source_prefix = gr.Textbox(scale=2)
|
||||
|
||||
model_name.change(
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||
).then(
|
||||
get_model_path, [model_name], [model_path]
|
||||
) # do not save config since the below line will save
|
||||
model_path.change(save_config, [model_name, model_path])
|
||||
|
||||
finetuning_type.change(
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||
).then(
|
||||
can_quantize, [finetuning_type], [quantization_bit]
|
||||
)
|
||||
|
||||
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints])
|
||||
|
||||
return dict(
|
||||
lang=lang,
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
finetuning_type=finetuning_type,
|
||||
checkpoints=checkpoints,
|
||||
refresh_btn=refresh_btn,
|
||||
advanced_tab=advanced_tab,
|
||||
quantization_bit=quantization_bit,
|
||||
template=template,
|
||||
source_prefix=source_prefix
|
||||
)
|
||||
18
src/llmtuner/webui/css.py
Normal file
18
src/llmtuner/webui/css.py
Normal file
@@ -0,0 +1,18 @@
|
||||
CSS = r"""
|
||||
.modal-box {
|
||||
position: fixed !important;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
transform: translate(-50%, -50%); /* center horizontally */
|
||||
max-width: 1000px;
|
||||
max-height: 750px;
|
||||
overflow-y: scroll !important;
|
||||
background-color: var(--input-background-fill);
|
||||
border: 2px solid black !important;
|
||||
z-index: 1000;
|
||||
}
|
||||
|
||||
.dark .modal-box {
|
||||
border: 2px solid white !important;
|
||||
}
|
||||
"""
|
||||
58
src/llmtuner/webui/interface.py
Normal file
58
src/llmtuner/webui/interface.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import gradio as gr
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from llmtuner.webui.components import (
|
||||
create_top,
|
||||
create_sft_tab,
|
||||
create_eval_tab,
|
||||
create_infer_tab,
|
||||
create_export_tab
|
||||
)
|
||||
from llmtuner.webui.css import CSS
|
||||
from llmtuner.webui.manager import Manager
|
||||
from llmtuner.webui.runner import Runner
|
||||
|
||||
|
||||
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
|
||||
|
||||
|
||||
def create_ui() -> gr.Blocks:
|
||||
runner = Runner()
|
||||
|
||||
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
|
||||
top_elems = create_top()
|
||||
|
||||
with gr.Tab("SFT"):
|
||||
sft_elems = create_sft_tab(top_elems, runner)
|
||||
|
||||
with gr.Tab("Evaluate"):
|
||||
eval_elems = create_eval_tab(top_elems, runner)
|
||||
|
||||
with gr.Tab("Chat"):
|
||||
infer_elems = create_infer_tab(top_elems)
|
||||
|
||||
with gr.Tab("Export"):
|
||||
export_elems = create_export_tab(top_elems)
|
||||
|
||||
elem_list = [top_elems, sft_elems, eval_elems, infer_elems, export_elems]
|
||||
manager = Manager(elem_list)
|
||||
|
||||
demo.load(
|
||||
manager.gen_label,
|
||||
[top_elems["lang"]],
|
||||
[elem for elems in elem_list for elem in elems.values()],
|
||||
)
|
||||
|
||||
top_elems["lang"].change(
|
||||
manager.gen_label,
|
||||
[top_elems["lang"]],
|
||||
[elem for elems in elem_list for elem in elems.values()],
|
||||
)
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo = create_ui()
|
||||
demo.queue()
|
||||
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
|
||||
552
src/llmtuner/webui/locales.py
Normal file
552
src/llmtuner/webui/locales.py
Normal file
@@ -0,0 +1,552 @@
|
||||
LOCALES = {
|
||||
"lang": {
|
||||
"en": {
|
||||
"label": "Lang"
|
||||
},
|
||||
"zh": {
|
||||
"label": "语言"
|
||||
}
|
||||
},
|
||||
"model_name": {
|
||||
"en": {
|
||||
"label": "Model name"
|
||||
},
|
||||
"zh": {
|
||||
"label": "模型名称"
|
||||
}
|
||||
},
|
||||
"model_path": {
|
||||
"en": {
|
||||
"label": "Model path",
|
||||
"info": "Path to pretrained model or model identifier from Hugging Face."
|
||||
},
|
||||
"zh": {
|
||||
"label": "模型路径",
|
||||
"info": "本地模型的文件路径或 Hugging Face 的模型标识符。"
|
||||
}
|
||||
},
|
||||
"finetuning_type": {
|
||||
"en": {
|
||||
"label": "Finetuning method"
|
||||
},
|
||||
"zh": {
|
||||
"label": "微调方法"
|
||||
}
|
||||
},
|
||||
"checkpoints": {
|
||||
"en": {
|
||||
"label": "Checkpoints"
|
||||
},
|
||||
"zh": {
|
||||
"label": "模型断点"
|
||||
}
|
||||
},
|
||||
"refresh_btn": {
|
||||
"en": {
|
||||
"value": "Refresh checkpoints"
|
||||
},
|
||||
"zh": {
|
||||
"value": "刷新断点"
|
||||
}
|
||||
},
|
||||
"advanced_tab": {
|
||||
"en": {
|
||||
"label": "Advanced configurations"
|
||||
},
|
||||
"zh": {
|
||||
"label": "高级设置"
|
||||
}
|
||||
},
|
||||
"quantization_bit": {
|
||||
"en": {
|
||||
"label": "Quantization bit (optional)",
|
||||
"info": "Enable 4/8-bit model quantization."
|
||||
},
|
||||
"zh": {
|
||||
"label": "量化等级(非必填)",
|
||||
"info": "启用 4/8 比特模型量化。"
|
||||
}
|
||||
},
|
||||
"template": {
|
||||
"en": {
|
||||
"label": "Prompt template",
|
||||
"info": "The template used in constructing prompts."
|
||||
},
|
||||
"zh": {
|
||||
"label": "提示模板",
|
||||
"info": "构建提示词时使用的模板"
|
||||
}
|
||||
},
|
||||
"source_prefix": {
|
||||
"en": {
|
||||
"label": "System prompt (optional)",
|
||||
"info": "A sequence used as the default system prompt."
|
||||
},
|
||||
"zh": {
|
||||
"label": "系统提示词(非必填)",
|
||||
"info": "默认使用的系统提示词"
|
||||
}
|
||||
},
|
||||
"dataset_dir": {
|
||||
"en": {
|
||||
"label": "Data dir",
|
||||
"info": "Path of the data directory."
|
||||
},
|
||||
"zh": {
|
||||
"label": "数据路径",
|
||||
"info": "数据文件夹的路径。"
|
||||
}
|
||||
},
|
||||
"dataset": {
|
||||
"en": {
|
||||
"label": "Dataset"
|
||||
},
|
||||
"zh": {
|
||||
"label": "数据集"
|
||||
}
|
||||
},
|
||||
"preview_btn": {
|
||||
"en": {
|
||||
"value": "Preview"
|
||||
},
|
||||
"zh": {
|
||||
"value": "预览"
|
||||
}
|
||||
},
|
||||
"preview_count": {
|
||||
"en": {
|
||||
"label": "Count"
|
||||
},
|
||||
"zh": {
|
||||
"label": "数量"
|
||||
}
|
||||
},
|
||||
"preview_samples": {
|
||||
"en": {
|
||||
"label": "Samples"
|
||||
},
|
||||
"zh": {
|
||||
"label": "样例"
|
||||
}
|
||||
},
|
||||
"close_btn": {
|
||||
"en": {
|
||||
"value": "Close"
|
||||
},
|
||||
"zh": {
|
||||
"value": "关闭"
|
||||
}
|
||||
},
|
||||
"max_source_length": {
|
||||
"en": {
|
||||
"label": "Max source length",
|
||||
"info": "Max tokens in source sequence."
|
||||
},
|
||||
"zh": {
|
||||
"label": "输入序列最大长度",
|
||||
"info": "输入序列分词后的最大长度。"
|
||||
}
|
||||
},
|
||||
"max_target_length": {
|
||||
"en": {
|
||||
"label": "Max target length",
|
||||
"info": "Max tokens in target sequence."
|
||||
},
|
||||
"zh": {
|
||||
"label": "输出序列最大长度",
|
||||
"info": "输出序列分词后的最大长度。"
|
||||
}
|
||||
},
|
||||
"learning_rate": {
|
||||
"en": {
|
||||
"label": "Learning rate",
|
||||
"info": "Initial learning rate for AdamW."
|
||||
},
|
||||
"zh": {
|
||||
"label": "学习率",
|
||||
"info": "AdamW 优化器的初始学习率。"
|
||||
}
|
||||
},
|
||||
"num_train_epochs": {
|
||||
"en": {
|
||||
"label": "Epochs",
|
||||
"info": "Total number of training epochs to perform."
|
||||
},
|
||||
"zh": {
|
||||
"label": "训练轮数",
|
||||
"info": "需要执行的训练总轮数。"
|
||||
}
|
||||
},
|
||||
"max_samples": {
|
||||
"en": {
|
||||
"label": "Max samples",
|
||||
"info": "Maximum samples per dataset."
|
||||
},
|
||||
"zh": {
|
||||
"label": "最大样本数",
|
||||
"info": "每个数据集最多使用的样本数。"
|
||||
}
|
||||
},
|
||||
"batch_size": {
|
||||
"en": {
|
||||
"label": "Batch size",
|
||||
"info": "Number of samples to process per GPU."
|
||||
},
|
||||
"zh":{
|
||||
"label": "批处理大小",
|
||||
"info": "每块 GPU 上处理的样本数量。"
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": {
|
||||
"en": {
|
||||
"label": "Gradient accumulation",
|
||||
"info": "Number of gradient accumulation steps."
|
||||
},
|
||||
"zh": {
|
||||
"label": "梯度累积",
|
||||
"info": "梯度累积的步数。"
|
||||
}
|
||||
},
|
||||
"lr_scheduler_type": {
|
||||
"en": {
|
||||
"label": "LR Scheduler",
|
||||
"info": "Name of learning rate scheduler.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "学习率调节器",
|
||||
"info": "采用的学习率调节器名称。"
|
||||
}
|
||||
},
|
||||
"max_grad_norm": {
|
||||
"en": {
|
||||
"label": "Maximum gradient norm",
|
||||
"info": "Norm for gradient clipping.."
|
||||
},
|
||||
"zh": {
|
||||
"label": "最大梯度范数",
|
||||
"info": "用于梯度裁剪的范数。"
|
||||
}
|
||||
},
|
||||
"dev_ratio": {
|
||||
"en": {
|
||||
"label": "Dev ratio",
|
||||
"info": "Proportion of data in the dev set."
|
||||
},
|
||||
"zh": {
|
||||
"label": "验证集比例",
|
||||
"info": "验证集占全部样本的百分比。"
|
||||
}
|
||||
},
|
||||
"logging_steps": {
|
||||
"en": {
|
||||
"label": "Logging steps",
|
||||
"info": "Number of steps between two logs."
|
||||
},
|
||||
"zh": {
|
||||
"label": "日志间隔",
|
||||
"info": "每两次日志输出间的更新步数。"
|
||||
}
|
||||
},
|
||||
"save_steps": {
|
||||
"en": {
|
||||
"label": "Save steps",
|
||||
"info": "Number of steps between two checkpoints."
|
||||
},
|
||||
"zh": {
|
||||
"label": "保存间隔",
|
||||
"info": "每两次断点保存间的更新步数。"
|
||||
}
|
||||
},
|
||||
"warmup_steps": {
|
||||
"en": {
|
||||
"label": "Warmup steps",
|
||||
"info": "Number of steps used for warmup."
|
||||
},
|
||||
"zh": {
|
||||
"label": "预热步数",
|
||||
"info": "学习率预热采用的步数。"
|
||||
}
|
||||
},
|
||||
"compute_type": {
|
||||
"en": {
|
||||
"label": "Compute type",
|
||||
"info": "Whether to use fp16 or bf16 mixed precision training."
|
||||
},
|
||||
"zh": {
|
||||
"label": "计算类型",
|
||||
"info": "是否启用 FP16 或 BF16 混合精度训练。"
|
||||
}
|
||||
},
|
||||
"lora_tab": {
|
||||
"en": {
|
||||
"label": "LoRA configurations"
|
||||
},
|
||||
"zh": {
|
||||
"label": "LoRA 参数设置"
|
||||
}
|
||||
},
|
||||
"lora_rank": {
|
||||
"en": {
|
||||
"label": "LoRA rank",
|
||||
"info": "The rank of LoRA matrices."
|
||||
},
|
||||
"zh": {
|
||||
"label": "LoRA 秩",
|
||||
"info": "LoRA 矩阵的秩。"
|
||||
}
|
||||
},
|
||||
"lora_dropout": {
|
||||
"en": {
|
||||
"label": "LoRA Dropout",
|
||||
"info": "Dropout ratio of LoRA weights."
|
||||
},
|
||||
"zh": {
|
||||
"label": "LoRA 随机丢弃",
|
||||
"info": "LoRA 权重随机丢弃的概率。"
|
||||
}
|
||||
},
|
||||
"lora_target": {
|
||||
"en": {
|
||||
"label": "LoRA modules (optional)",
|
||||
"info": "The name(s) of target modules to apply LoRA. Use commas to separate multiple modules."
|
||||
},
|
||||
"zh": {
|
||||
"label": "LoRA 作用层(非必填)",
|
||||
"info": "应用 LoRA 的线性层名称。使用英文逗号分隔多个名称。"
|
||||
}
|
||||
},
|
||||
"start_btn": {
|
||||
"en": {
|
||||
"value": "Start"
|
||||
},
|
||||
"zh": {
|
||||
"value": "开始"
|
||||
}
|
||||
},
|
||||
"stop_btn": {
|
||||
"en": {
|
||||
"value": "Abort"
|
||||
},
|
||||
"zh": {
|
||||
"value": "中断"
|
||||
}
|
||||
},
|
||||
"output_dir": {
|
||||
"en": {
|
||||
"label": "Checkpoint name",
|
||||
"info": "Directory to save checkpoint."
|
||||
},
|
||||
"zh": {
|
||||
"label": "断点名称",
|
||||
"info": "保存模型断点的文件夹名称。"
|
||||
}
|
||||
},
|
||||
"output_box": {
|
||||
"en": {
|
||||
"value": "Ready."
|
||||
},
|
||||
"zh": {
|
||||
"value": "准备就绪。"
|
||||
}
|
||||
},
|
||||
"loss_viewer": {
|
||||
"en": {
|
||||
"label": "Loss"
|
||||
},
|
||||
"zh": {
|
||||
"label": "损失"
|
||||
}
|
||||
},
|
||||
"predict": {
|
||||
"en": {
|
||||
"label": "Save predictions"
|
||||
},
|
||||
"zh": {
|
||||
"label": "保存预测结果"
|
||||
}
|
||||
},
|
||||
"load_btn": {
|
||||
"en": {
|
||||
"value": "Load model"
|
||||
},
|
||||
"zh": {
|
||||
"value": "加载模型"
|
||||
}
|
||||
},
|
||||
"unload_btn": {
|
||||
"en": {
|
||||
"value": "Unload model"
|
||||
},
|
||||
"zh": {
|
||||
"value": "卸载模型"
|
||||
}
|
||||
},
|
||||
"info_box": {
|
||||
"en": {
|
||||
"value": "Model unloaded, please load a model first."
|
||||
},
|
||||
"zh": {
|
||||
"value": "模型未加载,请先加载模型。"
|
||||
}
|
||||
},
|
||||
"prefix": {
|
||||
"en": {
|
||||
"placeholder": "System prompt (optional)"
|
||||
},
|
||||
"zh": {
|
||||
"placeholder": "系统提示词(非必填)"
|
||||
}
|
||||
},
|
||||
"query": {
|
||||
"en": {
|
||||
"placeholder": "Input..."
|
||||
},
|
||||
"zh": {
|
||||
"placeholder": "输入..."
|
||||
}
|
||||
},
|
||||
"submit_btn": {
|
||||
"en": {
|
||||
"value": "Submit"
|
||||
},
|
||||
"zh": {
|
||||
"value": "提交"
|
||||
}
|
||||
},
|
||||
"clear_btn": {
|
||||
"en": {
|
||||
"value": "Clear history"
|
||||
},
|
||||
"zh": {
|
||||
"value": "清空历史"
|
||||
}
|
||||
},
|
||||
"max_length": {
|
||||
"en": {
|
||||
"label": "Maximum length"
|
||||
},
|
||||
"zh": {
|
||||
"label": "最大长度"
|
||||
}
|
||||
},
|
||||
"max_new_tokens": {
|
||||
"en": {
|
||||
"label": "Maximum new tokens"
|
||||
},
|
||||
"zh": {
|
||||
"label": "最大生成长度"
|
||||
}
|
||||
},
|
||||
"top_p": {
|
||||
"en": {
|
||||
"label": "Top-p"
|
||||
},
|
||||
"zh": {
|
||||
"label": "Top-p 采样值"
|
||||
}
|
||||
},
|
||||
"temperature": {
|
||||
"en": {
|
||||
"label": "Temperature"
|
||||
},
|
||||
"zh": {
|
||||
"label": "温度系数"
|
||||
}
|
||||
},
|
||||
"save_dir": {
|
||||
"en": {
|
||||
"label": "Export dir",
|
||||
"info": "Directory to save exported model."
|
||||
},
|
||||
"zh": {
|
||||
"label": "导出目录",
|
||||
"info": "保存导出模型的文件夹路径。"
|
||||
}
|
||||
},
|
||||
"max_shard_size": {
|
||||
"en": {
|
||||
"label": "Max shard size (GB)",
|
||||
"info": "The maximum size for a model file."
|
||||
},
|
||||
"zh": {
|
||||
"label": "最大分块大小(GB)",
|
||||
"info": "模型文件的最大大小。"
|
||||
}
|
||||
},
|
||||
"export_btn": {
|
||||
"en": {
|
||||
"value": "Export"
|
||||
},
|
||||
"zh": {
|
||||
"value": "开始导出"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
ALERTS = {
|
||||
"err_conflict": {
|
||||
"en": "A process is in running, please abort it firstly.",
|
||||
"zh": "任务已存在,请先中断训练。"
|
||||
},
|
||||
"err_exists": {
|
||||
"en": "You have loaded a model, please unload it first.",
|
||||
"zh": "模型已存在,请先卸载模型。"
|
||||
},
|
||||
"err_no_model": {
|
||||
"en": "Please select a model.",
|
||||
"zh": "请选择模型。"
|
||||
},
|
||||
"err_no_path": {
|
||||
"en": "Model not found.",
|
||||
"zh": "模型未找到。"
|
||||
},
|
||||
"err_no_dataset": {
|
||||
"en": "Please choose a dataset.",
|
||||
"zh": "请选择数据集。"
|
||||
},
|
||||
"err_no_checkpoint": {
|
||||
"en": "Please select a checkpoint.",
|
||||
"zh": "请选择断点。"
|
||||
},
|
||||
"err_no_save_dir": {
|
||||
"en": "Please provide export dir.",
|
||||
"zh": "请填写导出目录"
|
||||
},
|
||||
"info_aborting": {
|
||||
"en": "Aborted, wait for terminating...",
|
||||
"zh": "训练中断,正在等待线程结束……"
|
||||
},
|
||||
"info_aborted": {
|
||||
"en": "Ready.",
|
||||
"zh": "准备就绪。"
|
||||
},
|
||||
"info_finished": {
|
||||
"en": "Finished.",
|
||||
"zh": "训练完毕。"
|
||||
},
|
||||
"info_loading": {
|
||||
"en": "Loading model...",
|
||||
"zh": "加载中……"
|
||||
},
|
||||
"info_unloading": {
|
||||
"en": "Unloading model...",
|
||||
"zh": "卸载中……"
|
||||
},
|
||||
"info_loaded": {
|
||||
"en": "Model loaded, now you can chat with your model!",
|
||||
"zh": "模型已加载,可以开始聊天了!"
|
||||
},
|
||||
"info_unloaded": {
|
||||
"en": "Model unloaded.",
|
||||
"zh": "模型已卸载。"
|
||||
},
|
||||
"info_exporting": {
|
||||
"en": "Exporting model...",
|
||||
"zh": "正在导出模型……"
|
||||
},
|
||||
"info_exported": {
|
||||
"en": "Model exported.",
|
||||
"zh": "模型导出完成。"
|
||||
}
|
||||
}
|
||||
35
src/llmtuner/webui/manager.py
Normal file
35
src/llmtuner/webui/manager.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import gradio as gr
|
||||
from typing import Any, Dict, List
|
||||
from gradio.components import Component
|
||||
|
||||
from llmtuner.webui.common import get_model_path, list_dataset, load_config
|
||||
from llmtuner.webui.locales import LOCALES
|
||||
from llmtuner.webui.utils import get_time
|
||||
|
||||
|
||||
class Manager:
|
||||
|
||||
def __init__(self, elem_list: List[Dict[str, Component]]):
|
||||
self.elem_list = elem_list
|
||||
|
||||
def gen_refresh(self) -> Dict[str, Any]:
|
||||
refresh_dict = {
|
||||
"dataset": {"choices": list_dataset()["choices"]},
|
||||
"output_dir": {"value": get_time()}
|
||||
}
|
||||
user_config = load_config()
|
||||
if user_config["last_model"]:
|
||||
refresh_dict["model_name"] = {"value": user_config["last_model"]}
|
||||
refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])}
|
||||
|
||||
return refresh_dict
|
||||
|
||||
def gen_label(self, lang: str) -> Dict[Component, dict]:
|
||||
update_dict = {}
|
||||
refresh_dict = self.gen_refresh()
|
||||
|
||||
for elems in self.elem_list:
|
||||
for name, component in elems.items():
|
||||
update_dict[component] = gr.update(**LOCALES[name][lang], **refresh_dict.get(name, {}))
|
||||
|
||||
return update_dict
|
||||
238
src/llmtuner/webui/runner.py
Normal file
238
src/llmtuner/webui/runner.py
Normal file
@@ -0,0 +1,238 @@
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import transformers
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.constants import DEFAULT_MODULE
|
||||
from llmtuner.extras.logging import LoggerHandler
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.tuner import get_train_args, run_sft
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
from llmtuner.webui.utils import format_info, get_eval_results
|
||||
|
||||
|
||||
class Runner:
|
||||
|
||||
def __init__(self):
|
||||
self.aborted = False
|
||||
self.running = False
|
||||
|
||||
def set_abort(self):
|
||||
self.aborted = True
|
||||
self.running = False
|
||||
|
||||
def initialize(
|
||||
self, lang: str, model_name: str, dataset: List[str]
|
||||
) -> Tuple[str, str, LoggerHandler, LogCallback]:
|
||||
if self.running:
|
||||
return None, ALERTS["err_conflict"][lang], None, None
|
||||
|
||||
if not model_name:
|
||||
return None, ALERTS["err_no_model"][lang], None, None
|
||||
|
||||
model_name_or_path = get_model_path(model_name)
|
||||
if not model_name_or_path:
|
||||
return None, ALERTS["err_no_path"][lang], None, None
|
||||
|
||||
if len(dataset) == 0:
|
||||
return None, ALERTS["err_no_dataset"][lang], None, None
|
||||
|
||||
self.aborted = False
|
||||
self.running = True
|
||||
|
||||
logger_handler = LoggerHandler()
|
||||
logger_handler.setLevel(logging.INFO)
|
||||
logging.root.addHandler(logger_handler)
|
||||
transformers.logging.add_handler(logger_handler)
|
||||
trainer_callback = LogCallback(self)
|
||||
|
||||
return model_name_or_path, "", logger_handler, trainer_callback
|
||||
|
||||
def finalize(
|
||||
self, lang: str, finish_info: Optional[str] = None
|
||||
) -> str:
|
||||
self.running = False
|
||||
torch_gc()
|
||||
if self.aborted:
|
||||
return ALERTS["info_aborted"][lang]
|
||||
else:
|
||||
return finish_info if finish_info is not None else ALERTS["info_finished"][lang]
|
||||
|
||||
def run_train(
|
||||
self,
|
||||
lang: str,
|
||||
model_name: str,
|
||||
checkpoints: List[str],
|
||||
finetuning_type: str,
|
||||
quantization_bit: str,
|
||||
template: str,
|
||||
source_prefix: str,
|
||||
dataset_dir: str,
|
||||
dataset: List[str],
|
||||
max_source_length: int,
|
||||
max_target_length: int,
|
||||
learning_rate: str,
|
||||
num_train_epochs: str,
|
||||
max_samples: str,
|
||||
batch_size: int,
|
||||
gradient_accumulation_steps: int,
|
||||
lr_scheduler_type: str,
|
||||
max_grad_norm: str,
|
||||
dev_ratio: float,
|
||||
logging_steps: int,
|
||||
save_steps: int,
|
||||
warmup_steps: int,
|
||||
compute_type: str,
|
||||
lora_rank: int,
|
||||
lora_dropout: float,
|
||||
lora_target: str,
|
||||
output_dir: str
|
||||
) -> Generator[str, None, None]:
|
||||
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
||||
if error:
|
||||
yield error
|
||||
return
|
||||
|
||||
if checkpoints:
|
||||
checkpoint_dir = ",".join(
|
||||
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
|
||||
)
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
|
||||
args = dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
do_train=True,
|
||||
overwrite_cache=True,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||
prompt_template=template,
|
||||
source_prefix=source_prefix,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=",".join(dataset),
|
||||
max_source_length=max_source_length,
|
||||
max_target_length=max_target_length,
|
||||
learning_rate=float(learning_rate),
|
||||
num_train_epochs=float(num_train_epochs),
|
||||
max_samples=int(max_samples),
|
||||
per_device_train_batch_size=batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
lr_scheduler_type=lr_scheduler_type,
|
||||
max_grad_norm=float(max_grad_norm),
|
||||
logging_steps=logging_steps,
|
||||
save_steps=save_steps,
|
||||
warmup_steps=warmup_steps,
|
||||
fp16=(compute_type == "fp16"),
|
||||
bf16=(compute_type == "bf16"),
|
||||
lora_rank=lora_rank,
|
||||
lora_dropout=lora_dropout,
|
||||
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
|
||||
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
|
||||
)
|
||||
|
||||
if dev_ratio > 1e-6:
|
||||
args["dev_ratio"] = dev_ratio
|
||||
args["evaluation_strategy"] = "steps"
|
||||
args["eval_steps"] = save_steps
|
||||
args["load_best_model_at_end"] = True
|
||||
|
||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
|
||||
|
||||
run_args = dict(
|
||||
model_args=model_args,
|
||||
data_args=data_args,
|
||||
training_args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
callbacks=[trainer_callback]
|
||||
)
|
||||
thread = threading.Thread(target=run_sft, kwargs=run_args)
|
||||
thread.start()
|
||||
|
||||
while thread.is_alive():
|
||||
time.sleep(1)
|
||||
if self.aborted:
|
||||
yield ALERTS["info_aborting"][lang]
|
||||
else:
|
||||
yield format_info(logger_handler.log, trainer_callback.tracker)
|
||||
|
||||
yield self.finalize(lang)
|
||||
|
||||
def run_eval(
|
||||
self,
|
||||
lang: str,
|
||||
model_name: str,
|
||||
checkpoints: List[str],
|
||||
finetuning_type: str,
|
||||
quantization_bit: str,
|
||||
template: str,
|
||||
source_prefix: str,
|
||||
dataset_dir: str,
|
||||
dataset: List[str],
|
||||
max_source_length: int,
|
||||
max_target_length: int,
|
||||
max_samples: str,
|
||||
batch_size: int,
|
||||
predict: bool
|
||||
) -> Generator[str, None, None]:
|
||||
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
||||
if error:
|
||||
yield error
|
||||
return
|
||||
|
||||
if checkpoints:
|
||||
checkpoint_dir = ",".join(
|
||||
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
|
||||
)
|
||||
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_" + "_".join(checkpoints))
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_base")
|
||||
|
||||
args = dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
do_eval=True,
|
||||
overwrite_cache=True,
|
||||
predict_with_generate=True,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||
prompt_template=template,
|
||||
source_prefix=source_prefix,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=",".join(dataset),
|
||||
max_source_length=max_source_length,
|
||||
max_target_length=max_target_length,
|
||||
max_samples=int(max_samples),
|
||||
per_device_eval_batch_size=batch_size,
|
||||
output_dir=output_dir
|
||||
)
|
||||
|
||||
if predict:
|
||||
args.pop("do_eval", None)
|
||||
args["do_predict"] = True
|
||||
|
||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
|
||||
|
||||
run_args = dict(
|
||||
model_args=model_args,
|
||||
data_args=data_args,
|
||||
training_args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
callbacks=[trainer_callback]
|
||||
)
|
||||
thread = threading.Thread(target=run_sft, kwargs=run_args)
|
||||
thread.start()
|
||||
|
||||
while thread.is_alive():
|
||||
time.sleep(1)
|
||||
if self.aborted:
|
||||
yield ALERTS["info_aborting"][lang]
|
||||
else:
|
||||
yield format_info(logger_handler.log, trainer_callback.tracker)
|
||||
|
||||
yield self.finalize(lang, get_eval_results(os.path.join(output_dir, "all_results.json")))
|
||||
125
src/llmtuner/webui/utils.py
Normal file
125
src/llmtuner/webui/utils.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import os
|
||||
import json
|
||||
import gradio as gr
|
||||
import matplotlib.figure
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import Any, Dict, Generator, List, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
from llmtuner.extras.ploting import smooth
|
||||
from llmtuner.tuner import get_infer_args, load_model_and_tokenizer
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
|
||||
|
||||
def format_info(log: str, tracker: dict) -> str:
|
||||
info = log
|
||||
if "current_steps" in tracker:
|
||||
info += "Running **{:d}/{:d}**: {} < {}\n".format(
|
||||
tracker["current_steps"], tracker["total_steps"], tracker["elapsed_time"], tracker["remaining_time"]
|
||||
)
|
||||
return info
|
||||
|
||||
|
||||
def get_time() -> str:
|
||||
return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
||||
|
||||
|
||||
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
if (
|
||||
len(dataset) > 0
|
||||
and "file_name" in dataset_info[dataset[0]]
|
||||
and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]))
|
||||
):
|
||||
return gr.update(interactive=True)
|
||||
else:
|
||||
return gr.update(interactive=False)
|
||||
|
||||
|
||||
def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, Dict[str, Any]]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
data_file = dataset_info[dataset[0]]["file_name"]
|
||||
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return len(data), data[:2], gr.update(visible=True)
|
||||
|
||||
|
||||
def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
||||
if finetuning_type != "lora":
|
||||
return gr.update(value="", interactive=False)
|
||||
else:
|
||||
return gr.update(interactive=True)
|
||||
|
||||
|
||||
def get_eval_results(path: os.PathLike) -> str:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
result = json.dumps(json.load(f), indent=4)
|
||||
return "```json\n{}\n```\n".format(result)
|
||||
|
||||
|
||||
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotlib.figure.Figure:
|
||||
log_file = os.path.join(get_save_dir(base_model), finetuning_type, output_dir, "trainer_log.jsonl")
|
||||
if not os.path.isfile(log_file):
|
||||
return None
|
||||
|
||||
plt.close("all")
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
steps, losses = [], []
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
log_info = json.loads(line)
|
||||
if log_info.get("loss", None):
|
||||
steps.append(log_info["current_steps"])
|
||||
losses.append(log_info["loss"])
|
||||
|
||||
if len(losses) == 0:
|
||||
return None
|
||||
|
||||
ax.plot(steps, losses, alpha=0.4, label="original")
|
||||
ax.plot(steps, smooth(losses), label="smoothed")
|
||||
ax.legend()
|
||||
ax.set_xlabel("step")
|
||||
ax.set_ylabel("loss")
|
||||
return fig
|
||||
|
||||
|
||||
def export_model(
|
||||
lang: str, model_name: str, checkpoints: List[str], finetuning_type: str, max_shard_size: int, save_dir: str
|
||||
) -> Generator[str, None, None]:
|
||||
if not model_name:
|
||||
yield ALERTS["err_no_model"][lang]
|
||||
return
|
||||
|
||||
model_name_or_path = get_model_path(model_name)
|
||||
if not model_name_or_path:
|
||||
yield ALERTS["err_no_path"][lang]
|
||||
return
|
||||
|
||||
if not checkpoints:
|
||||
yield ALERTS["err_no_checkpoint"][lang]
|
||||
return
|
||||
|
||||
checkpoint_dir = ",".join(
|
||||
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
|
||||
)
|
||||
|
||||
if not save_dir:
|
||||
yield ALERTS["err_no_save_dir"][lang]
|
||||
return
|
||||
|
||||
args = dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type
|
||||
)
|
||||
|
||||
yield ALERTS["info_exporting"][lang]
|
||||
model_args, _, finetuning_args, _ = get_infer_args(args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
model.save_pretrained(save_dir, max_shard_size=str(max_shard_size)+"GB")
|
||||
tokenizer.save_pretrained(save_dir)
|
||||
yield ALERTS["info_exported"][lang]
|
||||
@@ -1,4 +1,4 @@
|
||||
from llmtuner import get_train_args, run_pt, run_sft, run_rm, run_ppo
|
||||
from llmtuner.tuner import get_train_args, run_pt, run_sft, run_rm, run_ppo
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
11
src/train_web.py
Normal file
11
src/train_web.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from llmtuner.webui.interface import create_ui
|
||||
|
||||
|
||||
def main():
|
||||
demo = create_ui()
|
||||
demo.queue()
|
||||
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
103
src/web_demo.py
103
src/web_demo.py
@@ -3,93 +3,34 @@
|
||||
# Usage: python web_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
||||
|
||||
import gradio as gr
|
||||
from threading import Thread
|
||||
from transformers import TextIteratorStreamer
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from llmtuner import Template, get_infer_args, load_model_and_tokenizer, get_logits_processor
|
||||
from llmtuner.tuner import get_infer_args
|
||||
from llmtuner.webui.chat import WebChatModel
|
||||
from llmtuner.webui.components.chatbot import create_chat_box
|
||||
from llmtuner.webui.manager import Manager
|
||||
|
||||
|
||||
require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0")
|
||||
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
|
||||
|
||||
|
||||
model_args, data_args, finetuning_args, generating_args = get_infer_args()
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
def main():
|
||||
chat_model = WebChatModel(*get_infer_args())
|
||||
|
||||
prompt_template = Template(data_args.prompt_template)
|
||||
source_prefix = data_args.source_prefix if data_args.source_prefix else ""
|
||||
with gr.Blocks(title="Web Demo") as demo:
|
||||
lang = gr.Dropdown(choices=["en", "zh"], value="en")
|
||||
|
||||
_, _, _, chat_elems = create_chat_box(chat_model, visible=True)
|
||||
|
||||
manager = Manager([{"lang": lang}, chat_elems])
|
||||
|
||||
demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
|
||||
|
||||
lang.change(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
|
||||
|
||||
demo.queue()
|
||||
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
|
||||
|
||||
|
||||
def predict(query, chatbot, max_new_tokens, top_p, temperature, history):
|
||||
chatbot.append((query, ""))
|
||||
|
||||
input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"]
|
||||
input_ids = input_ids.to(model.device)
|
||||
|
||||
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||
|
||||
gen_kwargs = generating_args.to_dict()
|
||||
gen_kwargs.update({
|
||||
"input_ids": input_ids,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"logits_processor": get_logits_processor(),
|
||||
"streamer": streamer
|
||||
})
|
||||
|
||||
thread = Thread(target=model.generate, kwargs=gen_kwargs)
|
||||
thread.start()
|
||||
|
||||
response = ""
|
||||
for new_text in streamer:
|
||||
response += new_text
|
||||
new_history = history + [(query, response)]
|
||||
chatbot[-1] = (query, response)
|
||||
yield chatbot, new_history
|
||||
|
||||
|
||||
def reset_user_input():
|
||||
return gr.update(value="")
|
||||
|
||||
|
||||
def reset_state():
|
||||
return [], []
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
|
||||
gr.HTML("""
|
||||
<h1 align="center">
|
||||
<a href="https://github.com/hiyouga/LLaMA-Efficient-Tuning" target="_blank">
|
||||
LLaMA Efficient Tuning
|
||||
</a>
|
||||
</h1>
|
||||
""")
|
||||
|
||||
chatbot = gr.Chatbot()
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
with gr.Column(scale=12):
|
||||
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(container=False)
|
||||
with gr.Column(min_width=32, scale=1):
|
||||
submitBtn = gr.Button("Submit", variant="primary")
|
||||
|
||||
with gr.Column(scale=1):
|
||||
emptyBtn = gr.Button("Clear History")
|
||||
max_new_tokens = gr.Slider(10, 2048, value=generating_args.max_new_tokens, step=1.0,
|
||||
label="Maximum new tokens", interactive=True)
|
||||
top_p = gr.Slider(0.01, 1, value=generating_args.top_p, step=0.01,
|
||||
label="Top P", interactive=True)
|
||||
temperature = gr.Slider(0.01, 1.5, value=generating_args.temperature, step=0.01,
|
||||
label="Temperature", interactive=True)
|
||||
|
||||
history = gr.State([])
|
||||
|
||||
submitBtn.click(predict, [user_input, chatbot, max_new_tokens, top_p, temperature, history], [chatbot, history], show_progress=True)
|
||||
submitBtn.click(reset_user_input, [], [user_input])
|
||||
|
||||
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
|
||||
|
||||
demo.queue().launch(server_name="0.0.0.0", share=True, inbrowser=True)
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -300,6 +300,45 @@ class BaichuanPreTrainedModel(PreTrainedModel):
|
||||
if isinstance(module, BaichuanModel):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_standard_cache(
|
||||
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
|
||||
num_heads, ...]))
|
||||
"""
|
||||
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
|
||||
num_heads = batch_size_times_num_heads // batch_size
|
||||
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
|
||||
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
|
||||
return tuple(
|
||||
(
|
||||
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
|
||||
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
|
||||
)
|
||||
for layer_past in past_key_value
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_baichuan_cache(
|
||||
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Converts the cache to the format expected by Baichuan, i.e. to tuple(tuple([batch_size * num_heads, ...]))
|
||||
"""
|
||||
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
|
||||
batch_size_times_num_heads = batch_size * num_heads
|
||||
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
|
||||
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
|
||||
return tuple(
|
||||
(
|
||||
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
|
||||
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
|
||||
)
|
||||
for layer_past in past_key_value
|
||||
)
|
||||
|
||||
|
||||
class BaichuanModel(BaichuanPreTrainedModel):
|
||||
|
||||
@@ -318,9 +357,9 @@ class BaichuanModel(BaichuanPreTrainedModel):
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
self.embed_tokens = value
|
||||
|
||||
def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
|
||||
return build_alibi_tensor(attention_mask, num_heads, dtype)
|
||||
@@ -468,7 +507,7 @@ class BaichuanModel(BaichuanPreTrainedModel):
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
|
||||
class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
||||
|
||||
@@ -498,7 +537,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
@@ -528,7 +567,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
@@ -559,11 +598,20 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
||||
):
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
) -> dict:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
# the cache may be in the standard format (e.g. in contrastive search)
|
||||
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
|
||||
past_key_values = self._convert_to_baichuan_cache(past_key_values)
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
@@ -571,21 +619,38 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
{
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
return tuple(
|
||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past)
|
||||
for layer_past in past_key_values
|
||||
)
|
||||
def _reorder_cache(
|
||||
self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
|
||||
"""
|
||||
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
||||
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
||||
beam_idx at every generation step.
|
||||
|
||||
Output shares the same memory storage as `past`.
|
||||
"""
|
||||
standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
|
||||
|
||||
# Get a copy of `beam_idx` on all the devices where we need those indices.
|
||||
device_to_beam_idx = {
|
||||
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
|
||||
}
|
||||
reordered_past = tuple(
|
||||
(
|
||||
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
|
||||
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
|
||||
)
|
||||
for layer_past in standardized_past
|
||||
)
|
||||
return self._convert_to_baichuan_cache(reordered_past)
|
||||
|
||||
def quantize(self, bits: int):
|
||||
try:
|
||||
@@ -594,7 +659,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
||||
raise ImportError(
|
||||
f"Needs QLinear to run quantize."
|
||||
)
|
||||
|
||||
|
||||
for layer in self.model.layers:
|
||||
layer.self_attn.W_pack = QLinear(
|
||||
bits=bits,
|
||||
@@ -621,7 +686,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
||||
weight=layer.mlp.up_proj.weight,
|
||||
bias = None,
|
||||
)
|
||||
return self
|
||||
return self
|
||||
|
||||
def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0):
|
||||
max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
|
||||
|
||||
Reference in New Issue
Block a user