Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c4e9694c6e | ||
|
|
2006a96570 | ||
|
|
5dcd95645f | ||
|
|
9b3304b054 | ||
|
|
e580d4ef41 | ||
|
|
64db4abc68 | ||
|
|
5ba0b80e5c | ||
|
|
7a43ff3d89 | ||
|
|
7e1a1d141a | ||
|
|
6d881f161b | ||
|
|
a02b3e6192 | ||
|
|
bcdee9fc19 | ||
|
|
8b688251be | ||
|
|
718f3382ad | ||
|
|
dc8283d3d7 |
89
README.md
89
README.md
@@ -10,6 +10,8 @@
|
|||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
|
[23/07/19] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path meta-llama/Llama-2-7b-hf` argument to use the LLaMA-2 model. Remember to use `--prompt_template llama2` argument when you are using the LLaMA-2-chat model.
|
||||||
|
|
||||||
[23/07/18] Now we develop an all-in-one Web UI for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
|
[23/07/18] Now we develop an all-in-one Web UI for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
|
||||||
|
|
||||||
[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Please replace the Baichuan-13B model file with `tests/modeling_baichuan.py` and try `--model_name_or_path path_to_baichuan_model` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--prompt_template baichuan` argument when you are using the Baichuan-13B-Chat model.
|
[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Please replace the Baichuan-13B model file with `tests/modeling_baichuan.py` and try `--model_name_or_path path_to_baichuan_model` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--prompt_template baichuan` argument when you are using the Baichuan-13B-Chat model.
|
||||||
@@ -20,11 +22,11 @@
|
|||||||
|
|
||||||
[23/07/05] Now we support training the **Falcon-7B/40B** models in this repo. Try `--model_name_or_path tiiuae/falcon-7b` and `--lora_target query_key_value` arguments to use the Falcon model.
|
[23/07/05] Now we support training the **Falcon-7B/40B** models in this repo. Try `--model_name_or_path tiiuae/falcon-7b` and `--lora_target query_key_value` arguments to use the Falcon model.
|
||||||
|
|
||||||
[23/06/29] We provide a **reproducible example** of training a chat model using instruction-following datasets, see this [HuggingFace Repo](https://huggingface.co/hiyouga/baichuan-7b-sft) for details.
|
[23/06/29] We provide a **reproducible example** of training a chat model using instruction-following datasets, see this [Hugging Face Repo](https://huggingface.co/hiyouga/baichuan-7b-sft) for details.
|
||||||
|
|
||||||
[23/06/22] Now we align the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in **arbitrary ChatGPT-based applications**.
|
[23/06/22] Now we align the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in **arbitrary ChatGPT-based applications**.
|
||||||
|
|
||||||
[23/06/15] Now we support training the **Baichuan-7B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-7B` and `--lora_target W_pack` arguments to use the Baichuan-7B model. If you want to train with RTX3090, use `git checkout baichuan-7b-rtx3090` to switch to the `baichuan-7b-rtx3090` branch and try the `--baichuan_rtx_gpu true` argument. (Other RTX series GPUs can also be tried)
|
[23/06/15] Now we support training the **Baichuan-7B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-7B` and `--lora_target W_pack` arguments to use the Baichuan-7B model.
|
||||||
|
|
||||||
[23/06/03] Now we support quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature)
|
[23/06/03] Now we support quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature)
|
||||||
|
|
||||||
@@ -33,6 +35,7 @@
|
|||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
- [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B)
|
- [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B)
|
||||||
|
- [LLaMA-2](https://huggingface.co/meta-llama) (7B/13B/70B)
|
||||||
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B)
|
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B)
|
||||||
- [Falcon](https://huggingface.co/tiiuae/falcon-7b) (7B/40B)
|
- [Falcon](https://huggingface.co/tiiuae/falcon-7b) (7B/40B)
|
||||||
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B) (7B/13B)
|
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B) (7B/13B)
|
||||||
@@ -57,36 +60,36 @@
|
|||||||
## Provided Datasets
|
## Provided Datasets
|
||||||
|
|
||||||
- For pre-training:
|
- For pre-training:
|
||||||
- [Wiki Demo](data/wiki_demo.txt)
|
- [Wiki Demo (en)](data/wiki_demo.txt)
|
||||||
- For supervised fine-tuning:
|
- For supervised fine-tuning:
|
||||||
- [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca)
|
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||||
- [Stanford Alpaca (Chinese)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||||
- [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||||
- [BELLE 2M](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||||
- [BELLE 1M](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
- [Self-cognition (zh)](data/self_cognition.json)
|
||||||
- [BELLE 0.5M](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
|
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||||
- [BELLE Dialogue 0.4M](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
|
- [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
|
||||||
- [BELLE School Math 0.25M](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
|
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||||
- [BELLE Multiturn Chat 0.8M](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
||||||
- [Guanaco Dataset](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
||||||
- [Firefly 1.1M](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
|
||||||
- [CodeAlpaca 20k](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
|
||||||
- [Alpaca CoT](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
|
||||||
- [Web QA (Chinese)](https://huggingface.co/datasets/suolyer/webqa)
|
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
||||||
- [UltraChat](https://github.com/thunlp/UltraChat)
|
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||||
- [Open Assistant](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||||
- [Open Assistant (Chinese)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||||
- [WebNovel (Chinese)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||||
- For reward model training:
|
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
||||||
- [HH-RLHF](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||||
- [Open Assistant](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
- For reward modelling:
|
||||||
- [Open Assistant (Chinese)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||||
- [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||||
- [GPT-4 Generated Data (Chinese)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||||
|
|
||||||
Please refer to [data/README.md](data/README.md) for details.
|
Please refer to [data/README.md](data/README.md) for details.
|
||||||
|
|
||||||
Some datasets require confirmation before using them, so we recommend logging in with your HuggingFace account using these commands.
|
Some datasets require confirmation before using them, so we recommend logging in with your Hugging Face account using these commands.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install --upgrade huggingface_hub
|
pip install --upgrade huggingface_hub
|
||||||
@@ -260,34 +263,55 @@ use_cpu: false
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage pt \
|
--stage sft \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
--do_eval \
|
--do_eval \
|
||||||
--dataset alpaca_gpt4_en \
|
--dataset alpaca_gpt4_en \
|
||||||
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint \
|
--checkpoint_dir path_to_checkpoint \
|
||||||
--output_dir path_to_eval_result \
|
--output_dir path_to_eval_result \
|
||||||
--per_device_eval_batch_size 8 \
|
--per_device_eval_batch_size 8 \
|
||||||
--max_samples 50 \
|
--max_samples 100 \
|
||||||
--predict_with_generate
|
--predict_with_generate
|
||||||
```
|
```
|
||||||
|
|
||||||
We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation.
|
We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation.
|
||||||
|
|
||||||
|
### Predict
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--model_name_or_path path_to_your_model \
|
||||||
|
--do_predict \
|
||||||
|
--dataset alpaca_gpt4_en \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--checkpoint_dir path_to_checkpoint \
|
||||||
|
--output_dir path_to_predict_result \
|
||||||
|
--per_device_eval_batch_size 8 \
|
||||||
|
--max_samples 100 \
|
||||||
|
--predict_with_generate
|
||||||
|
```
|
||||||
|
|
||||||
|
If you want to predict the samples with empty responses, please kindly fill the `response` column with **dummy tokens** to ensure the sample will not be discarded throughout the preprocessing phase.
|
||||||
|
|
||||||
### API Demo
|
### API Demo
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/api_demo.py \
|
python src/api_demo.py \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint
|
--checkpoint_dir path_to_checkpoint
|
||||||
```
|
```
|
||||||
|
|
||||||
See `http://localhost:8000/docs` for API documentation.
|
Visit `http://localhost:8000/docs` for API documentation.
|
||||||
|
|
||||||
### CLI Demo
|
### CLI Demo
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/cli_demo.py \
|
python src/cli_demo.py \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint
|
--checkpoint_dir path_to_checkpoint
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -296,6 +320,7 @@ python src/cli_demo.py \
|
|||||||
```bash
|
```bash
|
||||||
python src/web_demo.py \
|
python src/web_demo.py \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint
|
--checkpoint_dir path_to_checkpoint
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -304,6 +329,7 @@ python src/web_demo.py \
|
|||||||
```bash
|
```bash
|
||||||
python src/export_model.py \
|
python src/export_model.py \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint \
|
--checkpoint_dir path_to_checkpoint \
|
||||||
--output_dir path_to_export
|
--output_dir path_to_export
|
||||||
```
|
```
|
||||||
@@ -315,6 +341,7 @@ This repository is licensed under the [Apache-2.0 License](LICENSE).
|
|||||||
Please follow the model licenses to use the corresponding model weights:
|
Please follow the model licenses to use the corresponding model weights:
|
||||||
|
|
||||||
- [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)
|
- [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)
|
||||||
|
- [LLaMA-2](https://ai.meta.com/llama/license/)
|
||||||
- [BLOOM](https://huggingface.co/spaces/bigscience/license)
|
- [BLOOM](https://huggingface.co/spaces/bigscience/license)
|
||||||
- [Falcon](LICENSE)
|
- [Falcon](LICENSE)
|
||||||
- [baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
|
- [baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
Data format in `dataset_info.json`:
|
If you are using a custom dataset, please provide your dataset definition in the following format in `dataset_info.json`.
|
||||||
|
|
||||||
```json
|
```json
|
||||||
"dataset_name": {
|
"dataset_name": {
|
||||||
"hf_hub_url": "the name of the dataset repository on the HuggingFace hub. (if specified, ignore below 3 arguments)",
|
"hf_hub_url": "the name of the dataset repository on the HuggingFace hub. (if specified, ignore below 3 arguments)",
|
||||||
@@ -14,40 +15,4 @@ Data format in `dataset_info.json`:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
`dataset_info.json` 中的数据集定义格式:
|
where the `prompt` and `response` columns should contain non-empty values. The `query` column will be concatenated with the `prompt` column and used as input for the model. The `history` column should contain a list where each element is a string tuple representing a query-response pair.
|
||||||
```json
|
|
||||||
"数据集名称": {
|
|
||||||
"hf_hub_url": "HuggingFace上的项目地址(若指定,则忽略下列三个参数)",
|
|
||||||
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
|
|
||||||
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
|
|
||||||
"file_sha1": "数据集文件的SHA-1哈希值(可选)",
|
|
||||||
"columns": {
|
|
||||||
"prompt": "数据集代表提示词的表头名称(默认:instruction)",
|
|
||||||
"query": "数据集代表请求的表头名称(默认:input)",
|
|
||||||
"response": "数据集代表回答的表头名称(默认:output)",
|
|
||||||
"history": "数据集代表历史对话的表头名称(默认:None)"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
部分预置数据集简介:
|
|
||||||
|
|
||||||
| 数据集名称 | 规模 | 描述 |
|
|
||||||
| --- | --- | --- |
|
|
||||||
| [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) | 52k | 斯坦福大学开源的 Alpaca 数据集,训练了 Alpaca 这类早期基于 LLaMA 的模型 |
|
|
||||||
| [Stanford Alpaca (Chinese)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) | 51k | 使用 ChatGPT 翻译的 Alpaca 数据集 |
|
|
||||||
| [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) | 100k+ | 基于 GPT-4 的 self-instruction 数据集 |
|
|
||||||
| [BELLE 2M](https://huggingface.co/datasets/BelleGroup/train_2M_CN) | 2m | 包含约 200 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 |
|
|
||||||
| [BELLE 1M](https://huggingface.co/datasets/BelleGroup/train_1M_CN) | 1m | 包含约 100 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 |
|
|
||||||
| [BELLE 0.5M](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN) | 500k | 包含约 50 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 |
|
|
||||||
| [BELLE Dialogue 0.4M](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M) | 400k | 包含约 40 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的个性化角色对话数据,包含角色介绍 |
|
|
||||||
| [BELLE School Math 0.25M](https://huggingface.co/datasets/BelleGroup/school_math_0.25M) | 250k | 包含约 25 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文数学题数据,包含解题过程 |
|
|
||||||
| [BELLE Multiturn Chat 0.8M](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M) | 800k | 包含约 80 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的用户与助手的多轮对话 |
|
|
||||||
| [Guanaco Dataset](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) | 100k+ | 包含日文、简繁体中文、英文等多类数据,数据集原用于 Guanaco 模型训练 |
|
|
||||||
| [Firefly 1.1M](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) | 1.1M | 中文对话大模型 firefly(流萤)的中文数据集,包含多个 NLP 任务 |
|
|
||||||
| [CodeAlpaca 20k](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) | 20k | 英文代码生成任务数据集 |
|
|
||||||
| [Alpaca CoT](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) | 6M | 用于微调的指令数据集集合 |
|
|
||||||
| [Web QA](https://huggingface.co/datasets/suolyer/webqa) | 36k | 百度知道汇集的中文问答数据集 |
|
|
||||||
| [UltraChat](https://github.com/thunlp/UltraChat) | 1.57M | 清华 NLP 发布的大规模多轮对话数据集 |
|
|
||||||
|
|
||||||
注:BELLE 数据集是由 ChatGPT 产生的数据集,不保证数据准确性,所有类 GPT 模型产生的 self-instruction 数据集均不能保证其准确性。
|
|
||||||
|
|||||||
@@ -1,2 +0,0 @@
|
|||||||
{"id": 0,"title": "大卫·亨利","content": "大卫·亨利\n\n大卫·克莱顿·亨利(David Clayton Henrie,),美国演员。近来在迪士尼频道原创电视影集《少年魔法师》(Wizards of Waverly Place)当中演出贾斯汀·鲁索(Justin Russo)一角。\n\n大卫·亨利出生在加州Mission Viejo,在凤凰城长大。他的胞弟劳伦斯·亨利(Lorenzo Henrie)也是演员。大卫·亨利就读夏安传统学校。家中是信奉罗马天主教。 \n\n大卫在2007年拍摄少年魔法师期间认识女演员露西·海尔(Lucy Hale),之后与其交往,于2009年分手。\n\n10岁时,大卫·亨利和SAG在凤凰城签订了合约,并开始走出去试镜。 9岁的时候,在沙加缅度进行商业拍摄,SAG董事建议大卫·亨利搬到洛杉矶。在10岁那年夏天,他和他的家人搬到了好莱坞。他预定他的前2支商业试镜,扮演主要角色为汉堡王和桂格燕麦。他初演电视节目为Providence。 \n\n到了13岁,大卫有了他的第一次重大突破,在福克斯公司的喜剧The Pitts饰演 Petey Pitt一角。大卫下出作品为的Hallmark movie为Monster Maker,和琳达布莱儿、乔治甘迺迪共同演出,并要求回来Hallmark movie公司。 \n\n在18岁时,大卫得到了迪士尼频道原创系列演出机会,该节目2007年10月12日首播。大卫2008年参加了迪士尼频道的游戏节目。他是绿色团队的队长,隔年,为旋风队队长。他在迪士尼原创电影《少年魔法师》之后在《酷爸的疯狂假期》中有饰演一角。\n"}
|
|
||||||
{"id": 1,"title": "大卫·亨利","content": "大卫·亨利\n\n大卫·克莱顿·亨利(David Clayton Henrie,),美国演员。近来在迪士尼频道原创电视影集《少年魔法师》(Wizards of Waverly Place)当中演出贾斯汀·鲁索(Justin Russo)一角。\n\n大卫·亨利出生在加州Mission Viejo,在凤凰城长大。他的胞弟劳伦斯·亨利(Lorenzo Henrie)也是演员。大卫·亨利就读夏安传统学校。家中是信奉罗马天主教。 \n\n大卫在2007年拍摄少年魔法师期间认识女演员露西·海尔(Lucy Hale),之后与其交往,于2009年分手。\n\n10岁时,大卫·亨利和SAG在凤凰城签订了合约,并开始走出去试镜。 9岁的时候,在沙加缅度进行商业拍摄,SAG董事建议大卫·亨利搬到洛杉矶。在10岁那年夏天,他和他的家人搬到了好莱坞。他预定他的前2支商业试镜,扮演主要角色为汉堡王和桂格燕麦。他初演电视节目为Providence。 \n\n到了13岁,大卫有了他的第一次重大突破,在福克斯公司的喜剧The Pitts饰演 Petey Pitt一角。大卫下出作品为的Hallmark movie为Monster Maker,和琳达布莱儿、乔治甘迺迪共同演出,并要求回来Hallmark movie公司。 \n\n在18岁时,大卫得到了迪士尼频道原创系列演出机会,该节目2007年10月12日首播。大卫2008年参加了迪士尼频道的游戏节目。他是绿色团队的队长,隔年,为旋风队队长。他在迪士尼原创电影《少年魔法师》之后在《酷爸的疯狂假期》中有饰演一角。\n"}
|
|
||||||
1
data/refgpt_zh_50k_p1.json.REMOVED.git-id
Normal file
1
data/refgpt_zh_50k_p1.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
|||||||
|
56405bb8f52727e52e99693739494b9b7b0d7ba6
|
||||||
1
data/refgpt_zh_50k_p2.json.REMOVED.git-id
Normal file
1
data/refgpt_zh_50k_p2.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
|||||||
|
fa935248a5d40d2bdd5649af99a72a754d40ae7a
|
||||||
1
data/sharegpt_zh_27k.json.REMOVED.git-id
Normal file
1
data/sharegpt_zh_27k.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
|||||||
|
38c89869c6aeca2a3af9ea1e09afe460f9b46810
|
||||||
@@ -10,7 +10,7 @@ rouge-chinese
|
|||||||
nltk
|
nltk
|
||||||
gradio>=3.36.0
|
gradio>=3.36.0
|
||||||
uvicorn
|
uvicorn
|
||||||
pydantic
|
pydantic==1.10.11
|
||||||
fastapi
|
fastapi==0.95.1
|
||||||
sse-starlette
|
sse-starlette
|
||||||
matplotlib
|
matplotlib
|
||||||
|
|||||||
@@ -5,9 +5,16 @@
|
|||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
from llmtuner import create_app
|
from llmtuner import ChatModel
|
||||||
|
from llmtuner.api.app import create_app
|
||||||
|
from llmtuner.tuner import get_infer_args
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
chat_model = ChatModel(*get_infer_args())
|
||||||
|
app = create_app(chat_model)
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
app = create_app()
|
main()
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
|
||||||
|
|||||||
@@ -2,7 +2,8 @@
|
|||||||
# Implements stream chat in command line for fine-tuned models.
|
# Implements stream chat in command line for fine-tuned models.
|
||||||
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
||||||
|
|
||||||
from llmtuner import ChatModel, get_infer_args
|
from llmtuner import ChatModel
|
||||||
|
from llmtuner.tuner import get_infer_args
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
# Exports the fine-tuned model.
|
# Exports the fine-tuned model.
|
||||||
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
|
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
|
||||||
|
|
||||||
from llmtuner import get_train_args, load_model_and_tokenizer
|
from llmtuner.tuner import get_train_args, load_model_and_tokenizer
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
@@ -1,7 +1,4 @@
|
|||||||
from llmtuner.api import create_app
|
|
||||||
from llmtuner.chat import ChatModel
|
from llmtuner.chat import ChatModel
|
||||||
from llmtuner.tuner import get_train_args, get_infer_args, load_model_and_tokenizer, run_pt, run_sft, run_rm, run_ppo
|
|
||||||
from llmtuner.webui import create_ui
|
|
||||||
|
|
||||||
|
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.1.2"
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
from llmtuner.api.app import create_app
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
@@ -31,9 +30,7 @@ async def lifespan(app: FastAPI): # collects GPU memory
|
|||||||
torch_gc()
|
torch_gc()
|
||||||
|
|
||||||
|
|
||||||
def create_app():
|
def create_app(chat_model: ChatModel) -> FastAPI:
|
||||||
chat_model = ChatModel(*get_infer_args())
|
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
@@ -96,7 +93,7 @@ def create_app():
|
|||||||
finish_reason=None
|
finish_reason=None
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||||
yield json.dumps(chunk, ensure_ascii=False)
|
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||||
|
|
||||||
for new_text in chat_model.stream_chat(
|
for new_text in chat_model.stream_chat(
|
||||||
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
|
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
|
||||||
@@ -110,7 +107,7 @@ def create_app():
|
|||||||
finish_reason=None
|
finish_reason=None
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||||
yield json.dumps(chunk, ensure_ascii=False)
|
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0,
|
index=0,
|
||||||
@@ -118,12 +115,13 @@ def create_app():
|
|||||||
finish_reason=Finish.STOP
|
finish_reason=Finish.STOP
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||||
yield json.dumps(chunk, ensure_ascii=False)
|
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||||
yield "[DONE]"
|
yield "[DONE]"
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
app = create_app()
|
chat_model = ChatModel(*get_infer_args())
|
||||||
|
app = create_app(chat_model)
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from threading import Thread
|
|||||||
from transformers import TextIteratorStreamer
|
from transformers import TextIteratorStreamer
|
||||||
|
|
||||||
from llmtuner.extras.misc import get_logits_processor
|
from llmtuner.extras.misc import get_logits_processor
|
||||||
from llmtuner.extras.template import Template
|
from llmtuner.extras.template import get_template
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||||
from llmtuner.tuner import load_model_and_tokenizer
|
from llmtuner.tuner import load_model_and_tokenizer
|
||||||
|
|
||||||
@@ -19,14 +19,14 @@ class ChatModel:
|
|||||||
generating_args: GeneratingArguments
|
generating_args: GeneratingArguments
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||||
self.template = Template(data_args.prompt_template)
|
self.template = get_template(data_args.prompt_template)
|
||||||
self.source_prefix = data_args.source_prefix if data_args.source_prefix else ""
|
self.source_prefix = data_args.source_prefix or ""
|
||||||
self.generating_args = generating_args
|
self.generating_args = generating_args
|
||||||
|
|
||||||
def process_args(
|
def process_args(
|
||||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
|
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
|
||||||
) -> Tuple[Dict[str, Any], int]:
|
) -> Tuple[Dict[str, Any], int]:
|
||||||
prefix = prefix if prefix else self.source_prefix
|
prefix = prefix or self.source_prefix
|
||||||
|
|
||||||
inputs = self.tokenizer([self.template.get_prompt(query, history, prefix)], return_tensors="pt")
|
inputs = self.tokenizer([self.template.get_prompt(query, history, prefix)], return_tensors="pt")
|
||||||
inputs = inputs.to(self.model.device)
|
inputs = inputs.to(self.model.device)
|
||||||
@@ -81,5 +81,4 @@ class ChatModel:
|
|||||||
thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
|
thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
for new_text in streamer:
|
yield from streamer
|
||||||
yield new_text
|
|
||||||
|
|||||||
@@ -1,2 +1,3 @@
|
|||||||
from llmtuner.dsets.loader import get_dataset
|
from llmtuner.dsets.loader import get_dataset
|
||||||
from llmtuner.dsets.preprocess import preprocess_dataset
|
from llmtuner.dsets.preprocess import preprocess_dataset
|
||||||
|
from llmtuner.dsets.utils import split_dataset
|
||||||
|
|||||||
@@ -1,63 +0,0 @@
|
|||||||
import os
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from datetime import timedelta
|
|
||||||
|
|
||||||
from transformers import (
|
|
||||||
TrainerCallback,
|
|
||||||
TrainerControl,
|
|
||||||
TrainerState,
|
|
||||||
TrainingArguments
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LogCallback(TrainerCallback):
|
|
||||||
|
|
||||||
def __init__(self, runner=None):
|
|
||||||
self.runner = runner
|
|
||||||
self.start_time = time.time()
|
|
||||||
self.tracker = {}
|
|
||||||
|
|
||||||
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
|
||||||
r"""
|
|
||||||
Event called at the beginning of a training step. If using gradient accumulation, one training step
|
|
||||||
might take several inputs.
|
|
||||||
"""
|
|
||||||
if self.runner is not None and self.runner.aborted:
|
|
||||||
control.should_epoch_stop = True
|
|
||||||
control.should_training_stop = True
|
|
||||||
|
|
||||||
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
|
||||||
r"""
|
|
||||||
Event called at the end of an substep during gradient accumulation.
|
|
||||||
"""
|
|
||||||
if self.runner is not None and self.runner.aborted:
|
|
||||||
control.should_epoch_stop = True
|
|
||||||
control.should_training_stop = True
|
|
||||||
|
|
||||||
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
|
|
||||||
r"""
|
|
||||||
Event called after logging the last logs.
|
|
||||||
"""
|
|
||||||
if "loss" not in state.log_history[-1]:
|
|
||||||
return
|
|
||||||
cur_time = time.time()
|
|
||||||
cur_steps = state.log_history[-1].get("step")
|
|
||||||
elapsed_time = cur_time - self.start_time
|
|
||||||
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
|
|
||||||
remaining_steps = state.max_steps - cur_steps
|
|
||||||
remaining_time = remaining_steps * avg_time_per_step
|
|
||||||
self.tracker = {
|
|
||||||
"current_steps": cur_steps,
|
|
||||||
"total_steps": state.max_steps,
|
|
||||||
"loss": state.log_history[-1].get("loss", None),
|
|
||||||
"reward": state.log_history[-1].get("reward", None),
|
|
||||||
"learning_rate": state.log_history[-1].get("learning_rate", None),
|
|
||||||
"epoch": state.log_history[-1].get("epoch", None),
|
|
||||||
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
|
|
||||||
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
|
|
||||||
"remaining_time": str(timedelta(seconds=int(remaining_time)))
|
|
||||||
}
|
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
|
||||||
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
|
||||||
f.write(json.dumps(self.tracker) + "\n")
|
|
||||||
@@ -6,7 +6,7 @@ from transformers.tokenization_utils import PreTrainedTokenizer
|
|||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.template import Template
|
from llmtuner.extras.template import get_template
|
||||||
from llmtuner.hparams import DataArguments
|
from llmtuner.hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
@@ -19,7 +19,7 @@ def preprocess_dataset(
|
|||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
|
|
||||||
column_names = list(dataset.column_names)
|
column_names = list(dataset.column_names)
|
||||||
prompt_template = Template(data_args.prompt_template)
|
prompt_template = get_template(data_args.prompt_template)
|
||||||
|
|
||||||
# support question with a single answer or multiple answers
|
# support question with a single answer or multiple answers
|
||||||
def get_dialog(examples):
|
def get_dialog(examples):
|
||||||
|
|||||||
16
src/llmtuner/dsets/utils.py
Normal file
16
src/llmtuner/dsets/utils.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
from typing import Dict
|
||||||
|
from datasets import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
def split_dataset(
|
||||||
|
dataset: Dataset, dev_ratio: float, do_train: bool
|
||||||
|
) -> Dict[str, Dataset]:
|
||||||
|
# Split the dataset
|
||||||
|
if do_train:
|
||||||
|
if dev_ratio > 1e-6:
|
||||||
|
dataset = dataset.train_test_split(test_size=dev_ratio)
|
||||||
|
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
||||||
|
else:
|
||||||
|
return {"train_dataset": dataset}
|
||||||
|
else: # do_eval or do_predict
|
||||||
|
return {"eval_dataset": dataset}
|
||||||
@@ -13,6 +13,12 @@ SUPPORTED_MODELS = {
|
|||||||
"LLaMA-13B": "huggyllama/llama-13b",
|
"LLaMA-13B": "huggyllama/llama-13b",
|
||||||
"LLaMA-30B": "huggyllama/llama-30b",
|
"LLaMA-30B": "huggyllama/llama-30b",
|
||||||
"LLaMA-65B": "huggyllama/llama-65b",
|
"LLaMA-65B": "huggyllama/llama-65b",
|
||||||
|
"LLaMA2-7B": "meta-llama/Llama-2-7b-hf",
|
||||||
|
"LLaMA2-13B": "meta-llama/Llama-2-13b-hf",
|
||||||
|
"LLaMA2-70B": "meta-llama/Llama-2-70b-hf",
|
||||||
|
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
|
||||||
|
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf",
|
||||||
"BLOOM-560M": "bigscience/bloom-560m",
|
"BLOOM-560M": "bigscience/bloom-560m",
|
||||||
"BLOOM-3B": "bigscience/bloom-3b",
|
"BLOOM-3B": "bigscience/bloom-3b",
|
||||||
"BLOOM-7B1": "bigscience/bloom-7b1",
|
"BLOOM-7B1": "bigscience/bloom-7b1",
|
||||||
@@ -30,8 +36,9 @@ SUPPORTED_MODELS = {
|
|||||||
"InternLM-7B-Chat": "internlm/internlm-chat-7b"
|
"InternLM-7B-Chat": "internlm/internlm-chat-7b"
|
||||||
}
|
}
|
||||||
|
|
||||||
DEFAULT_MODULE = { # will be deprecated
|
DEFAULT_MODULE = {
|
||||||
"LLaMA": "q_proj,v_proj",
|
"LLaMA": "q_proj,v_proj",
|
||||||
|
"LLaMA2": "q_proj,v_proj",
|
||||||
"BLOOM": "query_key_value",
|
"BLOOM": "query_key_value",
|
||||||
"BLOOMZ": "query_key_value",
|
"BLOOMZ": "query_key_value",
|
||||||
"Falcon": "query_key_value",
|
"Falcon": "query_key_value",
|
||||||
|
|||||||
@@ -3,30 +3,13 @@ from dataclasses import dataclass
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Format:
|
class Template:
|
||||||
|
|
||||||
prefix: str
|
prefix: str
|
||||||
prompt: str
|
prompt: str
|
||||||
sep: str
|
sep: str
|
||||||
use_history: bool
|
use_history: bool
|
||||||
|
|
||||||
|
|
||||||
templates: Dict[str, Format] = {}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Template:
|
|
||||||
|
|
||||||
name: str
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.name in templates:
|
|
||||||
self.prefix = templates[self.name].prefix
|
|
||||||
self.prompt = templates[self.name].prompt
|
|
||||||
self.sep = templates[self.name].sep
|
|
||||||
self.use_history = templates[self.name].use_history
|
|
||||||
else:
|
|
||||||
raise ValueError("Template {} does not exist.".format(self.name))
|
|
||||||
|
|
||||||
def get_prompt(
|
def get_prompt(
|
||||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
|
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -46,7 +29,7 @@ class Template:
|
|||||||
def _format_example(
|
def _format_example(
|
||||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
|
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
prefix = prefix if prefix else self.prefix # use prefix if provided
|
prefix = prefix or self.prefix # use prefix if provided
|
||||||
prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix
|
prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix
|
||||||
history = history if (history and self.use_history) else []
|
history = history if (history and self.use_history) else []
|
||||||
history = history + [(query, "<dummy>")]
|
history = history + [(query, "<dummy>")]
|
||||||
@@ -61,8 +44,11 @@ class Template:
|
|||||||
return convs[:-1] # drop last
|
return convs[:-1] # drop last
|
||||||
|
|
||||||
|
|
||||||
|
templates: Dict[str, Template] = {}
|
||||||
|
|
||||||
|
|
||||||
def register_template(name: str, prefix: str, prompt: str, sep: str, use_history: bool) -> None:
|
def register_template(name: str, prefix: str, prompt: str, sep: str, use_history: bool) -> None:
|
||||||
templates[name] = Format(
|
templates[name] = Template(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
sep=sep,
|
sep=sep,
|
||||||
@@ -70,6 +56,12 @@ def register_template(name: str, prefix: str, prompt: str, sep: str, use_history
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_template(name: str) -> Template:
|
||||||
|
template = templates.get(name, None)
|
||||||
|
assert template is not None, "Template {} does not exist.".format(name)
|
||||||
|
return template
|
||||||
|
|
||||||
|
|
||||||
r"""
|
r"""
|
||||||
Supports language model inference without histories.
|
Supports language model inference without histories.
|
||||||
"""
|
"""
|
||||||
@@ -95,6 +87,27 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
r"""
|
||||||
|
Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
|
||||||
|
https://huggingface.co/meta-llama/Llama-2-13b-chat-hf
|
||||||
|
https://huggingface.co/meta-llama/Llama-2-70b-chat-hf
|
||||||
|
"""
|
||||||
|
register_template(
|
||||||
|
name="llama2",
|
||||||
|
prefix="<<SYS>>\nYou are a helpful, respectful and honest assistant. "
|
||||||
|
"Always answer as helpfully as possible, while being safe. "
|
||||||
|
"Your answers should not include any harmful, unethical, "
|
||||||
|
"racist, sexist, toxic, dangerous, or illegal content. "
|
||||||
|
"Please ensure that your responses are socially unbiased and positive in nature.\n"
|
||||||
|
"If a question does not make any sense, or is not factually coherent, "
|
||||||
|
"explain why instead of answering something not correct. "
|
||||||
|
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n",
|
||||||
|
prompt=" [INST] {query} [/INST] ",
|
||||||
|
sep="</s>",
|
||||||
|
use_history=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
r"""
|
r"""
|
||||||
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
|
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
|
||||||
https://github.com/ymcui/Chinese-LLaMA-Alpaca
|
https://github.com/ymcui/Chinese-LLaMA-Alpaca
|
||||||
@@ -202,7 +215,7 @@ Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
|
|||||||
register_template(
|
register_template(
|
||||||
name="baichuan",
|
name="baichuan",
|
||||||
prefix="",
|
prefix="",
|
||||||
prompt=" <reserved_102> {query} <reserved_103> ",
|
prompt="<reserved_102>{query}<reserved_103>",
|
||||||
sep="</s>",
|
sep="</s>",
|
||||||
use_history=True
|
use_history=True
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,9 +16,10 @@ class FinetuningArguments:
|
|||||||
default=32,
|
default=32,
|
||||||
metadata={"help": "Number of decoder blocks in the model. \
|
metadata={"help": "Number of decoder blocks in the model. \
|
||||||
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
|
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
|
||||||
|
LLaMA-2 choices: [\"32\", \"40\", \"80\"], \
|
||||||
BLOOM choices: [\"24\", \"30\", \"70\"], \
|
BLOOM choices: [\"24\", \"30\", \"70\"], \
|
||||||
Falcon choices: [\"32\", \"60\"], \
|
Falcon choices: [\"32\", \"60\"], \
|
||||||
Baichuan choices: [\"32\"]"}
|
Baichuan choices: [\"32\", \"40\"]"}
|
||||||
)
|
)
|
||||||
num_layer_trainable: Optional[int] = field(
|
num_layer_trainable: Optional[int] = field(
|
||||||
default=3,
|
default=3,
|
||||||
@@ -27,7 +28,7 @@ class FinetuningArguments:
|
|||||||
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
|
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
|
||||||
default="mlp",
|
default="mlp",
|
||||||
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
|
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
|
||||||
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
LLaMA & LLaMA-2 choices: [\"mlp\", \"self_attn\"], \
|
||||||
BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \
|
BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \
|
||||||
Baichuan choices: [\"mlp\", \"self_attn\"]"}
|
Baichuan choices: [\"mlp\", \"self_attn\"]"}
|
||||||
)
|
)
|
||||||
@@ -46,7 +47,7 @@ class FinetuningArguments:
|
|||||||
lora_target: Optional[str] = field(
|
lora_target: Optional[str] = field(
|
||||||
default="q_proj,v_proj",
|
default="q_proj,v_proj",
|
||||||
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
|
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
|
||||||
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
LLaMA & LLaMA-2 choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||||
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
|
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
|
||||||
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"}
|
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.utils import check_min_version
|
from transformers.utils import check_min_version
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||||
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
|
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizerBase
|
from transformers.tokenization_utils import PreTrainedTokenizerBase
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
@@ -108,7 +109,7 @@ def load_model_and_tokenizer(
|
|||||||
model_to_load,
|
model_to_load,
|
||||||
config=config,
|
config=config,
|
||||||
torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16,
|
torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16,
|
||||||
low_cpu_mem_usage=True,
|
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
||||||
**config_kwargs
|
**config_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -107,7 +107,11 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
# Compute rewards
|
# Compute rewards
|
||||||
replace_model(unwrapped_model, target="reward")
|
replace_model(unwrapped_model, target="reward")
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_, _, values = self.model(**self.prepare_model_inputs(queries, responses))
|
_, _, values = self.model(
|
||||||
|
**self.prepare_model_inputs(queries, responses),
|
||||||
|
output_hidden_states=True,
|
||||||
|
return_dict=True
|
||||||
|
)
|
||||||
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
|
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
|
||||||
replace_model(unwrapped_model, target="default")
|
replace_model(unwrapped_model, target="default")
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import math
|
|||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
@@ -28,16 +28,6 @@ def run_pt(
|
|||||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Split the dataset
|
|
||||||
if training_args.do_train:
|
|
||||||
if data_args.dev_ratio > 1e-6:
|
|
||||||
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
|
|
||||||
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
|
||||||
else:
|
|
||||||
trainer_kwargs = {"train_dataset": dataset}
|
|
||||||
else: # do_eval or do_predict
|
|
||||||
trainer_kwargs = {"eval_dataset": dataset}
|
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = PeftTrainer(
|
trainer = PeftTrainer(
|
||||||
finetuning_args=finetuning_args,
|
finetuning_args=finetuning_args,
|
||||||
@@ -46,7 +36,7 @@ def run_pt(
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
**trainer_kwargs
|
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ class PairwisePeftTrainer(PeftTrainer):
|
|||||||
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
|
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
|
||||||
"""
|
"""
|
||||||
batch_size = inputs["input_ids"].size(0) // 2
|
batch_size = inputs["input_ids"].size(0) // 2
|
||||||
_, _, values = model(**inputs)
|
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
||||||
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
|
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
|
||||||
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
|
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
|
||||||
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss
|
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||||
@@ -29,16 +29,6 @@ def run_rm(
|
|||||||
|
|
||||||
training_args.remove_unused_columns = False # important for pairwise dataset
|
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||||
|
|
||||||
# Split the dataset
|
|
||||||
if training_args.do_train:
|
|
||||||
if data_args.dev_ratio > 1e-6:
|
|
||||||
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
|
|
||||||
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
|
||||||
else:
|
|
||||||
trainer_kwargs = {"train_dataset": dataset}
|
|
||||||
else: # do_eval or do_predict
|
|
||||||
trainer_kwargs = {"eval_dataset": dataset}
|
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = PairwisePeftTrainer(
|
trainer = PairwisePeftTrainer(
|
||||||
finetuning_args=finetuning_args,
|
finetuning_args=finetuning_args,
|
||||||
@@ -48,7 +38,7 @@ def run_rm(
|
|||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
compute_metrics=compute_accuracy,
|
compute_metrics=compute_accuracy,
|
||||||
**trainer_kwargs
|
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
|
|||||||
@@ -36,31 +36,44 @@ class Seq2SeqPeftTrainer(PeftTrainer):
|
|||||||
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
|
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
|
||||||
if label_len > prompt_len:
|
if label_len > prompt_len:
|
||||||
inputs["input_ids"] = self._pad_tensors_to_target_len(inputs["input_ids"], inputs["labels"])
|
inputs["input_ids"] = self._pad_tensors_to_target_len(inputs["input_ids"], inputs["labels"])
|
||||||
|
if "attention_mask" in inputs:
|
||||||
|
inputs["attention_mask"] = self._pad_tensors_to_target_len(
|
||||||
|
inputs["attention_mask"], inputs["labels"], pad_token_id=0
|
||||||
|
)
|
||||||
|
if "position_ids" in inputs:
|
||||||
|
inputs["position_ids"] = self._pad_tensors_to_target_len(
|
||||||
|
inputs["position_ids"], inputs["labels"], pad_token_id=0
|
||||||
|
)
|
||||||
|
|
||||||
loss, generated_tokens, labels = super().prediction_step(
|
loss, generated_tokens, labels = super().prediction_step(
|
||||||
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
||||||
)
|
)
|
||||||
generated_tokens = generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None
|
generated_tokens = (
|
||||||
|
generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
return (loss, generated_tokens, labels)
|
return (loss, generated_tokens, labels)
|
||||||
|
|
||||||
def _pad_tensors_to_target_len(self, src_tensor: torch.Tensor, tgt_tensor: torch.Tensor) -> torch.Tensor:
|
def _pad_tensors_to_target_len(
|
||||||
|
self,
|
||||||
|
src_tensor: torch.Tensor,
|
||||||
|
tgt_tensor: torch.Tensor,
|
||||||
|
pad_token_id: Optional[int] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
r"""
|
r"""
|
||||||
Pads the tensor to the same length as the target tensor.
|
Pads the tensor to the same length as the target tensor.
|
||||||
|
|
||||||
Should only be called when predict_with_generate=True.
|
Should only be called when predict_with_generate=True.
|
||||||
"""
|
"""
|
||||||
if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
|
if pad_token_id is None:
|
||||||
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
|
if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
|
||||||
# If PAD token is not defined at least EOS token has to be defined
|
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
|
||||||
pad_token_id = (
|
pad_token_id = self.tokenizer.pad_token_id
|
||||||
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if self.model.config.pad_token_id is not None:
|
|
||||||
pad_token_id = self.model.config.pad_token_id
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
|
if self.model.config.pad_token_id is not None:
|
||||||
|
pad_token_id = self.model.config.pad_token_id
|
||||||
|
else:
|
||||||
|
raise ValueError("Pad_token_id must be set in the configuration of the model.")
|
||||||
|
|
||||||
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
|
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
|
||||||
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
|
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.misc import get_logits_processor
|
from llmtuner.extras.misc import get_logits_processor
|
||||||
@@ -35,16 +35,6 @@ def run_sft(
|
|||||||
training_args.generation_num_beams = data_args.eval_num_beams if \
|
training_args.generation_num_beams = data_args.eval_num_beams if \
|
||||||
data_args.eval_num_beams is not None else training_args.generation_num_beams
|
data_args.eval_num_beams is not None else training_args.generation_num_beams
|
||||||
|
|
||||||
# Split the dataset
|
|
||||||
if training_args.do_train:
|
|
||||||
if data_args.dev_ratio > 1e-6:
|
|
||||||
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
|
|
||||||
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
|
||||||
else:
|
|
||||||
trainer_kwargs = {"train_dataset": dataset}
|
|
||||||
else: # do_eval or do_predict
|
|
||||||
trainer_kwargs = {"eval_dataset": dataset}
|
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = Seq2SeqPeftTrainer(
|
trainer = Seq2SeqPeftTrainer(
|
||||||
finetuning_args=finetuning_args,
|
finetuning_args=finetuning_args,
|
||||||
@@ -54,7 +44,7 @@ def run_sft(
|
|||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
||||||
**trainer_kwargs
|
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Keyword arguments for `model.generate`
|
# Keyword arguments for `model.generate`
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
from llmtuner.webui.interface import create_ui
|
|
||||||
|
|||||||
@@ -73,6 +73,7 @@ class WebChatModel(ChatModel):
|
|||||||
chatbot: List[Tuple[str, str]],
|
chatbot: List[Tuple[str, str]],
|
||||||
query: str,
|
query: str,
|
||||||
history: List[Tuple[str, str]],
|
history: List[Tuple[str, str]],
|
||||||
|
prefix: str,
|
||||||
max_new_tokens: int,
|
max_new_tokens: int,
|
||||||
top_p: float,
|
top_p: float,
|
||||||
temperature: float
|
temperature: float
|
||||||
@@ -80,7 +81,7 @@ class WebChatModel(ChatModel):
|
|||||||
chatbot.append([query, ""])
|
chatbot.append([query, ""])
|
||||||
response = ""
|
response = ""
|
||||||
for new_text in self.stream_chat(
|
for new_text in self.stream_chat(
|
||||||
query, history, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
query, history, prefix, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||||
):
|
):
|
||||||
response += new_text
|
response += new_text
|
||||||
new_history = history + [(query, response)]
|
new_history = history + [(query, response)]
|
||||||
|
|||||||
@@ -16,11 +16,9 @@ def create_chat_box(
|
|||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=4):
|
with gr.Column(scale=4):
|
||||||
with gr.Column(scale=12):
|
prefix = gr.Textbox(show_label=False)
|
||||||
query = gr.Textbox(show_label=False, lines=8)
|
query = gr.Textbox(show_label=False, lines=8)
|
||||||
|
submit_btn = gr.Button(variant="primary")
|
||||||
with gr.Column(min_width=32, scale=1):
|
|
||||||
submit_btn = gr.Button(variant="primary")
|
|
||||||
|
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
clear_btn = gr.Button()
|
clear_btn = gr.Button()
|
||||||
@@ -36,7 +34,7 @@ def create_chat_box(
|
|||||||
|
|
||||||
submit_btn.click(
|
submit_btn.click(
|
||||||
chat_model.predict,
|
chat_model.predict,
|
||||||
[chatbot, query, history, max_new_tokens, top_p, temperature],
|
[chatbot, query, history, prefix, max_new_tokens, top_p, temperature],
|
||||||
[chatbot, history],
|
[chatbot, history],
|
||||||
show_progress=True
|
show_progress=True
|
||||||
).then(
|
).then(
|
||||||
@@ -46,6 +44,7 @@ def create_chat_box(
|
|||||||
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
|
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
|
||||||
|
|
||||||
return chat_box, chatbot, history, dict(
|
return chat_box, chatbot, history, dict(
|
||||||
|
prefix=prefix,
|
||||||
query=query,
|
query=query,
|
||||||
submit_btn=submit_btn,
|
submit_btn=submit_btn,
|
||||||
clear_btn=clear_btn,
|
clear_btn=clear_btn,
|
||||||
|
|||||||
@@ -31,7 +31,8 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
|
|||||||
start_btn = gr.Button()
|
start_btn = gr.Button()
|
||||||
stop_btn = gr.Button()
|
stop_btn = gr.Button()
|
||||||
|
|
||||||
output_box = gr.Markdown()
|
with gr.Box():
|
||||||
|
output_box = gr.Markdown()
|
||||||
|
|
||||||
start_btn.click(
|
start_btn.click(
|
||||||
runner.run_eval,
|
runner.run_eval,
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
|
|||||||
load_btn = gr.Button()
|
load_btn = gr.Button()
|
||||||
unload_btn = gr.Button()
|
unload_btn = gr.Button()
|
||||||
|
|
||||||
info_box = gr.Markdown()
|
info_box = gr.Textbox(show_label=False, interactive=False)
|
||||||
|
|
||||||
chat_model = WebChatModel()
|
chat_model = WebChatModel()
|
||||||
chat_box, chatbot, history, chat_elems = create_chat_box(chat_model)
|
chat_box, chatbot, history, chat_elems = create_chat_box(chat_model)
|
||||||
|
|||||||
@@ -35,12 +35,21 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
|
|||||||
lr_scheduler_type = gr.Dropdown(
|
lr_scheduler_type = gr.Dropdown(
|
||||||
value="cosine", choices=[scheduler.value for scheduler in SchedulerType]
|
value="cosine", choices=[scheduler.value for scheduler in SchedulerType]
|
||||||
)
|
)
|
||||||
|
max_grad_norm = gr.Textbox(value="1.0")
|
||||||
dev_ratio = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
|
dev_ratio = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
|
||||||
fp16 = gr.Checkbox(value=True)
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
||||||
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
|
with gr.Row():
|
||||||
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
|
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
|
||||||
|
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
|
||||||
|
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
|
||||||
|
compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")
|
||||||
|
|
||||||
|
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
|
||||||
|
with gr.Row():
|
||||||
|
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
|
||||||
|
lora_dropout = gr.Slider(value=0, minimum=0, maximum=1, step=0.01, scale=1)
|
||||||
|
lora_target = gr.Textbox(scale=2)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
start_btn = gr.Button()
|
start_btn = gr.Button()
|
||||||
@@ -49,7 +58,9 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=4):
|
with gr.Column(scale=4):
|
||||||
output_dir = gr.Textbox(interactive=True)
|
output_dir = gr.Textbox(interactive=True)
|
||||||
output_box = gr.Markdown()
|
|
||||||
|
with gr.Box():
|
||||||
|
output_box = gr.Markdown()
|
||||||
|
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
loss_viewer = gr.Plot()
|
loss_viewer = gr.Plot()
|
||||||
@@ -74,10 +85,15 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
|
|||||||
batch_size,
|
batch_size,
|
||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
lr_scheduler_type,
|
lr_scheduler_type,
|
||||||
|
max_grad_norm,
|
||||||
dev_ratio,
|
dev_ratio,
|
||||||
fp16,
|
|
||||||
logging_steps,
|
logging_steps,
|
||||||
save_steps,
|
save_steps,
|
||||||
|
warmup_steps,
|
||||||
|
compute_type,
|
||||||
|
lora_rank,
|
||||||
|
lora_dropout,
|
||||||
|
lora_target,
|
||||||
output_dir
|
output_dir
|
||||||
],
|
],
|
||||||
[output_box]
|
[output_box]
|
||||||
@@ -103,10 +119,17 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
|
|||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||||
lr_scheduler_type=lr_scheduler_type,
|
lr_scheduler_type=lr_scheduler_type,
|
||||||
|
max_grad_norm=max_grad_norm,
|
||||||
dev_ratio=dev_ratio,
|
dev_ratio=dev_ratio,
|
||||||
fp16=fp16,
|
advanced_tab=advanced_tab,
|
||||||
logging_steps=logging_steps,
|
logging_steps=logging_steps,
|
||||||
save_steps=save_steps,
|
save_steps=save_steps,
|
||||||
|
warmup_steps=warmup_steps,
|
||||||
|
compute_type=compute_type,
|
||||||
|
lora_tab=lora_tab,
|
||||||
|
lora_rank=lora_rank,
|
||||||
|
lora_dropout=lora_dropout,
|
||||||
|
lora_target=lora_target,
|
||||||
start_btn=start_btn,
|
start_btn=start_btn,
|
||||||
stop_btn=stop_btn,
|
stop_btn=stop_btn,
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
|
|||||||
@@ -22,10 +22,11 @@ def create_top() -> Dict[str, Component]:
|
|||||||
checkpoints = gr.Dropdown(multiselect=True, scale=5)
|
checkpoints = gr.Dropdown(multiselect=True, scale=5)
|
||||||
refresh_btn = gr.Button(scale=1)
|
refresh_btn = gr.Button(scale=1)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
||||||
quantization_bit = gr.Dropdown([8, 4], scale=1)
|
with gr.Row():
|
||||||
template = gr.Dropdown(value="default", choices=list(templates.keys()), scale=2)
|
quantization_bit = gr.Dropdown([8, 4], scale=1)
|
||||||
source_prefix = gr.Textbox(scale=4)
|
template = gr.Dropdown(value="default", choices=list(templates.keys()), scale=1)
|
||||||
|
source_prefix = gr.Textbox(scale=2)
|
||||||
|
|
||||||
model_name.change(
|
model_name.change(
|
||||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||||
@@ -47,9 +48,10 @@ def create_top() -> Dict[str, Component]:
|
|||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
finetuning_type=finetuning_type,
|
finetuning_type=finetuning_type,
|
||||||
template=template,
|
|
||||||
checkpoints=checkpoints,
|
checkpoints=checkpoints,
|
||||||
refresh_btn=refresh_btn,
|
refresh_btn=refresh_btn,
|
||||||
|
advanced_tab=advanced_tab,
|
||||||
quantization_bit=quantization_bit,
|
quantization_bit=quantization_bit,
|
||||||
|
template=template,
|
||||||
source_prefix=source_prefix
|
source_prefix=source_prefix
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ def create_ui() -> gr.Blocks:
|
|||||||
with gr.Tab("Evaluate"):
|
with gr.Tab("Evaluate"):
|
||||||
eval_elems = create_eval_tab(top_elems, runner)
|
eval_elems = create_eval_tab(top_elems, runner)
|
||||||
|
|
||||||
with gr.Tab("Inference"):
|
with gr.Tab("Chat"):
|
||||||
infer_elems = create_infer_tab(top_elems)
|
infer_elems = create_infer_tab(top_elems)
|
||||||
|
|
||||||
elem_list = [top_elems, sft_elems, eval_elems, infer_elems]
|
elem_list = [top_elems, sft_elems, eval_elems, infer_elems]
|
||||||
|
|||||||
@@ -49,6 +49,14 @@ LOCALES = {
|
|||||||
"value": "刷新断点"
|
"value": "刷新断点"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"advanced_tab": {
|
||||||
|
"en": {
|
||||||
|
"label": "Advanced configurations"
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "高级设置"
|
||||||
|
}
|
||||||
|
},
|
||||||
"quantization_bit": {
|
"quantization_bit": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Quantization bit (optional)",
|
"label": "Quantization bit (optional)",
|
||||||
@@ -71,12 +79,12 @@ LOCALES = {
|
|||||||
},
|
},
|
||||||
"source_prefix": {
|
"source_prefix": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Source prefix (optional)",
|
"label": "System prompt (optional)",
|
||||||
"info": "A sequence used as the prefix of each samples."
|
"info": "A sequence used as the default system prompt."
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {
|
||||||
"label": "前缀序列(非必填)",
|
"label": "系统提示词(非必填)",
|
||||||
"info": "作为每个输入样本前缀的序列"
|
"info": "默认使用的系统提示词"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"dataset_dir": {
|
"dataset_dir": {
|
||||||
@@ -209,6 +217,16 @@ LOCALES = {
|
|||||||
"info": "采用的学习率调节器名称。"
|
"info": "采用的学习率调节器名称。"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"max_grad_norm": {
|
||||||
|
"en": {
|
||||||
|
"label": "Maximum gradient norm",
|
||||||
|
"info": "Norm for gradient clipping.."
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "最大梯度范数",
|
||||||
|
"info": "用于梯度裁剪的范数。"
|
||||||
|
}
|
||||||
|
},
|
||||||
"dev_ratio": {
|
"dev_ratio": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Dev ratio",
|
"label": "Dev ratio",
|
||||||
@@ -219,20 +237,10 @@ LOCALES = {
|
|||||||
"info": "验证集占全部样本的百分比。"
|
"info": "验证集占全部样本的百分比。"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fp16": {
|
|
||||||
"en": {
|
|
||||||
"label": "fp16",
|
|
||||||
"info": "Whether to use fp16 mixed precision training."
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "fp16",
|
|
||||||
"info": "是否启用 FP16 混合精度训练。"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"logging_steps": {
|
"logging_steps": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Logging steps",
|
"label": "Logging steps",
|
||||||
"info": "Number of update steps between two logs."
|
"info": "Number of steps between two logs."
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {
|
||||||
"label": "日志间隔",
|
"label": "日志间隔",
|
||||||
@@ -242,13 +250,71 @@ LOCALES = {
|
|||||||
"save_steps": {
|
"save_steps": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Save steps",
|
"label": "Save steps",
|
||||||
"info": "Number of updates steps between two checkpoints."
|
"info": "Number of steps between two checkpoints."
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {
|
||||||
"label": "保存间隔",
|
"label": "保存间隔",
|
||||||
"info": "每两次断点保存间的更新步数。"
|
"info": "每两次断点保存间的更新步数。"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"warmup_steps": {
|
||||||
|
"en": {
|
||||||
|
"label": "Warmup steps",
|
||||||
|
"info": "Number of steps used for warmup."
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "预热步数",
|
||||||
|
"info": "学习率预热采用的步数。"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"compute_type": {
|
||||||
|
"en": {
|
||||||
|
"label": "Compute type",
|
||||||
|
"info": "Whether to use fp16 or bf16 mixed precision training."
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "计算类型",
|
||||||
|
"info": "是否启用 FP16 或 BF16 混合精度训练。"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"lora_tab": {
|
||||||
|
"en": {
|
||||||
|
"label": "LoRA configurations"
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "LoRA 参数设置"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"lora_rank": {
|
||||||
|
"en": {
|
||||||
|
"label": "LoRA rank",
|
||||||
|
"info": "The rank of LoRA matrices."
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "LoRA 秩",
|
||||||
|
"info": "LoRA 矩阵的秩。"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"lora_dropout": {
|
||||||
|
"en": {
|
||||||
|
"label": "LoRA Dropout",
|
||||||
|
"info": "Dropout ratio of LoRA weights."
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "LoRA 随机丢弃",
|
||||||
|
"info": "LoRA 权重随机丢弃的概率。"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"lora_target": {
|
||||||
|
"en": {
|
||||||
|
"label": "LoRA modules (optional)",
|
||||||
|
"info": "The name(s) of target modules to apply LoRA. Use commas to separate multiple modules."
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "LoRA 作用层(非必填)",
|
||||||
|
"info": "应用 LoRA 的线性层名称。使用英文逗号分隔多个名称。"
|
||||||
|
}
|
||||||
|
},
|
||||||
"start_btn": {
|
"start_btn": {
|
||||||
"en": {
|
"en": {
|
||||||
"value": "Start"
|
"value": "Start"
|
||||||
@@ -323,6 +389,14 @@ LOCALES = {
|
|||||||
"value": "模型未加载,请先加载模型。"
|
"value": "模型未加载,请先加载模型。"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"prefix": {
|
||||||
|
"en": {
|
||||||
|
"placeholder": "System prompt (optional)"
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"placeholder": "系统提示词(非必填)"
|
||||||
|
}
|
||||||
|
},
|
||||||
"query": {
|
"query": {
|
||||||
"en": {
|
"en": {
|
||||||
"placeholder": "Input..."
|
"placeholder": "Input..."
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import transformers
|
|||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
from llmtuner.extras.constants import DEFAULT_MODULE # will be deprecated
|
from llmtuner.extras.constants import DEFAULT_MODULE
|
||||||
from llmtuner.extras.logging import LoggerHandler
|
from llmtuner.extras.logging import LoggerHandler
|
||||||
from llmtuner.extras.misc import torch_gc
|
from llmtuner.extras.misc import torch_gc
|
||||||
from llmtuner.tuner import get_train_args, run_sft
|
from llmtuner.tuner import get_train_args, run_sft
|
||||||
@@ -77,10 +77,15 @@ class Runner:
|
|||||||
batch_size: int,
|
batch_size: int,
|
||||||
gradient_accumulation_steps: int,
|
gradient_accumulation_steps: int,
|
||||||
lr_scheduler_type: str,
|
lr_scheduler_type: str,
|
||||||
|
max_grad_norm: str,
|
||||||
dev_ratio: float,
|
dev_ratio: float,
|
||||||
fp16: bool,
|
|
||||||
logging_steps: int,
|
logging_steps: int,
|
||||||
save_steps: int,
|
save_steps: int,
|
||||||
|
warmup_steps: int,
|
||||||
|
compute_type: str,
|
||||||
|
lora_rank: int,
|
||||||
|
lora_dropout: float,
|
||||||
|
lora_target: str,
|
||||||
output_dir: str
|
output_dir: str
|
||||||
):
|
):
|
||||||
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
||||||
@@ -99,7 +104,6 @@ class Runner:
|
|||||||
model_name_or_path=model_name_or_path,
|
model_name_or_path=model_name_or_path,
|
||||||
do_train=True,
|
do_train=True,
|
||||||
overwrite_cache=True,
|
overwrite_cache=True,
|
||||||
lora_target=DEFAULT_MODULE.get(model_name.split("-")[0], None) or "q_proj,v_proj",
|
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
finetuning_type=finetuning_type,
|
finetuning_type=finetuning_type,
|
||||||
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||||
@@ -115,9 +119,15 @@ class Runner:
|
|||||||
per_device_train_batch_size=batch_size,
|
per_device_train_batch_size=batch_size,
|
||||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||||
lr_scheduler_type=lr_scheduler_type,
|
lr_scheduler_type=lr_scheduler_type,
|
||||||
fp16=fp16,
|
max_grad_norm=float(max_grad_norm),
|
||||||
logging_steps=logging_steps,
|
logging_steps=logging_steps,
|
||||||
save_steps=save_steps,
|
save_steps=save_steps,
|
||||||
|
warmup_steps=warmup_steps,
|
||||||
|
fp16=(compute_type == "fp16"),
|
||||||
|
bf16=(compute_type == "bf16"),
|
||||||
|
lora_rank=lora_rank,
|
||||||
|
lora_dropout=lora_dropout,
|
||||||
|
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
|
||||||
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
|
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from llmtuner import get_train_args, run_pt, run_sft, run_rm, run_ppo
|
from llmtuner.tuner import get_train_args, run_pt, run_sft, run_rm, run_ppo
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from llmtuner import create_ui
|
from llmtuner.webui.interface import create_ui
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from llmtuner import get_infer_args
|
from llmtuner.tuner import get_infer_args
|
||||||
from llmtuner.webui.chat import WebChatModel
|
from llmtuner.webui.chat import WebChatModel
|
||||||
from llmtuner.webui.components.chatbot import create_chat_box
|
from llmtuner.webui.components.chatbot import create_chat_box
|
||||||
from llmtuner.webui.manager import Manager
|
from llmtuner.webui.manager import Manager
|
||||||
@@ -24,17 +24,9 @@ def main():
|
|||||||
|
|
||||||
manager = Manager([{"lang": lang}, chat_elems])
|
manager = Manager([{"lang": lang}, chat_elems])
|
||||||
|
|
||||||
demo.load(
|
demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
|
||||||
manager.gen_label,
|
|
||||||
[lang],
|
|
||||||
[lang] + [elem for elem in chat_elems.values()],
|
|
||||||
)
|
|
||||||
|
|
||||||
lang.change(
|
lang.change(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
|
||||||
manager.gen_label,
|
|
||||||
[lang],
|
|
||||||
[lang] + [elem for elem in chat_elems.values()],
|
|
||||||
)
|
|
||||||
|
|
||||||
demo.queue()
|
demo.queue()
|
||||||
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
|
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user