105 Commits

Author SHA1 Message Date
hiyouga
95d0f77fc2 release v0.3.0
Former-commit-id: de7f5b622340ab09ebbe57ad2703e63d06dfdeea
2023-11-16 16:00:11 +08:00
hiyouga
9b2654277b update readme
Former-commit-id: 4018aabc5d1623033d27a8aced25804de79b7e7b
2023-11-16 15:58:37 +08:00
hoshi-hiyouga
f1b3bdac3f Merge #1525 from hiyouga/dev, fix #224 #336 #931 #936 #1011
Refactor llmtuner, support full-parameter RLHF

Former-commit-id: 3b92826803dc69471827b4f8204c2c3dc5310619
2023-11-16 15:47:13 +08:00
hiyouga
595fdbd95d fix css
Former-commit-id: 7afec127f60257462828298b25a5f6fd9c6f42c5
2023-11-16 15:45:38 +08:00
hiyouga
dab9385297 fix bug in web ui
Former-commit-id: a598f145ec903dd2b2c984d951b6c450b142ece5
2023-11-16 15:21:24 +08:00
hiyouga
df83def566 update ppo and demo in webui
Former-commit-id: de7571704c82121db13e3fc907379d2453100191
2023-11-16 14:55:26 +08:00
hiyouga
f9d4e37b3c fix bug in freeze tuning
Former-commit-id: f6b436a08421ca17d64abc51497f4aa43729a43b
2023-11-16 14:25:11 +08:00
hiyouga
e59a3d71e0 tiny fix
Former-commit-id: d65519d8a44b73bbb713741c23465f13c35c83f5
2023-11-16 03:27:19 +08:00
hiyouga
de3a84ac59 fix rlhf callback
Former-commit-id: f5485452d660caef56474cb7dc37abbe4f34599e
2023-11-16 03:26:19 +08:00
hiyouga
e017266b98 fix bug in PPO training
Former-commit-id: 2e99f0e53ce6de0acbcab85dd50aef874e8c6336
2023-11-16 02:32:54 +08:00
hiyouga
f81a8a5e5c fix import bug
Former-commit-id: 2356029cdd120d5f7bf630b80681ce8c53bff90d
2023-11-16 02:27:03 +08:00
hiyouga
7a3a0144a5 support full-parameter PPO
Former-commit-id: 4af967d69475e1c9fdf1a7983cd6b83bd431abff
2023-11-16 02:08:04 +08:00
hiyouga
8263b2d32d add demo mode for web UI
Former-commit-id: 5ad34f08b4e1505d7933b973497347f126b2e818
2023-11-15 23:51:26 +08:00
hoshi-hiyouga
833cd490b8 Create CODE_OF_CONDUCT.md
Former-commit-id: 6bee64cdf9c75488033e600fb5b48738daa1ed3b
2023-11-15 20:42:15 +08:00
hiyouga
2162c37e41 update readme and constants
Former-commit-id: 7d83e3dd9101a4fdd0b589d0c1f7b609c0feecd1
2023-11-15 18:04:37 +08:00
hiyouga
b2ac8376e1 support multiple modules in freeze training #1514
Former-commit-id: 60abac70dfd778df2ae8b3a2e960ed8b607d7ab6
2023-11-15 17:08:18 +08:00
hiyouga
8079584143 fix imports
Former-commit-id: 6156f1abef631c675d150dd1cb0325cfc3820c91
2023-11-15 16:47:45 +08:00
hiyouga
09a4474e7f disentangle model from tuner and rename modules
Former-commit-id: 02cbf91e7e424f8379c1fed01b82a5f7a83b6947
2023-11-15 16:29:09 +08:00
hiyouga
81530133ff fix #1507
Former-commit-id: 1ba9c53bd9743fa95fca1516c0ed9da352dbe9a1
2023-11-15 16:22:32 +08:00
hiyouga
cc4b384ac3 Update cal_lr.py
Former-commit-id: b92ef6c80ae108982046ec1419efb67c8b10b250
2023-11-14 21:14:42 +08:00
hiyouga
3852daf447 Update cal_lr.py
Former-commit-id: b6c3f9b24324403db41c5680a00aabc6d53bbeb9
2023-11-14 21:13:01 +08:00
hiyouga
5c97111f9d Update cal_lr.py
Former-commit-id: 1258eec806f6f4580a6eb7d9eb44f431f4c0da4f
2023-11-14 21:09:30 +08:00
hiyouga
75dd1f0f7e add cal_lr.py
Former-commit-id: cea2ba17efc47917e63437a376f220864f7f90dd
2023-11-14 20:58:37 +08:00
hiyouga
c9a4551012 fix #1494
Former-commit-id: 07c8d734529f03e47ef638a1bda222e8824d3d38
2023-11-14 18:07:20 +08:00
hiyouga
87197ba91d fix #1489
Former-commit-id: ebdeaca9cdfd6138c690a0fcb9f676deaddff177
2023-11-14 15:27:05 +08:00
hiyouga
7461bf84e5 support eval remote dataset
Former-commit-id: 71dd2698bf8c0b9ef7af995fb1e49e39fa66074e
2023-11-14 02:42:30 +08:00
hiyouga
fbc0357b2e fix dc link
Former-commit-id: 04c3a1f1c98d8f191102e359def0c8dcdc9621e3
2023-11-13 23:22:56 +08:00
hiyouga
ec334f5891 release v0.2.2, fix #1478 #1466
Former-commit-id: c9534c411716e1dceb54c5eb35fe845c93ee2973
2023-11-13 23:09:05 +08:00
hiyouga
885efe772e fix #424
Former-commit-id: ca24d445f825e120e659f5cd080a954c2243b8f2
2023-11-13 22:42:23 +08:00
hiyouga
64fc9ba678 refactor evaluation, upgrade trl to 074
Former-commit-id: ed09ebe2c1926ffdb0520b3866f7fd03a9aed046
2023-11-13 22:20:35 +08:00
hiyouga
989eccd286 fix flashattn warning
Former-commit-id: 6eb095d39bd82fdbdb729a0ea57fc7246e3a60d6
2023-11-10 18:34:54 +08:00
hiyouga
f0766a2ab0 add todo
Former-commit-id: 0bd884feb11736d0ab24ca19885151cb47d9dcd3
2023-11-10 14:38:18 +08:00
hiyouga
178b85ff9a refactor constants
Former-commit-id: a4d4c3fd35276f20e3b354e9d13ea971029c8775
2023-11-10 14:16:10 +08:00
hiyouga
68dd1ef121 tiny fix
Former-commit-id: 97ba2027bb1ddc01a3c824c40d5a180828810c2c
2023-11-09 17:20:49 +08:00
hoshi-hiyouga
b222cffe98 Merge pull request #1454 from yyq/main
Update finetuning_args.py

Former-commit-id: e67d8b93705383a8590f99e26e9fe8f663712aef
2023-11-09 17:12:18 +08:00
Yanqing
b4f1ab93d1 Update finetuning_args.py
更新 chatglm/falcon/bloom 的 lora_target 的名称

Former-commit-id: 06606739af035a80ae9ddba9d12c965ed289305d
2023-11-09 17:04:40 +08:00
hiyouga
f2e139f5cd fix #1452
Former-commit-id: 4d16214467715df458e24d03bb7d303d62b8bdcd
2023-11-09 16:41:32 +08:00
hiyouga
a9cbca1604 update readme
Former-commit-id: f7ead54042868550a3e8a6928ea3c0e2673f15b3
2023-11-09 16:00:24 +08:00
hiyouga
3a30ce6c16 release v0.2.1
Former-commit-id: 1c30f2be0140f5ab47c2bc811170d0271a0cdad6
2023-11-09 15:54:16 +08:00
hiyouga
48ec5355f9 add template, modify datasets
Former-commit-id: 81e54beb4d0f792f4fd7f450643caaf10f2f0b7d
2023-11-09 15:53:23 +08:00
hoshi-hiyouga
11859bc322 Merge pull request #1436 from lvzii/main
fix tokenizer config changed after pretrain

Former-commit-id: f485c3983e413fd3a3a57b451800705b072869a7
2023-11-09 14:30:50 +08:00
hiyouga
28c67a5be8 support parquet format #1446
Former-commit-id: 44a3b9ac9f10d2012b8ad3d8c48123db9a0da2f1
2023-11-09 14:17:40 +08:00
hiyouga
44fe93e9b0 fix #1438 #1439
Former-commit-id: 84260d58dda22adc32c26bc943ed2a36fd01341d
2023-11-09 13:45:10 +08:00
lvzi
09a1681b63 fix tokenizer config changed after pretrain
Changing tokenizer's attribute at preprocessing stage will result in saving a wrong tokenizer.
for example, baichuan2

Former-commit-id: 19942b5314b84267691f0a5657d0679f2ddbe58b
2023-11-08 15:50:46 +08:00
hiyouga
f5ba2190fb fix ppo train and dpo eval
Former-commit-id: ced863031836632cb5920e22ae6991f251372118
2023-11-07 22:48:51 +08:00
hiyouga
14a38b5069 fix #1422
Former-commit-id: 25d7bbd0a5142f001bd2ff498df07b24137050a9
2023-11-07 19:42:01 +08:00
hiyouga
f23e5b602a fix reward model loading
Former-commit-id: 9709ca501180a1afce32e9043aedb359762b437d
2023-11-07 17:20:51 +08:00
hiyouga
857696ed9c fix args
Former-commit-id: 44d0fa2ac6a6423c7ddaf91eb8998c1b9248c04e
2023-11-07 16:36:06 +08:00
hiyouga
2084133058 update info
Former-commit-id: 89643b8ac1e3fa8d2f29f1c88e4d4503410c0d05
2023-11-07 16:28:21 +08:00
hiyouga
f7f0c3070e delete file
Former-commit-id: 7d6355db0fd5809b99f3fa42753cf4dffd251fd1
2023-11-07 16:20:12 +08:00
hiyouga
46235aa514 fix #1418
Former-commit-id: 9bfecc72c53cf95fea4a9ff02ec40a65da6d4f54
2023-11-07 16:17:22 +08:00
hiyouga
2eb65d21ac upgrade peft, fix #1088 #1411
Former-commit-id: aa7d104f8e050d12cb8f585bc8a52c850995500f
2023-11-07 16:13:36 +08:00
hiyouga
37a0d62a82 update requirements
Former-commit-id: 82ebbbbb80b3f3f616274210970738d0f44b5a0a
2023-11-06 19:01:21 +08:00
hiyouga
21ac46e439 use seed in evaluate.py
Former-commit-id: ab5cac1dfa681933f3266827f80068ce798b4c56
2023-11-06 18:17:51 +08:00
hiyouga
ba3e8ba20c update readme (list in alphabetical order)
Former-commit-id: e6a67b5477ee095bd92764581cfe6af57e799a69
2023-11-06 17:18:12 +08:00
hiyouga
2c48e798ca update templates
Former-commit-id: 85be2e242b062283f192c4c4d0715dc1e8a68589
2023-11-06 12:25:47 +08:00
hiyouga
4e40f5b62b fix #1383
Former-commit-id: 9b8a782aa80f27c3e2a2e2621f9be17cae1a27e8
2023-11-06 11:42:23 +08:00
hiyouga
2a8892b785 fix deepseek template
Former-commit-id: 1fdbcdad9a1cdb20299350efd87a8e5cb8c625a3
2023-11-05 13:08:46 +08:00
hiyouga
ee3b33ff03 support deepseek coder #1378
Former-commit-id: ae0c829917b9de10e71199c85c77a52cdcd2b7b3
2023-11-05 12:51:03 +08:00
hiyouga
b2c3001f8e fix #1365
Former-commit-id: 0277d120e62164bb7fa1d6043b8fcc52c881fe96
2023-11-05 12:21:07 +08:00
hiyouga
6cfe1e1ac2 tiny fix
Former-commit-id: 594c510a20d6c2782d7b7ffff18931e3003e6c22
2023-11-03 01:26:06 +08:00
hiyouga
52326870e4 fix #1290
Former-commit-id: ad911d258c4cea16f54d09bc192e076c21d26394
2023-11-03 00:44:53 +08:00
hiyouga
217fde0918 fix bug in data loader, support dpo eval
Former-commit-id: f4f3dcff990468a2fa864b7176adcebbcf16dac9
2023-11-03 00:34:26 +08:00
hiyouga
065021d82a update data readme
Former-commit-id: 6a65ef44ed58714c611da60b5af96b85352e8735
2023-11-03 00:15:23 +08:00
hiyouga
4bb643e685 update data readme (zh)
Former-commit-id: b32fb3a984c681732b82f6544d6c05a98c34cf4c
2023-11-02 23:42:49 +08:00
hiyouga
b77c745b1a support sharegpt format, add datasets
Former-commit-id: 202daf8987ccb7523be03ca535b572b5c9e65994
2023-11-02 23:10:04 +08:00
hiyouga
7d13501b94 support pagination in webui preview
Former-commit-id: f2307e26b9c2ce5d60917cce5a9638466ea676c8
2023-11-02 21:21:45 +08:00
hiyouga
ac74639b32 fix webui
Former-commit-id: 9192948fa221c0275ddfa579ef6b3442d45b8962
2023-11-02 18:03:14 +08:00
hiyouga
12fa56ae68 support warning in webui
Former-commit-id: 9903b523fad2f0ec0e66c3d313823bd4674bfa2b
2023-11-02 17:57:04 +08:00
hiyouga
f11b863f4b fix #1349
Former-commit-id: 556c023eab2a68560b26a7d5318a79410fb0c700
2023-11-02 17:02:44 +08:00
hiyouga
f3e4b72957 fix #1356
Former-commit-id: d2ed436108a339d405dad1be1ca15baca3d6d3e4
2023-11-02 16:51:52 +08:00
hiyouga
8d52fb46ca fix #1325
Former-commit-id: 59f2cbbd52d4646fbd1ba83032bf522ecc49a50f
2023-11-01 23:38:49 +08:00
hiyouga
dab8f45033 fix chat
Former-commit-id: 68f2b3df09c4c8638b9e225fd5b8aed3541e97a0
2023-11-01 23:07:58 +08:00
hiyouga
bff8b02543 update gradio, support multiple resp in api
Former-commit-id: a34263e7c0e07a080276d164cdab9f12f1d767d2
2023-11-01 23:02:16 +08:00
hiyouga
2406200914 fix SFT trainer
Former-commit-id: bf09b6a6cd75cc2738d9af6b8c30bcbba77fa9b5
2023-10-31 21:52:52 +08:00
hiyouga
db06fcfc84 fix #1316
Former-commit-id: 88a753fe80e277007bac2264aee24024e18f2314
2023-10-31 11:32:08 +08:00
hiyouga
93b9f74e9f update projects
Former-commit-id: 33d58e9171ad2693b9d54715eb61a6f4326c59f4
2023-10-29 22:53:47 +08:00
hiyouga
33ec844f76 add projects
Former-commit-id: 495a68cd5962dd3b3af7e4a920d91ac25531a862
2023-10-29 22:07:13 +08:00
hiyouga
0f727b393e update constants
Former-commit-id: ebacbb1072045924a7e335cc9dda488d6f0be8b3
2023-10-29 13:30:20 +08:00
hiyouga
7da2aad6ee fix vicuna template
Former-commit-id: a98eda0803e4b73a24f12d848e14161451921e98
2023-10-27 22:15:25 +08:00
hiyouga
6f09f50d02 fix chatglm3 template
Former-commit-id: 69bcbc9f6c98e4f4ad97ec0306b33ab21923d311
2023-10-27 21:12:06 +08:00
hiyouga
5919832059 update readme
Former-commit-id: 6fb92c7088316c56ce8656e540fc47b0a5a1bf18
2023-10-27 19:19:03 +08:00
hiyouga
f7635c1afc support chatglm3
Former-commit-id: ba82e13bbeed3b262d301196b1860d73f319401d
2023-10-27 19:16:28 +08:00
hiyouga
c762168ed0 support dataset cache
Former-commit-id: f79ee62eb4a2a4a01cb4e2a6aa2d07158cf8eb59
2023-10-26 21:48:45 +08:00
hiyouga
67a46e553f fix #1287
Former-commit-id: d885aca472c6448bbf9a9e8d16bead92038825e3
2023-10-26 17:49:41 +08:00
hiyouga
e406f37b54 fix #1285
Former-commit-id: 2f8fe4439506e844b147fe38b5eb878c5748c31c
2023-10-26 16:34:52 +08:00
hiyouga
62fe877124 remove filter in preprocess
Former-commit-id: 9eac08b35fec47129a29c401ca265343f8388ab0
2023-10-23 23:46:02 +08:00
hiyouga
a0e682ba79 update neftune logic
Former-commit-id: bb4f0589ed23bf0236d3e918272ad64f0a05ef39
2023-10-22 17:42:13 +08:00
hiyouga
49e8a87383 fix webui
Former-commit-id: a5a5a7bc1f53d36e1b26e418999465903cb7d9ed
2023-10-22 17:24:56 +08:00
hiyouga
b2764b49ca add new options in webui
Former-commit-id: 6698b832dd9cc2d7d60be4fa5ab90e34a7e9d8e0
2023-10-22 17:17:58 +08:00
hiyouga
06b810de8f fix recursion error
Former-commit-id: c7938188c36a71a878bca982b7dd151195164986
2023-10-22 16:28:37 +08:00
hiyouga
6da51565f5 reimplement neftune
Former-commit-id: efe9e5a194d3a9f052701d904715238816e4c09e
2023-10-22 16:15:08 +08:00
hoshi-hiyouga
1f69965239 Merge pull request #1252 from anvie/neftune
add NEFTune optimization

Former-commit-id: 85d5c5fbe731f486c3e83812227fa05edc131487
2023-10-22 15:59:20 +08:00
anvie
af2d61178d add NEFTune optimization
Former-commit-id: 603e0298af64116ac07130fe6661a9ba823c186c
2023-10-21 13:24:10 +07:00
hiyouga
6a955ccf4f fix openchat template
Former-commit-id: 88b9b657bc50495ac4c42f64195fc652fe4ca3df
2023-10-21 01:25:42 +08:00
hiyouga
c0658711ca fix tokenizer padding side in evaluate.py
Former-commit-id: bcb43ff8ba1946c1f7e7865c9d0fb47ba276935d
2023-10-21 00:30:04 +08:00
hiyouga
d602f06882 fix #1232
Former-commit-id: 49975755d47344e362145c52548fdda8783f2c0c
2023-10-20 23:28:52 +08:00
hiyouga
1cb9a38ac2 fix #1215
Former-commit-id: d91b43a8afbea4859357f2224e3d9b9d71160e6d
2023-10-19 16:19:21 +08:00
hiyouga
47a1f73d0f fix #1218
Former-commit-id: b301f35bd4a3bf368159c8f5fb4e2736f922115b
2023-10-19 16:17:41 +08:00
hiyouga
142dd63b47 fix #1228
Former-commit-id: e4e0cae3f55da2f1b566c97dbfdd7fc5b7b728a4
2023-10-19 15:54:10 +08:00
hiyouga
b1bd8370c2 fix #1217
Former-commit-id: 065fc0a6f3f005bb87e1c5c126c8b6bb470ce700
2023-10-19 15:52:24 +08:00
hiyouga
215660c8da rename webui
Former-commit-id: 26feaf80fff6177d9eb4e28ad18feb6d34d3ea27
2023-10-16 15:16:24 +08:00
hiyouga
0cafe67efe fix #1197
Former-commit-id: 00100e23fcfef9587fda4cf01c62599d996e1176
2023-10-16 15:13:46 +08:00
hoshi-hiyouga
ea83b3222b Update README_zh.md
Former-commit-id: 3450404bb9a33c3bd4b45ac4afcf51062f8c7d1d
2023-10-16 00:28:27 +08:00
hoshi-hiyouga
725087a04f Update README.md
Former-commit-id: d84896597eded79f78224faed81cc9f2df222978
2023-10-16 00:23:37 +08:00
96 changed files with 3096 additions and 1681 deletions

128
CODE_OF_CONDUCT.md Normal file
View File

@@ -0,0 +1,128 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
`hoshihiyouga AT gmail DOT com`.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.

158
README.md
View File

@@ -6,20 +6,27 @@
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/e73gccsSd?compact=true&style=flat)](https://discord.gg/e73gccsSd)
[![Discord](https://dcbadge.vercel.app/api/server/c2EPEt5NU?compact=true&style=flat)](https://discord.gg/c2EPEt5NU)
[![Spaces](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
👋 Join our [WeChat](assets/wechat.jpg).
\[ English | [中文](README_zh.md) \]
## Example: Fine-tuning large language model within 10 minutes
## LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory
Launch an **all-in-one Web UI** via `python src/train_web.py`.
Preview LLaMA Board at **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)**.
Launch LLaMA Board via `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`. (multiple GPUs are not supported yet)
Here is an example of altering the self-cognition of an instruction-tuned language model within 10 minutes on a single GPU.
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846-2d88920d5ba1
## Changelog
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neft_alpha` argument to activate NEFTune, e.g., `--neft_alpha 5`.
[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `--shift_attn` argument to enable shift short attention.
[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models.
@@ -48,72 +55,98 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
| Model | Model size | Default module | Template |
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B | query_key_value | - |
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 |
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | falcon |
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B/14B | c_attn | chatml |
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | xverse |
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
| [Qwen](https://github.com/QwenLM/Qwen) | 7B/14B | c_attn | qwen |
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
> [!NOTE]
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules.
>
> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "chat" models.
Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list of models we supported.
## Supported Training Approaches
| Approach | Full-parameter | Partial-parameter | LoRA | QLoRA |
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| Reward Modeling | | | :white_check_mark: | :white_check_mark: |
| PPO Training | | | :white_check_mark: | :white_check_mark: |
| DPO Training | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
> [!NOTE]
> Use `--quantization_bit 4/8` argument to enable QLoRA.
## Provided Datasets
- For pre-training:
- [Wiki Demo (en)](data/wiki_demo.txt)
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
- For supervised fine-tuning:
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [Self-cognition (zh)](data/self_cognition.json)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- For reward modeling or DPO training:
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
<details><summary>Pre-training datasets</summary>
- [Wiki Demo (en)](data/wiki_demo.txt)
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
- [RedPajama V2 (en)](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2)
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
</details>
<details><summary>Supervised fine-tuning datasets</summary>
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Self-cognition (zh)](data/self_cognition.json)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
</details>
<details><summary>Preference datasets</summary>
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
</details>
Please refer to [data/README.md](data/README.md) for details.
@@ -129,9 +162,9 @@ huggingface-cli login
- Python 3.8+ and PyTorch 1.13.1+
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
- sentencepiece, protobuf and tiktoken
- fire, jieba, rouge-chinese and nltk (used at evaluation and predict)
- gradio and matplotlib (used in web_demo.py)
- uvicorn, fastapi and sse-starlette (used in api_demo.py)
- jieba, rouge-chinese and nltk (used at evaluation and predict)
- gradio and matplotlib (used in web UI)
- uvicorn, fastapi and sse-starlette (used in API)
And **powerful GPUs**!
@@ -139,7 +172,7 @@ And **powerful GPUs**!
### Data Preparation (optional)
Please refer to `data/example_dataset` for checking the details about the format of dataset files. You can either use a single `.json` file or a [dataset loading script](https://huggingface.co/docs/datasets/dataset_script) with multiple files to create a custom dataset.
Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use a single `.json` file or a [dataset loading script](https://huggingface.co/docs/datasets/dataset_script) with multiple files to create a custom dataset.
> [!NOTE]
> Please update `data/dataset_info.json` to use your custom dataset. About the format of this file, please refer to `data/README.md`.
@@ -160,17 +193,6 @@ If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you wi
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
```
### All-in-one Web UI
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
```
We **strongly recommend** using the all-in-one Web UI for newcomers since it can also generate training scripts automatically, even without a GPU environment.
> [!WARNING]
> Currently the web UI only supports training on **a single GPU**.
### Train on a single GPU
> [!IMPORTANT]
@@ -377,8 +399,7 @@ python src/export_model.py \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--export_dir path_to_export \
--fp16
--export_dir path_to_export
```
### API Demo
@@ -449,11 +470,18 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
> [!NOTE]
> We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit predict.
## Projects using LLaMA Factory
- **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
- **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
## License
This repository is licensed under the [Apache-2.0 License](LICENSE).
Please follow the model licenses to use the corresponding model weights: [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [Falcon](LICENSE) / [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/resolve/main/Baichuan%202%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [InternLM](https://github.com/InternLM/InternLM#open-source-license) / [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx)
Please follow the model licenses to use the corresponding model weights: [Baichuan](https://huggingface.co/baichuan-inc/Baichuan-13B-Base/resolve/main/Community%20License%20for%20Baichuan-13B%20Model.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/resolve/main/Community%20License%20for%20Baichuan2%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/LICENSE) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
## Citation

View File

@@ -6,20 +6,27 @@
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/e73gccsSd?compact=true&style=flat)](https://discord.gg/e73gccsSd)
[![Discord](https://dcbadge.vercel.app/api/server/c2EPEt5NU?compact=true&style=flat)](https://discord.gg/c2EPEt5NU)
[![Spaces](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
👋 加入我们的[微信群](assets/wechat.jpg)。
\[ [English](README.md) | 中文 \]
## 示例:在十分钟内微调一个大模型
## LLaMA Board: 通过一站式网页界面快速上手 LLaMA Factory
通过 `python src/train_web.py` 开启**训练推理一体化界面**
通过 **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** 预览 LLaMA Board
使用 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 启动 LLaMA Board。该模式目前仅支持单卡训练
下面是使用单张 GPU 在 10 分钟内更改对话式大型语言模型自我认知的示例。
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846-2d88920d5ba1
## 更新日志
[23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neft_alpha` 参数启用 NEFTune例如 `--neft_alpha 5`
[23/09/27] 我们针对 LLaMA 模型支持了 [LongLoRA](https://github.com/dvlab-research/LongLoRA) 提出的 **$S^2$-Attn**。请使用 `--shift_attn` 参数以启用该功能。
[23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。
@@ -34,7 +41,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft))。
[23/07/18] 我们开发了支持训练和测试的**浏览器一体化界面**。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
[23/07/18] 我们开发了支持训练和测试的**浏览器一体化界面**。请使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
[23/07/09] 我们开源了 **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。
@@ -42,80 +49,106 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
[23/06/22] 我们对齐了[示例 API](src/api_demo.py) 与 [OpenAI API](https://platform.openai.com/docs/api-reference/chat) 的格式,您可以将微调模型接入**任意基于 ChatGPT 的应用**中。
[23/06/03] 我们实现了 4 比特的 LoRA 训练(也称 **[QLoRA](https://github.com/artidoro/qlora)**)。请尝试使用 `--quantization_bit 4` 参数进行 4 比特量化微调。
[23/06/03] 我们实现了 4 比特的 LoRA 训练(也称 **[QLoRA](https://github.com/artidoro/qlora)**)。请使用 `--quantization_bit 4` 参数进行 4 比特量化微调。
## 模型
| 模型名 | 模型大小 | 默认模块 | Template |
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B | query_key_value | - |
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 |
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | falcon |
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B/14B | c_attn | chatml |
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | xverse |
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
| [Qwen](https://github.com/QwenLM/Qwen) | 7B/14B | c_attn | qwen |
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
> [!NOTE]
> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块。
>
> 对于所有“基座”Base模型`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Chat模型请务必使用**对应的模板**。
项目所支持模型的完整列表请参阅 [constants.py](src/llmtuner/extras/constants.py)。
## 训练方法
| 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA |
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
| 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| 奖励模型训练 | | | :white_check_mark: | :white_check_mark: |
| PPO 训练 | | | :white_check_mark: | :white_check_mark: |
| DPO 训练 | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
| 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
> [!NOTE]
> 请使用 `--quantization_bit 4/8` 参数来启用 QLoRA 训练。
## 数据集
- 用于预训练:
- [Wiki Demo (en)](data/wiki_demo.txt)
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
- 用于指令监督微调:
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [Self-cognition (zh)](data/self_cognition.json)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- 用于训练奖励模型或 DPO 训练:
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
<details><summary>预训练数据集</summary>
使用方法请参考 [data/README.md](data/README_zh.md) 文件。
- [Wiki Demo (en)](data/wiki_demo.txt)
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
- [RedPajama V2 (en)](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2)
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
</details>
<details><summary>指令微调数据集</summary>
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Self-cognition (zh)](data/self_cognition.json)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
</details>
<details><summary>偏好数据集</summary>
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
</details>
使用方法请参考 [data/README_zh.md](data/README_zh.md) 文件。
部分数据集的使用需要确认,我们推荐使用下述命令登录您的 Hugging Face 账户。
@@ -129,7 +162,7 @@ huggingface-cli login
- Python 3.8+ 和 PyTorch 1.13.1+
- 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL
- sentencepiece, protobuf 和 tiktoken
- fire, jieba, rouge-chinese 和 nltk (用于评估及预测)
- jieba, rouge-chinese 和 nltk (用于评估及预测)
- gradio 和 matplotlib (用于网页端交互)
- uvicorn, fastapi 和 sse-starlette (用于 API)
@@ -139,10 +172,10 @@ huggingface-cli login
### 数据准备(可跳过)
关于数据集文件的格式,请参考 `data/example_dataset` 文件夹的内容。构建自定义数据集时,既可以使用单个 `.json` 文件,也可以使用一个[数据加载脚本](https://huggingface.co/docs/datasets/dataset_script)和多个文件。
关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。构建自定义数据集时,既可以使用单个 `.json` 文件,也可以使用一个[数据加载脚本](https://huggingface.co/docs/datasets/dataset_script)和多个文件。
> [!NOTE]
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件,该文件的格式请参考 `data/README.md`。
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件,该文件的格式请参考 `data/README_zh.md`。
### 环境搭建(可跳过)
@@ -160,17 +193,6 @@ pip install -r requirements.txt
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
```
### 浏览器一体化界面
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
```
我们**极力推荐**新手使用浏览器一体化界面,因为它还可以不依赖 GPU 环境自动生成在 GPU 上运行的命令行脚本。
> [!WARNING]
> 目前网页 UI 仅支持**单卡训练**。
### 单 GPU 训练
> [!IMPORTANT]
@@ -376,8 +398,7 @@ python src/export_model.py \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_export \
--fp16
--export_dir path_to_export
```
### API 服务
@@ -448,11 +469,18 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
> [!NOTE]
> 我们建议在量化模型的预测中使用 `--per_device_eval_batch_size=1` 和 `--max_target_length 128`。
## 使用了 LLaMA Factory 的项目
- **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
- **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
## 协议
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
使用模型权重时,请遵循对应的模型协议:[LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [Falcon](LICENSE) / [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/resolve/main/Baichuan%202%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [InternLM](https://github.com/InternLM/InternLM#open-source-license) / [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx)
使用模型权重时,请遵循对应的模型协议:[Baichuan](https://huggingface.co/baichuan-inc/Baichuan-13B-Base/resolve/main/Community%20License%20for%20Baichuan-13B%20Model.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/resolve/main/Community%20License%20for%20Baichuan2%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/LICENSE) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
## 引用

View File

@@ -2,31 +2,106 @@ If you are using a custom dataset, please provide your dataset definition in the
```json
"dataset_name": {
"hf_hub_url": "the name of the dataset repository on the HuggingFace hub. (if specified, ignore below 3 arguments)",
"hf_hub_url": "the name of the dataset repository on the Hugging Face hub. (if specified, ignore below 3 arguments)",
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)",
"file_name": "the name of the dataset file in the this directory. (required if above are not specified)",
"file_sha1": "the SHA-1 hash value of the dataset file. (optional)",
"ranking": "whether the examples contains ranked responses or not. (default: false)",
"file_sha1": "the SHA-1 hash value of the dataset file. (optional, does not affect training)",
"subset": "the name of the subset. (optional, default: None)",
"ranking": "whether the dataset is a preference dataset or not. (default: false)",
"formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
"columns": {
"prompt": "the name of the column in the datasets containing the prompts. (default: instruction)",
"query": "the name of the column in the datasets containing the queries. (default: input)",
"response": "the name of the column in the datasets containing the responses. (default: output)",
"history": "the name of the column in the datasets containing the history of chat. (default: None)"
"prompt": "the column name in the dataset containing the prompts. (default: instruction, for alpaca)",
"query": "the column name in the dataset containing the queries. (default: input, for alpaca)",
"response": "the column name in the dataset containing the responses. (default: output, for alpaca)",
"history": "the column name in the dataset containing the histories. (default: None, for alpaca)",
"messages": "the column name in the dataset containing the messages. (default: conversations, for sharegpt)",
"role": "the key in the message represents the identity. (default: from, for sharegpt)",
"content": "the key in the message represents the content. (default: value, for sharegpt)"
}
}
```
where the `prompt` and `response` columns should contain non-empty values. The `query` column will be concatenated with the `prompt` column and used as input for the model. The `history` column should contain a list where each element is a string tuple representing a query-response pair.
Given above, you can use the custom dataset via specifying `--dataset dataset_name`.
For datasets used in reward modeling or DPO training, the `response` column should be a string list, with the preferred answers appearing first, for example:
Currently we support dataset in **alpaca** or **sharegpt** format, the dataset in alpaca format should follow the below format:
```json
[
{
"instruction": "user instruction (required)",
"input": "user input (optional)",
"output": "model response (required)",
"history": [
["user instruction in the first round (optional)", "model response in the first round (optional)"],
["user instruction in the second round (optional)", "model response in the second round (optional)"]
]
}
]
```
Regarding the above dataset, the `columns` in `dataset_info.json` should be:
```json
"dataset_name": {
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
}
}
```
where the `prompt` and `response` columns should contain non-empty values, represent instruction and response respectively. The `query` column will be concatenated with the `prompt` column and used as input for the model.
The `history` column is a list consisting string tuples representing query-response pairs in history. Note that the responses **in each round will be used for training**.
For the pre-training datasets, only the `prompt` column will be used for training.
For the preference datasets, the `response` column should be a string list whose length is 2, with the preferred answers appearing first, for example:
```json
{
"instruction": "Question",
"input": "",
"instruction": "user instruction",
"input": "user input",
"output": [
"Chosen answer",
"Rejected answer"
"chosen answer",
"rejected answer"
]
}
```
The dataset in sharegpt format should follow the below format:
```json
[
{
"conversations": [
{
"from": "human",
"value": "user instruction"
},
{
"from": "gpt",
"value": "model response"
}
]
}
]
```
Regarding the above dataset, the `columns` in `dataset_info.json` should be:
```json
"dataset_name": {
"columns": {
"messages": "conversations",
"role": "from",
"content": "value"
}
}
```
where the `messages` column should be a list whose length is even, and follow the `u/a/u/a/u/a` order.
Pre-training datasets and preference datasets are incompatible with the sharegpt format yet.

View File

@@ -1,32 +1,107 @@
如果您使用自定义数据集,请务必在 `dataset_info.json` 文件中以下格式提供您的数据集定义。
如果您使用自定义数据集,请务必在 `dataset_info.json` 文件中按照以下格式提供数据集定义。
```json
"数据集名称": {
"hf_hub_url": "HuggingFace上的项目地址若指定则忽略下列三个参数",
"hf_hub_url": "Hugging Face 上的项目地址(若指定,则忽略下列三个参数)",
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
"file_sha1": "数据集文件的SHA-1哈希值可选",
"ranking": "数据集是否包含排序后的回答默认false",
"file_sha1": "数据集文件的SHA-1哈希值可选,留空不影响训练",
"subset": "数据集子集的名称可选默认None",
"ranking": "是否为偏好数据集可选默认False",
"formatting": "数据集格式可选默认alpaca可以为 alpaca 或 sharegpt",
"columns": {
"prompt": "数据集代表提示词的表头名称默认instruction",
"query": "数据集代表请求的表头名称默认input",
"response": "数据集代表回答的表头名称默认output",
"history": "数据集代表历史对话的表头名称默认None"
"prompt": "数据集代表提示词的表头名称默认instruction,用于 alpaca 格式",
"query": "数据集代表请求的表头名称默认input,用于 alpaca 格式",
"response": "数据集代表回答的表头名称默认output,用于 alpaca 格式",
"history": "数据集代表历史对话的表头名称默认None,用于 alpaca 格式)",
"messages": "数据集代表消息列表的表头名称默认conversations用于 sharegpt 格式)",
"role": "消息中代表发送者身份的键名默认from用于 sharegpt 格式)",
"content": "消息中代表文本内容的键名默认value用于 sharegpt 格式)"
}
}
```
其中 `prompt``response` 列应当是非空的字符串。`query` 列的内容将会和 `prompt` 列拼接作为模型输入。`history` 列应当是一个列表,其中每个元素是一个字符串二元组,分别代表用户请求和模型答复
添加后可通过指定 `--dataset 数据集名称` 参数使用自定义数据集
对于训练奖励模型或 DPO 训练的数据集,`response` 列应当是一个字符串列表,排在前面的代表更优的答案,例如
该项目目前支持两种格式的数据集:**alpaca** 和 **sharegpt**,其中 alpaca 格式的数据集按照以下方式组织
```json
[
{
"instruction": "用户指令(必填)",
"input": "用户输入(选填)",
"output": "模型回答(必填)",
"history": [
["第一轮指令(选填)", "第一轮回答(选填)"],
["第二轮指令(选填)", "第二轮回答(选填)"]
]
}
]
```
对于上述格式的数据,`dataset_info.json` 中的 `columns` 应为:
```json
"数据集名称": {
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
}
}
```
其中 `prompt``response` 列应当是非空的字符串,分别代表用户指令和模型回答。`query` 列的内容将会和 `prompt` 列拼接作为模型输入。
`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意每轮的模型回答**均会被用于训练**。
对于预训练数据集,仅 `prompt` 列中的内容会用于模型训练。
对于偏好数据集,`response` 列应当是一个长度为 2 的字符串列表,排在前面的代表更优的回答,例如:
```json
{
"instruction": "Question",
"input": "",
"instruction": "用户指令",
"input": "用户输入",
"output": [
"Chosen answer",
"Rejected answer"
"优质回答",
"劣质回答"
]
}
```
而 sharegpt 格式的数据集按照以下方式组织:
```json
[
{
"conversations": [
{
"from": "human",
"value": "用户指令"
},
{
"from": "gpt",
"value": "模型回答"
}
]
}
]
```
对于上述格式的数据,`dataset_info.json` 中的 `columns` 应为:
```json
"数据集名称": {
"columns": {
"messages": "conversations",
"role": "from",
"content": "value"
}
}
```
其中 `messages` 列必须为偶数长度的列表,且符合 `用户/模型/用户/模型/用户/模型` 的顺序。
预训练数据集和偏好数据集尚不支持 sharegpt 格式。

View File

@@ -1,6 +1,5 @@
import json
import datasets
from typing import Any, Dict, List
_DESCRIPTION = "BELLE multiturn chat dataset."
@@ -23,11 +22,9 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
def _info(self) -> datasets.DatasetInfo:
def _info(self):
features = datasets.Features({
"instruction": datasets.Value("string"),
"output": datasets.Value("string"),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]
})
return datasets.DatasetInfo(
description=_DESCRIPTION,
@@ -37,7 +34,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
def _split_generators(self, dl_manager: datasets.DownloadManager):
file_path = dl_manager.download(_URL)
return [
datasets.SplitGenerator(
@@ -48,10 +45,11 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
)
]
def _generate_examples(self, filepath: str) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat with history
def _generate_examples(self, filepath: str):
with open(filepath, "r", encoding="utf-8") as f:
for key, row in enumerate(f):
data = json.loads(row)
conversations = []
prompt = data["instruction"].strip()
response = data["output"].strip()
@@ -59,7 +57,8 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
human_idx = prompt.rfind("Human:")
query = prompt[human_idx+6:assist_idx].strip()
prompt = prompt[:human_idx].strip()
history = []
conversations.insert(0, {"from": "gpt", "value": response})
conversations.insert(0, {"from": "human", "value": query})
while prompt.rfind("Assistant:") != -1:
assist_idx = prompt.rfind("Assistant:")
@@ -67,13 +66,10 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
if human_idx != -1:
old_query = prompt[human_idx+6:assist_idx].strip()
old_resp = prompt[assist_idx+10:].strip()
history.insert(0, (old_query, old_resp))
conversations.insert(0, {"from": "gpt", "value": old_resp})
conversations.insert(0, {"from": "human", "value": old_query})
else:
break
prompt = prompt[:human_idx].strip()
yield key, {
"instruction": query,
"output": response,
"history": history
}
yield key, {"conversations": conversations}

View File

@@ -3,7 +3,7 @@ import datasets
from typing import Any, Dict, List
_DESCRIPTION = "An example of dataset for LLaMA."
_DESCRIPTION = "An example of dataset."
_CITATION = ""
_HOMEPAGE = ""
_LICENSE = ""

View File

@@ -1,9 +1,9 @@
import json
import datasets
from typing import Any, Dict, List
from typing import List
_DESCRIPTION = "Human preference data about helpfulness and harmlessness for ChatGLM."
_DESCRIPTION = "Human preference data about helpfulness and harmlessness."
_CITATION = ""
_HOMEPAGE = "https://huggingface.co/datasets/Anthropic/hh-rlhf"
_LICENSE = "mit"
@@ -42,7 +42,7 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
def _split_generators(self, dl_manager: datasets.DownloadManager):
file_path = dl_manager.download_and_extract(_URLS)
return [
datasets.SplitGenerator(
@@ -59,7 +59,7 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
)
]
def _generate_examples(self, filepaths: List[str]) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat for ChatGLM
def _generate_examples(self, filepaths: List[str]):
key = 0
for filepath in filepaths:
with open(filepath, "r", encoding="utf-8") as f:

View File

@@ -1,6 +1,6 @@
import json
import datasets
from typing import Any, Dict, List
from typing import List
_DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data."
@@ -21,15 +21,13 @@ _LICENSE = "cc-by-nc-4.0"
_BASE_DATA_URL = "https://huggingface.co/datasets/stingning/ultrachat/resolve/main/train_{idx}.jsonl"
class BelleMultiturn(datasets.GeneratorBasedBuilder):
class UltraChat(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
def _info(self) -> datasets.DatasetInfo:
def _info(self):
features = datasets.Features({
"instruction": datasets.Value("string"),
"output": datasets.Value("string"),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]
})
return datasets.DatasetInfo(
description=_DESCRIPTION,
@@ -39,8 +37,8 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(9)] # multiple shards
def _split_generators(self, dl_manager: datasets.DownloadManager):
file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(10)] # multiple shards
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
@@ -50,7 +48,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
)
]
def _generate_examples(self, filepaths: List[str]) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat for ChatGLM
def _generate_examples(self, filepaths: List[str]):
for filepath in filepaths:
with open(filepath, "r", encoding="utf-8") as f:
for row in f:
@@ -58,19 +56,14 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
data = json.loads(row)
except:
continue
key = data["id"]
content = data["data"]
key: int = data["id"]
content: List[str] = data["data"]
if len(content) % 2 == 1:
content.pop(-1)
if len(content) < 2:
continue
query = content[-2]
response = content[-1]
history = [[content[2*i], content[2*i+1]] for i in range(len(content) // 2 - 1)]
yield key, {
"instruction": query,
"output": response,
"history": history
}
conversations = [{
"from": "human" if i % 2 == 0 else "gpt",
"value": content[i]
} for i in range(len(content))]
yield key, {"conversations": conversations}

View File

@@ -1,20 +1,19 @@
torch>=1.13.1
transformers>=4.31.0
datasets>=2.12.0
transformers>=4.31.0,<4.35.0
datasets>=2.14.0
accelerate>=0.21.0
peft>=0.4.0
trl>=0.7.1
peft>=0.6.0
trl>=0.7.4
gradio>=3.38.0,<4.0.0
scipy
sentencepiece
protobuf
tiktoken
fire
jieba
rouge-chinese
nltk
gradio==3.38.0
uvicorn
pydantic==1.10.11
fastapi==0.95.1
pydantic
fastapi
sse-starlette
matplotlib

View File

@@ -6,8 +6,8 @@ from llmtuner import ChatModel, create_app
def main():
chat_model = ChatModel()
app = create_app(chat_model)
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
print("Visit http://localhost:8000/docs for API document.")
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
if __name__ == "__main__":

View File

@@ -1,4 +1,12 @@
from llmtuner import ChatModel
from llmtuner.extras.misc import torch_gc
try:
import platform
if platform.system() != "Windows":
import readline
except ImportError:
print("Install `readline` for a better experience.")
def main():
@@ -20,6 +28,7 @@ def main():
if query.strip() == "clear":
history = []
torch_gc()
print("History has been removed.")
continue

View File

@@ -1,185 +1,10 @@
# coding=utf-8
# Evaluates the performance of pre-trained models.
# Usage: python evaluate.py --model_name_or_path path_to_model --checkpoint_dir path_to_ckpt --template vanilla
# --task ceval --split validation --lang zh --n_shot 5 --batch_size 4 --save_name result
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
import os
import fire
import json
import torch
import numpy as np
from collections import Counter
from datasets import load_dataset
from dataclasses import dataclass
from tqdm import tqdm, trange
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
from llmtuner import ChatModel
if TYPE_CHECKING:
from datasets import Dataset
from llmtuner import Evaluator
choices = ["A", "B", "C", "D"]
@dataclass
class EvalTemplate:
system: str
choice: str
answer: str
prefix: str
def parse_example(
self,
example: Dict[str, str]
) -> Tuple[str, str]:
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in choices if ch in example]
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
def format_example(
self,
target_data: Dict[str, str],
support_set: "Dataset",
subject_name: str,
use_history: bool
) -> Tuple[str, str, List[Tuple[str, str]]]:
query, resp = self.parse_example(target_data)
history = [self.parse_example(support_set[k]) for k in range(len(support_set))]
if len(history):
temp = history.pop(0)
history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1]))
else:
query = self.system.format(subject=subject_name) + query
if not use_history:
query = "\n\n".join(["".join(item) for item in history] + [query])
history = []
return query.strip(), resp, history
eval_templates = {
"en": EvalTemplate(
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
choice="\n{choice}. {content}",
answer="\nAnswer: ",
prefix=" "
),
"zh": EvalTemplate(
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
choice="\n{choice}. {content}",
answer="\n答案:",
prefix="\n"
)
}
@torch.inference_mode()
def batch_inference(
chat_model: ChatModel,
batch_input: Dict[str, torch.Tensor],
prefix_char: str
) -> List[str]:
logits = chat_model.model(**batch_input).logits
probs = torch.nn.functional.softmax(
torch.stack(
[
logits[:, -1, chat_model.tokenizer.encode(prefix_char + choice, add_special_tokens=False)[-1]]
for choice in choices
],
dim=-1
),
dim=-1
).detach()
return [chr(ord("A") + offset.item()) for offset in torch.argmax(probs, dim=-1)]
def evaluate(
model_name_or_path: str,
finetuning_type: Optional[str] = "lora",
checkpoint_dir: Optional[str] = None,
template: Optional[str] = "vanilla",
task: Optional[str] = "ceval",
dataset_dir: Optional[str] = "evaluation",
split: Optional[Literal["validation", "test"]] = "validation",
lang: Optional[Literal["zh", "en"]] = "zh",
n_shot: Optional[int] = 5,
n_avg: Optional[int] = 1,
batch_size: Optional[int] = 4,
save_name: Optional[str] = None
):
with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f)
chat_model = ChatModel(dict(
model_name_or_path=model_name_or_path,
finetuning_type=finetuning_type,
checkpoint_dir=checkpoint_dir,
template=template
))
eval_template = eval_templates[lang]
assert chat_model.tokenizer.padding_side == "left", "only left-padded tensor can be accepted."
category_corrects: Dict[str, np.ndarray] = {
subj: np.array([], dtype="bool") for subj in ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
}
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
results = {}
for subject in pbar:
dataset = load_dataset(os.path.join(dataset_dir, task), subject)
labels, answers, all_outputs = [], [], []
for epoch in range(n_avg):
pbar.set_postfix_str("{} Trial: {}".format(categorys[subject]["name"], epoch))
inputs, outputs = [], []
for i in trange(len(dataset[split]), desc="Formatting batches", position=1, leave=False):
support_set = dataset["train"].shuffle().select(range(min(n_shot, len(dataset["train"]))))
query, resp, history = eval_template.format_example(
target_data=dataset[split][i],
support_set=support_set,
subject_name=categorys[subject]["name"],
use_history=chat_model.template.use_history
)
input_ids, _ = chat_model.template.encode_oneturn(
tokenizer=chat_model.tokenizer, query=query, resp=resp, history=history
)
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
if epoch == 0:
labels.append(resp)
for i in trange(0, len(inputs), batch_size, desc="Predicting batches", position=1, leave=False):
batch_input = chat_model.tokenizer.pad(
inputs[i : i + batch_size], return_attention_mask=True, return_tensors="pt"
).to(chat_model.model.device)
preds = batch_inference(chat_model, batch_input, eval_template.prefix)
outputs += preds
all_outputs.append(outputs)
for i in range(len(all_outputs[0])):
count = Counter([all_outputs[epoch][i] for epoch in range(n_avg)])
answers.append(count.most_common(1)[0][0])
corrects = (np.array(answers) == np.array(labels))
category_name = categorys[subject]["category"]
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
results[subject] = {str(i): answers[i] for i in range(len(answers))}
score_info = "\n".join([
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
for category_name, category_correct in category_corrects.items() if len(category_correct)
])
print(score_info)
if save_name is not None:
with open(save_name + ".json", "w", encoding="utf-8", newline="\n") as f:
json.dump(results, f, indent=2)
with open(save_name + ".log", "w", encoding="utf-8", newline="\n") as f:
f.write(score_info)
def main():
evaluator = Evaluator()
evaluator.eval()
if __name__ == "__main__":
fire.Fire(evaluate)
main()

View File

@@ -1,9 +1,10 @@
# Level: api, webui > chat > tuner > dsets > extras, hparams
# Level: api, webui > chat, eval, train > data, model > extras, hparams
from llmtuner.api import create_app
from llmtuner.chat import ChatModel
from llmtuner.tuner import export_model, run_exp
from llmtuner.eval import Evaluator
from llmtuner.train import export_model, run_exp
from llmtuner.webui import create_ui, create_web_demo
__version__ = "0.2.0"
__version__ = "0.3.0"

View File

@@ -1,12 +1,8 @@
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from sse_starlette import EventSourceResponse
import json
from typing import List, Tuple
from pydantic import BaseModel
from contextlib import asynccontextmanager
from llmtuner.extras.misc import torch_gc
from llmtuner.chat import ChatModel
from llmtuner.api.protocol import (
Role,
Finish,
@@ -21,15 +17,40 @@ from llmtuner.api.protocol import (
ChatCompletionResponseStreamChoice,
ChatCompletionResponseUsage
)
from llmtuner.chat import ChatModel
from llmtuner.extras.misc import torch_gc
from llmtuner.extras.packages import (
is_fastapi_availble, is_starlette_available, is_uvicorn_available
)
if is_fastapi_availble():
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
if is_starlette_available():
from sse_starlette import EventSourceResponse
if is_uvicorn_available():
import uvicorn
@asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory
async def lifespan(app: "FastAPI"): # collects GPU memory
yield
torch_gc()
def create_app(chat_model: ChatModel) -> FastAPI:
def to_json(data: BaseModel) -> str:
try: # pydantic v2
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
except: # pydantic v1
return data.json(exclude_unset=True, ensure_ascii=False)
def create_app(chat_model: "ChatModel") -> "FastAPI":
app = FastAPI(lifespan=lifespan)
app.add_middleware(
@@ -45,14 +66,14 @@ def create_app(chat_model: ChatModel) -> FastAPI:
model_card = ModelCard(id="gpt-3.5-turbo")
return ModelList(data=[model_card])
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
async def create_chat_completion(request: ChatCompletionRequest):
if len(request.messages) < 1 or request.messages[-1].role != Role.USER:
raise HTTPException(status_code=400, detail="Invalid request")
if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
query = request.messages[-1].content
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
if len(prev_messages) and prev_messages[0].role == Role.SYSTEM:
system = prev_messages.pop(0).content
else:
system = None
@@ -62,32 +83,42 @@ def create_app(chat_model: ChatModel) -> FastAPI:
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
history.append([prev_messages[i].content, prev_messages[i+1].content])
else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
if request.stream:
generate = predict(query, history, system, request)
return EventSourceResponse(generate, media_type="text/event-stream")
response, (prompt_length, response_length) = chat_model.chat(
responses = chat_model.chat(
query, history, system,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
max_new_tokens=request.max_tokens
max_new_tokens=request.max_tokens,
num_return_sequences=request.n
)
prompt_length, response_length = 0, 0
choices = []
for i, response in enumerate(responses):
choices.append(ChatCompletionResponseChoice(
index=i,
message=ChatMessage(role=Role.ASSISTANT, content=response.response_text),
finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
))
prompt_length = response.prompt_length
response_length += response.response_length
usage = ChatCompletionResponseUsage(
prompt_tokens=prompt_length,
completion_tokens=response_length,
total_tokens=prompt_length+response_length
)
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role=Role.ASSISTANT, content=response),
finish_reason=Finish.STOP
)
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
choice_data = ChatCompletionResponseStreamChoice(
@@ -96,7 +127,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield chunk.json(exclude_unset=True, ensure_ascii=False)
yield to_json(chunk)
for new_text in chat_model.stream_chat(
query, history, system,
@@ -114,7 +145,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield chunk.json(exclude_unset=True, ensure_ascii=False)
yield to_json(chunk)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
@@ -122,7 +153,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
finish_reason=Finish.STOP
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield chunk.json(exclude_unset=True, ensure_ascii=False)
yield to_json(chunk)
yield "[DONE]"
return app

View File

@@ -20,9 +20,6 @@ class ModelCard(BaseModel):
object: Optional[str] = "model"
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
owned_by: Optional[str] = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = []
class ModelList(BaseModel):

View File

@@ -1 +1 @@
from llmtuner.chat.stream_chat import ChatModel
from llmtuner.chat.chat_model import ChatModel

View File

@@ -1,11 +1,21 @@
import torch
from typing import Any, Dict, Generator, List, Optional, Tuple
from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
from threading import Thread
from transformers import GenerationConfig, TextIteratorStreamer
from llmtuner.extras.misc import dispatch_model, get_logits_processor
from llmtuner.extras.template import get_template_and_fix_tokenizer
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.extras.misc import get_logits_processor
from llmtuner.model import dispatch_model, get_infer_args, load_model_and_tokenizer
@dataclass
class Response:
response_text: str
response_length: int
prompt_length: int
finish_reason: Literal["stop", "length"]
class ChatModel:
@@ -18,7 +28,7 @@ class ChatModel:
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
self.system_prompt = data_args.system_prompt
def process_args(
def _process_args(
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
@@ -26,17 +36,17 @@ class ChatModel:
**input_kwargs
) -> Tuple[Dict[str, Any], int]:
system = system or self.system_prompt
prompt, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
)
prompt_length = len(prompt)
input_ids = torch.tensor([prompt], device=self.model.device)
prompt_length = len(input_ids[0])
do_sample = input_kwargs.pop("do_sample", None)
temperature = input_kwargs.pop("temperature", None)
top_p = input_kwargs.pop("top_p", None)
top_k = input_kwargs.pop("top_k", None)
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
max_length = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
@@ -47,11 +57,15 @@ class ChatModel:
temperature=temperature or generating_args["temperature"],
top_p=top_p or generating_args["top_p"],
top_k=top_k or generating_args["top_k"],
num_return_sequences=num_return_sequences or 1,
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
pad_token_id=self.tokenizer.pad_token_id
))
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
generating_args["do_sample"] = True
if max_length:
generating_args.pop("max_new_tokens", None)
generating_args["max_length"] = max_length
@@ -75,13 +89,30 @@ class ChatModel:
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None,
**input_kwargs
) -> Tuple[str, Tuple[int, int]]:
gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs)
generation_output = self.model.generate(**gen_kwargs)
outputs = generation_output.tolist()[0][prompt_length:]
response = self.tokenizer.decode(outputs, skip_special_tokens=True)
response_length = len(outputs)
return response, (prompt_length, response_length)
) -> List[Response]:
r"""
Args: query, history, system, **input_kwargs
Returns: [(response_text, prompt_length, response_length)] * n (default n=1)
"""
gen_kwargs, prompt_length = self._process_args(query, history, system, **input_kwargs)
generate_output = self.model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:]
response = self.tokenizer.batch_decode(
response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
results = []
for i in range(len(response)):
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
results.append(Response(
response_text=response[i],
response_length=response_length,
prompt_length=prompt_length,
finish_reason="stop" if len(eos_index) else "length"
))
return results
@torch.inference_mode()
def stream_chat(
@@ -91,7 +122,7 @@ class ChatModel:
system: Optional[str] = None,
**input_kwargs
) -> Generator[str, None, None]:
gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs)
gen_kwargs, _ = self._process_args(query, history, system, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer

View File

@@ -0,0 +1,4 @@
from llmtuner.data.loader import get_dataset
from llmtuner.data.preprocess import preprocess_dataset
from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.data.utils import split_dataset

145
src/llmtuner/data/loader.py Normal file
View File

@@ -0,0 +1,145 @@
import os
from typing import TYPE_CHECKING, Any, Dict, List, Union
from datasets import concatenate_datasets, interleave_datasets, load_dataset
from llmtuner.data.utils import checksum, EXT2TYPE
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from llmtuner.hparams import ModelArguments, DataArguments
logger = get_logger(__name__)
def get_dataset(
model_args: "ModelArguments",
data_args: "DataArguments"
) -> Union["Dataset", "IterableDataset"]:
max_samples = data_args.max_samples
all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
for dataset_attr in data_args.dataset_list:
logger.info("Loading dataset {}...".format(dataset_attr))
if dataset_attr.load_from == "hf_hub":
data_path = dataset_attr.dataset_name
data_name = dataset_attr.subset
data_files = None
elif dataset_attr.load_from == "script":
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
data_name = dataset_attr.subset
data_files = None
elif dataset_attr.load_from == "file":
data_path, data_name = None, None
data_files: List[str] = []
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is directory
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
if data_path is None:
data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
else:
assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical."
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is file
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None)
else:
raise ValueError("File not found.")
assert data_path, "File extension must be txt, csv, json or jsonl."
checksum(data_files, dataset_attr.dataset_sha1)
else:
raise NotImplementedError
dataset = load_dataset(
path=data_path,
name=data_name,
data_files=data_files,
split=data_args.split,
cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token,
streaming=data_args.streaming
)
if max_samples is not None: # truncate dataset
dataset = dataset.select(range(min(len(dataset), max_samples)))
def convert_format(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
# convert dataset from sharegpt format to alpaca format
outputs = {"prompt": [], "query": [], "response": [], "history": []}
for msg_list in examples[dataset_attr.messages]:
msg_list = msg_list[:len(msg_list) // 2 * 2] # should be multiples of 2
if len(msg_list) == 0:
continue
msg_pairs = []
user_role, assistant_role = None, None
for idx in range(0, len(msg_list), 2):
if user_role is None and assistant_role is None:
user_role = msg_list[idx][dataset_attr.role]
assistant_role = msg_list[idx + 1][dataset_attr.role]
else:
if (
msg_list[idx][dataset_attr.role] != user_role
or msg_list[idx+1][dataset_attr.role] != assistant_role
):
raise ValueError("Only accepts conversation in u/a/u/a/u/a order.")
msg_pairs.append((msg_list[idx][dataset_attr.content], msg_list[idx + 1][dataset_attr.content]))
if len(msg_pairs) != 0:
outputs["prompt"].append(msg_pairs[-1][0])
outputs["query"].append("")
outputs["response"].append(msg_pairs[-1][1])
outputs["history"].append(msg_pairs[:-1])
return outputs
if dataset_attr.formatting == "sharegpt": # convert format
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Converting format of dataset"
)
dataset = dataset.map(
convert_format,
batched=True,
remove_columns=column_names,
**kwargs
)
else:
for column_name in ["prompt", "query", "response", "history"]: # align dataset
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
if dataset_attr.system_prompt: # add system prompt
system_prompt = dataset_attr.system_prompt
if data_args.streaming:
dataset = dataset.map(lambda _: {"system": system_prompt})
else:
dataset = dataset.add_column("system", [system_prompt] * len(dataset))
all_datasets.append(dataset)
if len(data_args.dataset_list) == 1:
return all_datasets[0]
elif data_args.mix_strategy == "concat":
if data_args.streaming:
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming:
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
return interleave_datasets(
datasets=all_datasets,
probabilities=data_args.interleave_probs,
seed=data_args.seed,
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
)
else:
raise ValueError("Unknown mixing strategy.")

View File

@@ -1,9 +1,13 @@
import os
import tiktoken
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union
from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Union
from datasets import load_from_disk
from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.template import get_template_and_fix_tokenizer
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
@@ -12,6 +16,25 @@ if TYPE_CHECKING:
from llmtuner.hparams import DataArguments
logger = get_logger(__name__)
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
for i in range(len(examples["prompt"])):
query, response = examples["prompt"][i], examples["response"][i]
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
history = examples["history"][i] if "history" in examples else None
system = examples["system"][i] if "system" in examples else None
yield query, response, history, system
def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]:
max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, data_args.reserved_label_len)
max_source_len = data_args.cutoff_len - max_target_len
return max_source_len, max_target_len
def preprocess_dataset(
dataset: Union["Dataset", "IterableDataset"],
tokenizer: "PreTrainedTokenizer",
@@ -19,21 +42,12 @@ def preprocess_dataset(
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"]
) -> Union["Dataset", "IterableDataset"]:
column_names = list(next(iter(dataset)).keys())
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
for i in range(len(examples["prompt"])):
query, response = examples["prompt"][i], examples["response"][i]
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
history = examples["history"][i] if "history" in examples else None
system = examples["system"][i] if "system" in examples else None
yield query, response, history, system
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...`
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
@@ -41,6 +55,7 @@ def preprocess_dataset(
kwargs = dict(add_special_tokens=True)
if hasattr(tokenizer, "add_eos_token"): # for LLaMA tokenizer
add_eos_token_flag = getattr(tokenizer, "add_eos_token")
setattr(tokenizer, "add_eos_token", True)
tokenized_examples = tokenizer(examples["prompt"], **kwargs)
@@ -54,26 +69,29 @@ def preprocess_dataset(
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
# make sure the saved tokenizer is the same as the original one
if hasattr(tokenizer, "add_eos_token"):
setattr(tokenizer, "add_eos_token", add_eos_token_flag)
return result
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for query, response, history, system in construct_example(examples):
input_ids, labels = [], []
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
continue
input_ids, labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, query, response, history, system
)):
total_len = len(source_ids) + len(target_ids)
max_source_len = int(data_args.cutoff_len * (len(source_ids) / total_len))
max_target_len = int(data_args.cutoff_len * (len(target_ids) / total_len))
if len(source_ids) > max_source_len:
source_len, target_len = len(source_ids), len(target_ids)
max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
if source_len > max_source_len:
source_ids = source_ids[:max_source_len]
if len(target_ids) > max_target_len:
if target_len > max_target_len:
target_ids = target_ids[:max_target_len]
if data_args.train_on_prompt:
@@ -100,12 +118,15 @@ def preprocess_dataset(
return model_inputs
def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
input_ids, labels = [], []
for query, response, history, system in construct_example(examples):
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
continue
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, query, response, history, system
)):
@@ -134,11 +155,14 @@ def preprocess_dataset(
return model_inputs
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for query, response, history, system in construct_example(examples):
if not (isinstance(query, str) and query != ""):
continue
input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system)
if template.efficient_eos:
@@ -155,10 +179,13 @@ def preprocess_dataset(
return model_inputs
def preprocess_pairwise_dataset(examples):
def preprocess_pairwise_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
for query, response, history, system in construct_example(examples):
if not (isinstance(query, str) and isinstance(response, list) and query != "" and len(response) > 1):
continue
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
@@ -166,23 +193,21 @@ def preprocess_dataset(
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
total_len = len(prompt_ids) + max(len(chosen_ids), len(rejected_ids))
max_source_len = int(data_args.cutoff_len * (len(prompt_ids) / total_len))
max_target_len = int(data_args.cutoff_len * (max(len(chosen_ids), len(rejected_ids)) / total_len))
if len(prompt_ids) > max_source_len:
source_len, target_len = len(prompt_ids), max(len(chosen_ids), len(rejected_ids))
max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
if source_len > max_source_len:
prompt_ids = prompt_ids[:max_source_len]
if len(chosen_ids) > max_target_len:
if target_len > max_target_len:
chosen_ids = chosen_ids[:max_target_len]
if len(rejected_ids) > max_target_len:
rejected_ids = rejected_ids[:max_target_len]
model_inputs["prompt_ids"].append(prompt_ids)
model_inputs["chosen_ids"].append(chosen_ids)
model_inputs["rejected_ids"].append(rejected_ids)
return model_inputs
def print_supervised_dataset_example(example):
def print_supervised_dataset_example(example: Dict[str, List[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
@@ -190,7 +215,7 @@ def preprocess_dataset(
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
))
def print_pairwise_dataset_example(example):
def print_pairwise_dataset_example(example: Dict[str, List[int]]) -> None:
print("prompt_ids:\n{}".format(example["prompt_ids"]))
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
print("chosen_ids:\n{}".format(example["chosen_ids"]))
@@ -198,46 +223,53 @@ def preprocess_dataset(
print("rejected_ids:\n{}".format(example["rejected_ids"]))
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
def print_unsupervised_dataset_example(example):
def print_unsupervised_dataset_example(example: Dict[str, List[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
if stage == "pt":
dataset = dataset.filter(lambda example: example["prompt"])
preprocess_func = preprocess_pretrain_dataset
print_function = print_unsupervised_dataset_example
elif stage == "sft" and not training_args.predict_with_generate:
dataset = dataset.filter(lambda example: example["prompt"] and example["response"])
preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset
print_function = print_supervised_dataset_example
elif stage == "rm":
dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1)
preprocess_func = preprocess_pairwise_dataset
print_function = print_pairwise_dataset_example
else:
dataset = dataset.filter(lambda example: example["prompt"])
preprocess_func = preprocess_unsupervised_dataset
print_function = print_unsupervised_dataset_example
if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
return load_from_disk(data_args.cache_path)
with training_args.main_process_first(desc="dataset map pre-processing"):
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Running tokenizer on dataset"
)
dataset = dataset.map(
preprocess_func,
batched=True,
batched=True,
remove_columns=column_names,
**kwargs
)
try:
print_function(next(iter(dataset)))
except StopIteration:
raise ValueError("Empty dataset!")
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
if training_args.should_save:
dataset.save_to_disk(data_args.cache_path)
raise SystemExit("Dataset saved, rerun this script with the same `--cache_path`.")
if training_args.should_log:
try:
print_function(next(iter(dataset)))
except StopIteration:
raise RuntimeError("Empty dataset!")
return dataset

View File

@@ -225,90 +225,6 @@ def get_template_and_fix_tokenizer(
return template
r"""
Supports language model inference without histories.
"""
register_template(
name="vanilla",
prefix=[],
prompt=[
"{{query}}"
],
system="",
sep=[],
use_history=False
)
r"""
Default template.
"""
register_template(
name="default",
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}\nAssistant: "
],
system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
sep=[
"\n"
]
)
r"""
Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
https://huggingface.co/meta-llama/Llama-2-13b-chat-hf
https://huggingface.co/meta-llama/Llama-2-70b-chat-hf
"""
register_template(
name="llama2",
prefix=[
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
],
prompt=[
"[INST] {{query}} [/INST] "
],
system=(
"You are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
),
sep=[]
)
r"""
Supports: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2
https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
"""
register_template(
name="llama2_zh",
prefix=[
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
],
prompt=[
"[INST] {{query}} [/INST] "
],
system="You are a helpful assistant. 你是一个乐于助人的助手。",
sep=[]
)
r"""
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
https://github.com/ymcui/Chinese-LLaMA-Alpaca
"""
register_template(
name="alpaca",
prefix=[
@@ -327,111 +243,13 @@ register_template(
)
r"""
Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
https://huggingface.co/lmsys/vicuna-13b-delta-v1.1
"""
register_template(
name="vicuna",
prefix=[
"{{system}}"
],
prompt=[
"USER: {{query}} ASSISTANT: "
],
system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
sep=[]
)
r"""
Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
"""
register_template(
name="belle",
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}\n\nBelle: "
],
system="",
sep=[
"\n\n"
]
)
r"""
Supports: https://github.com/CVI-SZU/Linly
"""
register_template(
name="linly",
prefix=[
"{{system}}"
],
prompt=[
"User: {{query}}\nBot: "
],
system="",
sep=[
"\n"
]
)
r"""
Supports: https://github.com/Neutralzz/BiLLa
"""
register_template(
name="billa",
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}\nAssistant: "
],
system="",
sep=[
"\n"
]
)
r"""
Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
"""
register_template(
name="ziya",
prefix=[
"{{system}}"
],
prompt=[
{"token": "<human>"},
":{{query}}\n",
{"token": "<bot>"},
":"
],
system="",
sep=[
"\n"
]
)
r"""
Supports: https://huggingface.co/BAAI/AquilaChat-7B
"""
register_template(
name="aquila",
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}###Assistant: "
"Human: {{query}}###Assistant:"
],
system=(
"A chat between a curious human and an artificial intelligence assistant. "
@@ -447,9 +265,169 @@ register_template(
)
r"""
Supports: https://huggingface.co/internlm/internlm-chat-7b
"""
register_template(
name="baichuan",
prefix=[
"{{system}}"
],
prompt=[
{"token": "<reserved_102>"}, # user token
"{{query}}",
{"token": "<reserved_103>"} # assistant token
],
system="",
sep=[],
efficient_eos=True
)
register_template(
name="baichuan2",
prefix=[
"{{system}}"
],
prompt=[
{"token": "<reserved_106>"}, # user token
"{{query}}",
{"token": "<reserved_107>"} # assistant token
],
system="",
sep=[],
efficient_eos=True
)
register_template(
name="belle",
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}\n\nBelle: "
],
system="",
sep=[
"\n\n"
]
)
register_template(
name="bluelm",
prefix=[
"{{system}}"
],
prompt=[
{"token": "[|Human|]:"},
"{{query}}",
{"token": "[|AI|]:"}
],
system="",
sep=[]
)
register_template(
name="chatglm2",
prefix=[
{"token": "[gMASK]"},
{"token": "sop"},
"{{system}}"
],
prompt=[
"[Round {{idx}}]\n\n问:{{query}}\n\n答:"
],
system="",
sep=[
"\n\n"
],
efficient_eos=True
)
register_template(
name="chatglm3",
prefix=[
{"token": "[gMASK]"},
{"token": "sop"},
"{{system}}"
],
prompt=[
{"token": "<|user|>"},
"\n",
"{{query}}",
{"token": "<|assistant|>"}
],
system="",
sep=[],
stop_words=[
"<|user|>",
"<|observation|>"
],
efficient_eos=True
)
register_template(
name="deepseek",
prefix=[
"{{system}}"
],
prompt=[
"### Instruction:\n{{query}}\n\n### Response:\n"
],
system=(
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
"developed by Deepseek Company, and you only answer questions related to computer science. "
"For politically sensitive questions, security and privacy issues, "
"and other non-computer science questions, you will refuse to answer."
),
sep=[
"\n",
{"token": "<|EOT|>"},
"\n\n"
],
stop_words=[
"<|EOT|>"
],
efficient_eos=True
)
register_template(
name="default",
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}\nAssistant:"
],
system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
sep=[
"\n"
]
)
register_template(
name="falcon",
prefix=[
"{{system}}"
],
prompt=[
"User: {{query}}\nFalcon:"
],
system="",
sep=[
"\n"
],
efficient_eos=True
)
register_template(
name="intern",
prefix=[
@@ -472,49 +450,101 @@ register_template(
)
r"""
Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
"""
register_template(
name="baichuan",
name="llama2",
prefix=[
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
],
prompt=[
"[INST] {{query}} [/INST]"
],
system=(
"You are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
),
sep=[]
)
register_template(
name="llama2_zh",
prefix=[
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
],
prompt=[
"[INST] {{query}} [/INST]"
],
system="You are a helpful assistant. 你是一个乐于助人的助手。",
sep=[]
)
register_template(
name="mistral",
prefix=[
"{{system}}"
],
prompt=[
{"token": "<reserved_102>"}, # user token
"{{query}}",
{"token": "<reserved_103>"} # assistant token
"[INST] {{query}} [/INST]"
],
system="",
sep=[],
efficient_eos=True
sep=[]
)
r"""
Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat
https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat
"""
register_template(
name="baichuan2",
name="openchat",
prefix=[
"{{system}}"
],
prompt=[
{"token": "<reserved_106>"}, # user token
"{{query}}",
{"token": "<reserved_107>"} # assistant token
"GPT4 Correct User: {{query}}",
{"token": "<|end_of_turn|>"},
"GPT4 Correct Assistant:"
],
system="",
sep=[],
sep=[
{"token": "<|end_of_turn|>"}
],
stop_words=[
"<|end_of_turn|>"
],
efficient_eos=True
)
register_template(
name="qwen",
prefix=[
{"token": "<|im_start|>"},
"system\n{{system}}"
],
prompt=[
{"token": "<|im_start|>"},
"user\n{{query}}",
{"token": "<|im_end|>"},
"\n",
{"token": "<|im_start|>"},
"assistant\n"
],
system="You are a helpful assistant.",
sep=[
{"token": "<|im_end|>"},
"\n"
],
stop_words=[
"<|im_end|>"
],
efficient_eos=True
)
r"""
Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha
https://huggingface.co/HuggingFaceH4/starchat-beta
"""
register_template(
name="starchat",
prefix=[
@@ -541,58 +571,36 @@ register_template(
r"""
Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
Supports language model inference without histories.
"""
register_template(
name="chatml",
prefix=[
{"token": "<|im_start|>"},
"system\n{{system}}"
],
name="vanilla",
prefix=[],
prompt=[
{"token": "<|im_start|>"},
"user\n{{query}}",
{"token": "<|im_end|>"},
"\n",
{"token": "<|im_start|>"},
"assistant\n"
"{{query}}"
],
system="You are a helpful assistant.",
sep=[
{"token": "<|im_end|>"},
"\n"
],
stop_words=[
"<|im_end|>"
],
efficient_eos=True
system="",
sep=[],
use_history=False
)
r"""
Supports: https://huggingface.co/THUDM/chatglm2-6b
"""
register_template(
name="chatglm2",
name="vicuna",
prefix=[
{"token": "[gMASK]"},
{"token": "sop"},
"{{system}}"
],
prompt=[
"[Round {{idx}}]\n\n问:{{query}}\n\n答:"
"USER: {{query}} ASSISTANT:"
],
system="",
sep=[
"\n\n"
],
efficient_eos=True
system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
sep=[]
)
r"""
Supports: https://huggingface.co/xverse/XVERSE-13B-Chat
"""
register_template(
name="xverse",
prefix=[
@@ -604,3 +612,71 @@ register_template(
system="",
sep=[]
)
register_template(
name="yayi",
prefix=[
{"token": "<|System|>"},
":\n{{system}}"
],
prompt=[
{"token": "<|Human|>"},
":\n{{query}}\n\n",
{"token": "<|YaYi|>"},
":"
],
system=(
"You are a helpful, respectful and honest assistant named YaYi "
"developed by Beijing Wenge Technology Co.,Ltd. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
),
sep=[
"\n\n"
],
stop_words=[
"<|End|>"
]
)
register_template(
name="zephyr",
prefix=[
{"token": "<|system|>"},
"\n{{system}}",
{"token": "</s>"}
],
prompt=[
{"token": "<|user|>"},
"\n{{query}}",
{"token": "</s>"},
{"token": "<|assistant|>"}
],
system="You are a friendly chatbot who always responds in the style of a pirate",
sep=[]
)
register_template(
name="ziya",
prefix=[
"{{system}}"
],
prompt=[
{"token": "<human>"},
":{{query}}\n",
{"token": "<bot>"},
":"
],
system="",
sep=[
"\n"
]
)

View File

@@ -13,9 +13,11 @@ logger = get_logger(__name__)
EXT2TYPE = {
"arrow": "arrow",
"csv": "csv",
"json": "json",
"jsonl": "json",
"parquet": "parquet",
"txt": "text"
}

View File

@@ -1,3 +0,0 @@
from llmtuner.dsets.loader import get_dataset
from llmtuner.dsets.preprocess import preprocess_dataset
from llmtuner.dsets.utils import split_dataset

View File

@@ -1,94 +0,0 @@
import os
from typing import TYPE_CHECKING, List, Union
from datasets import concatenate_datasets, interleave_datasets, load_dataset
from llmtuner.dsets.utils import checksum, EXT2TYPE
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from llmtuner.hparams import ModelArguments, DataArguments
logger = get_logger(__name__)
def get_dataset(
model_args: "ModelArguments",
data_args: "DataArguments"
) -> Union["Dataset", "IterableDataset"]:
max_samples = data_args.max_samples
all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
for dataset_attr in data_args.dataset_list:
logger.info("Loading dataset {}...".format(dataset_attr))
if dataset_attr.load_from == "hf_hub":
data_path = dataset_attr.dataset_name
data_files = None
elif dataset_attr.load_from == "script":
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
data_files = None
elif dataset_attr.load_from == "file":
data_path = None
data_files: List[str] = []
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # directory
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
if data_path is None:
data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
else:
assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file type does not match."
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # single file
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None)
else:
raise ValueError("File not found.")
assert data_path, "File extension must be txt, csv, json or jsonl."
checksum(data_files, dataset_attr.dataset_sha1)
else:
raise NotImplementedError
dataset = load_dataset(
data_path,
data_files=data_files,
split=data_args.split,
cache_dir=model_args.cache_dir,
streaming=data_args.streaming,
use_auth_token=True if model_args.use_auth_token else None
)
if max_samples is not None:
max_samples_temp = min(len(dataset), max_samples)
dataset = dataset.select(range(max_samples_temp))
# TODO: adapt to the sharegpt format
for column_name in ["prompt", "query", "response", "history"]: # align datasets
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
if dataset_attr.system_prompt: # add system prompt
if data_args.streaming:
dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt})
else:
dataset = dataset.add_column("system", [dataset_attr.system_prompt] * len(dataset))
all_datasets.append(dataset)
if len(data_args.dataset_list) == 1:
return all_datasets[0]
elif data_args.mix_strategy == "concat":
if data_args.streaming:
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming:
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
return interleave_datasets(all_datasets, data_args.interleave_probs, stopping_strategy=stopping_strategy)
else:
raise ValueError("Unknown mixing strategy.")

View File

@@ -0,0 +1 @@
from llmtuner.eval.evaluator import Evaluator

View File

@@ -0,0 +1,116 @@
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
import os
import json
import torch
import tiktoken
import numpy as np
from tqdm import tqdm, trange
from typing import Any, Dict, List, Optional
from datasets import load_dataset
from transformers.utils import cached_file
from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.eval.template import get_eval_template
from llmtuner.extras.constants import CHOICES, SUBJECTS
from llmtuner.model import dispatch_model, get_eval_args, load_model_and_tokenizer
class Evaluator:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
self.model = dispatch_model(self.model)
self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer)
self.eval_template = get_eval_template(self.eval_args.lang)
self.choice_inputs = self._encode_choices()
def _encode_choices(self) -> List[int]:
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=False)
return [self.tokenizer.encode(self.eval_template.prefix + ch, **kwargs)[-1] for ch in CHOICES]
@torch.inference_mode()
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
logits = self.model(**batch_input).logits
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
choice_probs = torch.nn.functional.softmax(word_probs[:, self.choice_inputs], dim=-1).detach()
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
def eval(self) -> None:
mapping = cached_file(
path_or_repo_id = os.path.join(self.eval_args.task_dir, self.eval_args.task),
filename="mapping.json",
cache_dir=self.model_args.cache_dir,
token=self.model_args.hf_hub_token,
revision=self.model_args.model_revision
)
with open(mapping, "r", encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f)
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
results = {}
for subject in pbar:
dataset = load_dataset(
path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
name=subject,
download_mode="force_redownload"
)
pbar.set_postfix_str(categorys[subject]["name"])
inputs, outputs, labels = [], [], []
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
support_set = dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
query, resp, history = self.eval_template.format_example(
target_data=dataset[self.data_args.split][i],
support_set=support_set,
subject_name=categorys[subject]["name"],
use_history=self.template.use_history
)
input_ids, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, query=query, resp=resp, history=history
)
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
labels.append(resp)
for i in trange(0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False):
batch_input = self.tokenizer.pad(
inputs[i : i + self.eval_args.batch_size], return_attention_mask=True, return_tensors="pt"
).to(self.model.device)
preds = self.batch_inference(batch_input)
outputs += preds
corrects = (np.array(outputs) == np.array(labels))
category_name = categorys[subject]["category"]
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
results[subject] = {str(i): outputs[i] for i in range(len(outputs))}
pbar.close()
self._save_results(category_corrects, results)
def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None:
score_info = "\n".join([
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
for category_name, category_correct in category_corrects.items() if len(category_correct)
])
print(score_info)
if self.eval_args.save_dir is not None:
os.makedirs(self.eval_args.save_dir, exist_ok=False)
with open(os.path.join(self.eval_args.save_dir, "results.json"), "w", encoding="utf-8", newline="\n") as f:
json.dump(results, f, indent=2)
with open(os.path.join(self.eval_args.save_dir, "results.log"), "w", encoding="utf-8", newline="\n") as f:
f.write(score_info)
if __name__ == "__main__":
evaluator = Evaluator()
evaluator.eval()

View File

@@ -0,0 +1,86 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Tuple
from llmtuner.extras.constants import CHOICES
if TYPE_CHECKING:
from datasets import Dataset
@dataclass
class EvalTemplate:
system: str
choice: str
answer: str
prefix: str
def parse_example(
self,
example: Dict[str, str]
) -> Tuple[str, str]:
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
def format_example(
self,
target_data: Dict[str, str],
support_set: "Dataset",
subject_name: str,
use_history: bool
) -> Tuple[str, str, List[Tuple[str, str]]]:
query, resp = self.parse_example(target_data)
history = [self.parse_example(support_set[k]) for k in range(len(support_set))]
if len(history):
temp = history.pop(0)
history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1]))
else:
query = self.system.format(subject=subject_name) + query
if not use_history:
query = "\n\n".join(["".join(item) for item in history] + [query])
history = []
return query.strip(), resp, history
eval_templates: Dict[str, EvalTemplate] = {}
def register_eval_template(
name: str,
system: str,
choice: str,
answer: str,
prefix: str
) -> None:
eval_templates[name] = EvalTemplate(
system=system,
choice=choice,
answer=answer,
prefix=prefix
)
def get_eval_template(name: str) -> EvalTemplate:
eval_template = eval_templates.get(name, None)
assert eval_template is not None, "Template {} does not exist.".format(name)
return eval_template
register_eval_template(
name="en",
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
choice="\n{choice}. {content}",
answer="\nAnswer: ",
prefix=" "
)
register_eval_template(
name="zh",
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
choice="\n{choice}. {content}",
answer="\n答案:",
prefix="\n"
)

View File

@@ -12,6 +12,7 @@ from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from transformers import TrainingArguments, TrainerState, TrainerControl
from trl import AutoModelForCausalLMWithValueHead
logger = get_logger(__name__)
@@ -25,18 +26,24 @@ class SavePeftModelCallback(TrainerCallback):
"""
if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
model = kwargs.pop("model")
model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
model.pretrained_model.config.save_pretrained(output_dir)
if model.pretrained_model.can_generate():
model.pretrained_model.generation_config.save_pretrained(output_dir)
if getattr(model, "is_peft_model", False):
getattr(model, "pretrained_model").save_pretrained(output_dir)
model.pretrained_model.save_pretrained(output_dir)
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
if args.should_save:
model = kwargs.pop("model")
model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
model.pretrained_model.config.save_pretrained(args.output_dir)
if model.pretrained_model.can_generate():
model.pretrained_model.generation_config.save_pretrained(args.output_dir)
if getattr(model, "is_peft_model", False):
getattr(model, "pretrained_model").save_pretrained(args.output_dir)
model.pretrained_model.save_pretrained(args.output_dir)
class LogCallback(TrainerCallback):

View File

@@ -1,11 +1,25 @@
from collections import defaultdict, OrderedDict
from typing import Dict, Optional
CHOICES = ["A", "B", "C", "D"]
DEFAULT_MODULE = defaultdict(str)
DEFAULT_TEMPLATE = defaultdict(str)
IGNORE_INDEX = -100
LAYERNORM_NAMES = {"norm", "ln"}
LOG_FILE_NAME = "trainer_log.jsonl"
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp", "ln_1", "ln_2"]
METHODS = ["full", "freeze", "lora"]
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
SUPPORTED_MODELS = OrderedDict()
TRAINING_STAGES = {
"Supervised Fine-Tuning": "sft",
"Reward Modeling": "rm",
@@ -14,75 +28,251 @@ TRAINING_STAGES = {
"Pre-Training": "pt"
}
SUPPORTED_MODELS = {
"LLaMA-7B": "huggyllama/llama-7b",
"LLaMA-13B": "huggyllama/llama-13b",
"LLaMA-30B": "huggyllama/llama-30b",
"LLaMA-65B": "huggyllama/llama-65b",
"LLaMA2-7B": "meta-llama/Llama-2-7b-hf",
"LLaMA2-13B": "meta-llama/Llama-2-13b-hf",
"LLaMA2-70B": "meta-llama/Llama-2-70b-hf",
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf",
"ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b",
"ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b",
"ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b",
"ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b",
"BLOOM-560M": "bigscience/bloom-560m",
"BLOOM-3B": "bigscience/bloom-3b",
"BLOOM-7B1": "bigscience/bloom-7b1",
"BLOOMZ-560M": "bigscience/bloomz-560m",
"BLOOMZ-3B": "bigscience/bloomz-3b",
"BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt",
"Falcon-7B": "tiiuae/falcon-7b",
"Falcon-40B": "tiiuae/falcon-40b",
"Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
"Baichuan-7B": "baichuan-inc/Baichuan-7B",
"Baichuan-13B": "baichuan-inc/Baichuan-13B-Base",
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat",
"Baichuan2-7B": "baichuan-inc/Baichuan2-7B-Base",
"Baichuan2-13B": "baichuan-inc/Baichuan2-13B-Base",
"Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat",
"Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat",
"InternLM-7B": "internlm/internlm-7b",
"InternLM-20B": "internlm/internlm-20b",
"InternLM-7B-Chat": "internlm/internlm-chat-7b",
"InternLM-20B-Chat": "internlm/internlm-chat-20b",
"Qwen-7B": "Qwen/Qwen-7B",
"Qwen-14B": "Qwen/Qwen-14B",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
"XVERSE-13B": "xverse/XVERSE-13B",
"XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat",
"ChatGLM2-6B-Chat": "THUDM/chatglm2-6b",
"Phi1.5-1.3B": "microsoft/phi-1_5"
}
DEFAULT_MODULE = {
"LLaMA": "q_proj,v_proj",
"LLaMA2": "q_proj,v_proj",
"ChineseLLaMA2": "q_proj,v_proj",
"BLOOM": "query_key_value",
"BLOOMZ": "query_key_value",
"Falcon": "query_key_value",
"Baichuan": "W_pack",
"Baichuan2": "W_pack",
"InternLM": "q_proj,v_proj",
"Qwen": "c_attn",
"XVERSE": "q_proj,v_proj",
"ChatGLM2": "query_key_value",
"Phi1.5": "Wqkv"
}
def register_model_group(
models: Dict[str, str],
module: Optional[str] = None,
template: Optional[str] = None
) -> None:
prefix = None
for name, path in models.items():
if prefix is None:
prefix = name.split("-")[0]
else:
assert prefix == name.split("-")[0], "prefix should be identical."
SUPPORTED_MODELS[name] = path
if module is not None:
DEFAULT_MODULE[prefix] = module
if template is not None:
DEFAULT_TEMPLATE[prefix] = template
DEFAULT_TEMPLATE = {
"LLaMA2": "llama2",
"ChineseLLaMA2": "llama2_zh",
"Baichuan": "baichuan",
"Baichuan2": "baichuan2",
"InternLM": "intern",
"Qwen": "chatml",
"XVERSE": "xverse",
"ChatGLM2": "chatglm2"
}
register_model_group(
models={
"Baichuan-7B-Base": "baichuan-inc/Baichuan-7B",
"Baichuan-13B-Base": "baichuan-inc/Baichuan-13B-Base",
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat"
},
module="W_pack",
template="baichuan"
)
register_model_group(
models={
"Baichuan2-7B-Base": "baichuan-inc/Baichuan2-7B-Base",
"Baichuan2-13B-Base": "baichuan-inc/Baichuan2-13B-Base",
"Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat",
"Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat"
},
module="W_pack",
template="baichuan2"
)
register_model_group(
models={
"BLOOM-560M": "bigscience/bloom-560m",
"BLOOM-3B": "bigscience/bloom-3b",
"BLOOM-7B1": "bigscience/bloom-7b1"
},
module="query_key_value"
)
register_model_group(
models={
"BLOOMZ-560M": "bigscience/bloomz-560m",
"BLOOMZ-3B": "bigscience/bloomz-3b",
"BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt"
},
module="query_key_value"
)
register_model_group(
models={
"BlueLM-7B-Base": "vivo-ai/BlueLM-7B-Base",
"BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat"
},
template="bluelm"
)
register_model_group(
models={
"ChatGLM2-6B-Chat": "THUDM/chatglm2-6b"
},
module="query_key_value",
template="chatglm2"
)
register_model_group(
models={
"ChatGLM3-6B-Base": "THUDM/chatglm3-6b-base",
"ChatGLM3-6B-Chat": "THUDM/chatglm3-6b"
},
module="query_key_value",
template="chatglm3"
)
register_model_group(
models={
"ChineseLLaMA2-1.3B": "hfl/chinese-llama-2-1.3b",
"ChineseLLaMA2-7B": "hfl/chinese-llama-2-7b",
"ChineseLLaMA2-13B": "hfl/chinese-llama-2-13b",
"ChineseLLaMA2-1.3B-Chat": "hfl/chinese-alpaca-2-1.3b",
"ChineseLLaMA2-7B-Chat": "hfl/chinese-alpaca-2-7b",
"ChineseLLaMA2-13B-Chat": "hfl/chinese-alpaca-2-13b"
},
template="llama2_zh"
)
register_model_group(
models={
"Falcon-7B": "tiiuae/falcon-7b",
"Falcon-40B": "tiiuae/falcon-40b",
"Falcon-180B": "tiiuae/falcon-180B",
"Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
"Falcon-180B-Chat": "tiiuae/falcon-180B-chat"
},
module="query_key_value",
template="falcon"
)
register_model_group(
models={
"InternLM-7B": "internlm/internlm-7b",
"InternLM-20B": "internlm/internlm-20b",
"InternLM-7B-Chat": "internlm/internlm-chat-7b",
"InternLM-20B-Chat": "internlm/internlm-chat-20b"
},
template="intern"
)
register_model_group(
models={
"LingoWhale-8B": "deeplang-ai/LingoWhale-8B"
},
module="qkv_proj"
)
register_model_group(
models={
"LLaMA-7B": "huggyllama/llama-7b",
"LLaMA-13B": "huggyllama/llama-13b",
"LLaMA-30B": "huggyllama/llama-30b",
"LLaMA-65B": "huggyllama/llama-65b"
}
)
register_model_group(
models={
"LLaMA2-7B": "meta-llama/Llama-2-7b-hf",
"LLaMA2-13B": "meta-llama/Llama-2-13b-hf",
"LLaMA2-70B": "meta-llama/Llama-2-70b-hf",
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf"
},
template="llama2"
)
register_model_group(
models={
"Mistral-7B": "mistralai/Mistral-7B-v0.1",
"Mistral-7B-Chat": "mistralai/Mistral-7B-Instruct-v0.1"
},
template="mistral"
)
register_model_group(
models={
"OpenChat3.5-7B-Chat": "openchat/openchat_3.5"
},
template="openchat"
)
register_model_group(
models={
"Phi1.5-1.3B": "microsoft/phi-1_5"
},
module="Wqkv"
)
register_model_group(
models={
"Qwen-7B": "Qwen/Qwen-7B",
"Qwen-14B": "Qwen/Qwen-14B",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat"
},
module="c_attn",
template="qwen"
)
register_model_group(
models={
"Skywork-13B-Base": "Skywork/Skywork-13B-base"
}
)
register_model_group(
models={
"Vicuna1.5-7B-Chat": "lmsys/vicuna-7b-v1.5",
"Vicuna1.5-13B-Chat": "lmsys/vicuna-13b-v1.5"
},
template="vicuna"
)
register_model_group(
models={
"XVERSE-7B": "xverse/XVERSE-7B",
"XVERSE-13B": "xverse/XVERSE-13B",
"XVERSE-65B": "xverse/XVERSE-65B",
"XVERSE-7B-Chat": "xverse/XVERSE-7B-Chat",
"XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat"
},
template="xverse"
)
register_model_group(
models={
"Yayi-7B": "wenge-research/yayi-7b-llama2",
"Yayi-13B": "wenge-research/yayi-13b-llama2"
},
template="yayi"
)
register_model_group(
models={
"Yi-6B": "01-ai/Yi-6B",
"Yi-34B": "01-ai/Yi-34B"
}
)
register_model_group(
models={
"Zephyr-7B-Alpha-Chat": "HuggingFaceH4/zephyr-7b-alpha",
"Zephyr-7B-Beta-Chat": "HuggingFaceH4/zephyr-7b-beta"
},
template="zephyr"
)

View File

@@ -3,6 +3,9 @@ import logging
class LoggerHandler(logging.Handler):
r"""
Logger handler used in Web UI.
"""
def __init__(self):
super().__init__()
@@ -19,16 +22,10 @@ class LoggerHandler(logging.Handler):
self.log += "\n\n"
def reset_logging():
r"""
Removes basic config of root logger
"""
root = logging.getLogger()
list(map(root.removeHandler, root.handlers))
list(map(root.removeFilter, root.filters))
def get_logger(name: str) -> logging.Logger:
r"""
Gets a standard logger with a stream hander to stdout.
"""
formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S"
@@ -41,3 +38,12 @@ def get_logger(name: str) -> logging.Logger:
logger.addHandler(handler)
return logger
def reset_logging() -> None:
r"""
Removes basic config of root logger. (unused in script)
"""
root = logging.getLogger()
list(map(root.removeHandler, root.handlers))
list(map(root.removeFilter, root.filters))

View File

@@ -1,6 +1,8 @@
import gc
import os
import sys
import torch
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
try:
@@ -11,13 +13,13 @@ try:
is_torch_npu_available
)
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available()
except ImportError:
_is_fp16_available = torch.cuda.is_available()
_is_bf16_available = torch.cuda.is_bf16_supported()
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from transformers import HfArgumentParser
class AverageMeter:
@@ -62,6 +64,25 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
return trainable_params, all_param
def get_current_device() -> str:
import accelerate
from accelerate import Accelerator
dummy_accelerator = Accelerator()
if accelerate.utils.is_xpu_available():
return "xpu:{}".format(dummy_accelerator.local_process_index)
else:
return dummy_accelerator.local_process_index if torch.cuda.is_available() else "cpu"
def get_logits_processor() -> "LogitsProcessorList":
r"""
Gets logits processor that removes NaN and Inf logits.
"""
logits_processor = LogitsProcessorList()
logits_processor.append(InfNanRemoveLogitsProcessor())
return logits_processor
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
r"""
Infers the optimal dtype according to the model_dtype and device compatibility.
@@ -74,13 +95,15 @@ def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
return torch.float32
def get_logits_processor() -> LogitsProcessorList:
r"""
Gets logits processor that removes NaN and Inf logits.
"""
logits_processor = LogitsProcessorList()
logits_processor.append(InfNanRemoveLogitsProcessor())
return logits_processor
def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None:
return parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
else:
return parser.parse_args_into_dataclasses()
def torch_gc() -> None:
@@ -91,28 +114,3 @@ def torch_gc() -> None:
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
r"""
Dispatches a pre-trained model to GPUs with balanced memory.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
"""
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
return model
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
if model._no_split_modules is None:
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
max_memory = get_balanced_memory(model, **kwargs)
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
return dispatch_model(model, device_map)
else:
return model.cuda()

View File

@@ -0,0 +1,55 @@
import importlib.metadata
import importlib.util
def is_package_available(name: str) -> bool:
return importlib.util.find_spec(name) is not None
def get_package_version(name: str) -> str:
try:
return importlib.metadata.version(name)
except:
return "0.0.0"
_fastapi_available = is_package_available("fastapi")
_flash_attn2_available = is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")
_jieba_available = is_package_available("jieba")
_matplotlib_available = is_package_available("matplotlib")
_nltk_available = is_package_available("nltk")
_rouge_available = is_package_available("rouge-chinese")
_starlette_available = is_package_available("sse-starlette")
_uvicorn_available = is_package_available("uvicorn")
def is_fastapi_availble():
return _fastapi_available
def is_flash_attn2_available():
return _flash_attn2_available
def is_jieba_available():
return _jieba_available
def is_matplotlib_available():
return _matplotlib_available
def is_nltk_available():
return _nltk_available
def is_rouge_available():
return _rouge_available
def is_starlette_available():
return _starlette_available
def is_uvicorn_available():
return _uvicorn_available

View File

@@ -3,13 +3,19 @@ import torch
import torch.nn as nn
from typing import Optional, Tuple
from transformers.utils import logging
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
try:
from transformers.models.llama.modeling_llama import repeat_kv
except ImportError:
print("Please upgrade `transformers`.")
from llmtuner.extras.packages import is_flash_attn2_available
if is_flash_attn2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
except ImportError:
print("FlashAttention-2 is not installed, ignore this if you are not using FlashAttention.")
logger = logging.get_logger(__name__)

View File

@@ -1,11 +1,14 @@
import os
import math
import json
import matplotlib.pyplot as plt
from typing import List, Optional
from transformers.trainer import TRAINER_STATE_NAME
from llmtuner.extras.logging import get_logger
from llmtuner.extras.packages import is_matplotlib_available
if is_matplotlib_available():
import matplotlib.pyplot as plt
logger = get_logger(__name__)

View File

@@ -1,21 +0,0 @@
import os
import torch
from transformers.trainer import WEIGHTS_NAME
from llmtuner.extras.logging import get_logger
logger = get_logger(__name__)
def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
vhead_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
if not os.path.exists(vhead_file):
logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
return False
vhead_params = torch.load(vhead_file, map_location="cpu")
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
return True

View File

@@ -1,4 +1,5 @@
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments
from .model_args import ModelArguments

View File

@@ -11,11 +11,17 @@ class DatasetAttr:
dataset_name: Optional[str] = None
dataset_sha1: Optional[str] = None
system_prompt: Optional[str] = None
subset: Optional[str] = None
ranking: Optional[bool] = False
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
response: Optional[str] = "output"
history: Optional[str] = None
messages: Optional[str] = "conversations"
role: Optional[str] = "from"
content: Optional[str] = "value"
def __repr__(self) -> str:
return self.dataset_name
@@ -36,7 +42,7 @@ class DataArguments:
)
dataset_dir: Optional[str] = field(
default="data",
metadata={"help": "The name of the folder containing datasets."}
metadata={"help": "Path to the folder containing the datasets."}
)
split: Optional[str] = field(
default="train",
@@ -46,6 +52,10 @@ class DataArguments:
default=1024,
metadata={"help": "The maximum length of the model inputs after tokenization."}
)
reserved_label_len: Optional[int] = field(
default=1,
metadata={"help": "The maximum length reserved for label after tokenization."}
)
train_on_prompt: Optional[bool] = field(
default=False,
metadata={"help": "Whether to disable the mask on the prompt or not."}
@@ -60,7 +70,7 @@ class DataArguments:
)
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
default="concat",
metadata={"help": "Strategy to use in dataset mixing."}
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."}
)
interleave_probs: Optional[str] = field(
default=None,
@@ -98,20 +108,33 @@ class DataArguments:
default=False,
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
)
cache_path: Optional[str] = field(
default=None,
metadata={"help": "Path to save or load the preprocessed datasets."}
)
def __post_init__(self):
if self.reserved_label_len >= self.cutoff_len:
raise ValueError("`reserved_label_len` must be smaller than `cutoff_len`.")
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
raise ValueError("Streaming mode should have an integer val size.")
if self.streaming and self.max_samples is not None:
raise ValueError("`max_samples` is incompatible with `streaming`.")
def init_for_training(self): # support mixing multiple datasets
if self.streaming and self.cache_path:
raise ValueError("`cache_path` is incompatible with `streaming`.")
def init_for_training(self, seed: int): # support mixing multiple datasets
self.seed = seed
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
try:
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
dataset_info = json.load(f)
except Exception:
if self.dataset is not None:
raise ValueError("Cannot find dataset_info.json in `dataset_dir`.")
dataset_info = None
prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
@@ -142,7 +165,12 @@ class DataArguments:
dataset_attr.query = dataset_info[name]["columns"].get("query", None)
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
dataset_attr.messages = dataset_info[name]["columns"].get("messages", None)
dataset_attr.role = dataset_info[name]["columns"].get("role", None)
dataset_attr.content = dataset_info[name]["columns"].get("content", None)
dataset_attr.subset = dataset_info[name].get("subset", None)
dataset_attr.ranking = dataset_info[name].get("ranking", False)
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
dataset_attr.system_prompt = prompt_list[i]
self.dataset_list.append(dataset_attr)

View File

@@ -0,0 +1,55 @@
import os
from typing import Literal, Optional
from dataclasses import dataclass, field
from datasets import DownloadMode
@dataclass
class EvaluationArguments:
r"""
Arguments pertaining to specify the evaluation parameters.
"""
task: str = field(
metadata={"help": "Name of the evaluation task."}
)
task_dir: Optional[str] = field(
default="evaluation",
metadata={"help": "Path to the folder containing the evaluation datasets."}
)
batch_size: Optional[int] = field(
default=4,
metadata={"help": "The batch size per GPU for evaluation."}
)
seed: Optional[int] = field(
default=42,
metadata={"help": "Random seed to be used with data loaders."}
)
lang: Optional[Literal["en", "zh"]] = field(
default="en",
metadata={"help": "Language used at evaluation."}
)
n_shot: Optional[int] = field(
default=5,
metadata={"help": "Number of examplars for few-shot learning."}
)
save_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to save the evaluation results."}
)
download_mode: Optional[DownloadMode] = field(
default=DownloadMode.REUSE_DATASET_IF_EXISTS,
metadata={"help": "Download mode used for the evaluation datasets."}
)
def __post_init__(self):
task_available = []
for folder in os.listdir(self.task_dir):
if os.path.isdir(os.path.join(self.task_dir, folder)):
task_available.append(folder)
if self.task not in task_available:
raise ValueError("Task {} not found in {}.".format(self.task, self.task_dir))
if self.save_dir is not None and os.path.exists(self.save_dir):
raise ValueError("`save_dir` already exists, use another one.")

View File

@@ -4,38 +4,38 @@ from dataclasses import asdict, dataclass, field
@dataclass
class FinetuningArguments:
class FreezeArguments:
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
Arguments pertaining to the freeze (partial-parameter) training.
"""
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."}
)
finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field(
default="lora",
metadata={"help": "Which fine-tuning method to use."}
)
num_layer_trainable: Optional[int] = field(
default=3,
metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
)
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
name_module_trainable: Optional[str] = field(
default="mlp",
metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
Use commas to separate multiple modules. \
LLaMA choices: [\"mlp\", \"self_attn\"], \
BLOOM & Falcon & ChatGLM2 choices: [\"mlp\", \"self_attention\"], \
BLOOM & Falcon & ChatGLM choices: [\"mlp\", \"self_attention\"], \
Qwen choices: [\"mlp\", \"attn\"], \
Phi-1.5 choices: [\"mlp\", \"mixer\"], \
LLaMA-2, Baichuan, InternLM, XVERSE choices: the same as LLaMA."}
Others choices: the same as LLaMA."}
)
@dataclass
class LoraArguments:
r"""
Arguments pertaining to the LoRA training.
"""
lora_rank: Optional[int] = field(
default=8,
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
)
lora_alpha: Optional[float] = field(
default=32.0,
metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."}
default=None,
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2.0)."}
)
lora_dropout: Optional[float] = field(
default=0.1,
@@ -45,11 +45,11 @@ class FinetuningArguments:
default=None,
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
BLOOM & Falcon & ChatGLM2 choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
BLOOM & Falcon & ChatGLM choices: [\"query_key_value\", \"dense\", \"dense_h_to_4h\", \"dense_4h_to_h\"], \
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
Others choices: the same as LLaMA."}
)
additional_target: Optional[str] = field(
default=None,
@@ -59,31 +59,115 @@ class FinetuningArguments:
default=True,
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
)
ppo_score_norm: Optional[bool] = field(
default=False,
metadata={"help": "Use score normalization in PPO training."}
@dataclass
class RLHFArguments:
r"""
Arguments pertaining to the PPO and DPO training.
"""
dpo_beta: Optional[float] = field(
default=0.1,
metadata={"help": "The beta parameter for the DPO loss."}
)
ppo_logger: Optional[str] = field(
default=None,
metadata={"help": "Log with either 'wandb' or 'tensorboard' in PPO training."}
)
ppo_score_norm: Optional[bool] = field(
default=False,
metadata={"help": "Use score normalization in PPO training."}
)
ppo_target: Optional[float] = field(
default=6.0,
metadata={"help": "Target KL value for adaptive KL control in PPO training."}
)
dpo_beta: Optional[float] = field(
default=0.1,
metadata={"help": "The beta parameter for the DPO loss."}
ppo_whiten_rewards: Optional[bool] = field(
default=False,
metadata={"help": "Whiten the rewards before compute advantages in PPO training."}
)
ref_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the reference model used for the PPO or DPO training."}
)
ref_model_checkpoint: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reference model."}
)
ref_model_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the reference model."}
)
reward_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
)
reward_model_checkpoint: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reward model."}
)
reward_model_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the reward model."}
)
reward_model_type: Optional[Literal["lora", "full"]] = field(
default="lora",
metadata={"help": "The checkpoint type of the reward model. The lora type only supports lora training."}
)
@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."}
)
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
default="lora",
metadata={"help": "Which fine-tuning method to use."}
)
upcast_layernorm: Optional[bool] = field(
default=False,
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
)
neft_alpha: Optional[float] = field(
default=0,
metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."}
)
export_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."}
)
plot_loss: Optional[bool] = field(
default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
)
def __post_init__(self):
if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
self.lora_target = [target.strip() for target in self.lora_target.split(",")]
def split_arg(arg):
if isinstance(arg, str):
return [item.strip() for item in arg.split(",")]
return arg
if isinstance(self.additional_target, str):
self.additional_target = [target.strip() for target in self.additional_target.split(",")]
self.name_module_trainable = split_arg(self.name_module_trainable)
self.lora_alpha = self.lora_alpha or float(self.lora_rank * 2.0)
self.lora_target = split_arg(self.lora_target)
self.additional_target = split_arg(self.additional_target)
self.ref_model_checkpoint = split_arg(self.ref_model_checkpoint)
self.reward_model_checkpoint = split_arg(self.reward_model_checkpoint)
assert self.finetuning_type in ["lora", "freeze", "full", "none"], "Invalid fine-tuning method."
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
if self.stage == "ppo" and self.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
raise ValueError("Lora reward model only supports lora training.")
def save_to_json(self, json_path: str):
r"""Saves the content of this instance in JSON format inside `json_path`."""
@@ -96,4 +180,5 @@ class FinetuningArguments:
r"""Creates an instance from the content of `json_path`."""
with open(json_path, "r", encoding="utf-8") as f:
text = f.read()
return cls(**json.loads(text))

View File

@@ -1,13 +0,0 @@
from typing import Literal, Optional
from dataclasses import dataclass, field
@dataclass
class GeneralArguments:
r"""
Arguments pertaining to which stage we are going to perform.
"""
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."}
)

View File

@@ -28,7 +28,7 @@ class GeneratingArguments:
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
)
max_length: Optional[int] = field(
default=None,
default=512,
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
)
max_new_tokens: Optional[int] = field(
@@ -46,6 +46,8 @@ class GeneratingArguments:
def to_dict(self) -> Dict[str, Any]:
args = asdict(self)
if args.get("max_new_tokens", None):
if args.get("max_new_tokens", -1) > 0:
args.pop("max_length", None)
else:
args.pop("max_new_tokens", None)
return args

View File

@@ -1,5 +1,5 @@
from typing import Literal, Optional
from dataclasses import dataclass, field
from typing import Any, Dict, Literal, Optional
from dataclasses import asdict, dataclass, field
@dataclass
@@ -22,10 +22,6 @@ class ModelArguments:
default=False,
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}
)
use_auth_token: Optional[bool] = field(
default=False,
metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
)
model_revision: Optional[str] = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
@@ -48,7 +44,7 @@ class ModelArguments:
)
checkpoint_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
metadata={"help": "Path to the directory(s) containing the model checkpoints as well as the configurations."}
)
flash_attn: Optional[bool] = field(
default=False,
@@ -58,26 +54,10 @@ class ModelArguments:
default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
)
reward_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
)
upcast_layernorm: Optional[bool] = field(
default=False,
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
)
plot_loss: Optional[bool] = field(
default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
)
hf_auth_token: Optional[str] = field(
hf_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."}
)
export_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."}
)
def __post_init__(self):
self.compute_dtype = None
@@ -89,9 +69,7 @@ class ModelArguments:
if self.checkpoint_dir is not None: # support merging multiple lora weights
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
if self.quantization_bit is not None:
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
if self.use_auth_token == True and self.hf_auth_token is not None:
from huggingface_hub.hf_api import HfFolder # lazy load
HfFolder.save_token(self.hf_auth_token)
def to_dict(self) -> Dict[str, Any]:
return asdict(self)

View File

@@ -0,0 +1,5 @@
# Level: loader > adapter > parser, utils
from llmtuner.model.loader import load_model_and_tokenizer
from llmtuner.model.parser import get_train_args, get_infer_args, get_eval_args
from llmtuner.model.utils import dispatch_model, generate_model_card, load_valuehead_params

View File

@@ -1,15 +1,9 @@
import torch
from typing import TYPE_CHECKING
from peft import (
PeftModel,
TaskType,
LoraConfig,
get_peft_model
)
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core.utils import find_all_linear_modules
from llmtuner.model.utils import find_all_linear_modules
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
@@ -23,8 +17,7 @@ def init_adapter(
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
is_mergeable: bool
is_trainable: bool
) -> "PreTrainedModel":
r"""
Initializes the adapters.
@@ -34,22 +27,33 @@ def init_adapter(
Note that the trainable parameters must be cast to float32.
"""
if finetuning_args.finetuning_type == "none" and is_trainable:
raise ValueError("You cannot use finetuning_type=none while training.")
if (not is_trainable) and model_args.checkpoint_dir is None:
logger.info("Checkpoint is not found at evaluation, load the original model.")
return model
if finetuning_args.finetuning_type == "full" and is_trainable:
logger.info("Fine-tuning method: Full")
model = model.float()
if finetuning_args.finetuning_type == "freeze":
if finetuning_args.finetuning_type == "freeze" and is_trainable:
logger.info("Fine-tuning method: Freeze")
num_layers = getattr(model.config, "num_layers")
num_layers = (
getattr(model.config, "num_hidden_layers", None)
or getattr(model.config, "num_layers", None)
or getattr(model.config, "n_layer", None)
)
if not num_layers:
raise ValueError("Current model does not support freeze tuning.")
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)]
trainable_layers = ["{:d}.{}".format(idx, finetuning_args.name_module_trainable) for idx in trainable_layer_ids]
trainable_layers = []
for module_name in finetuning_args.name_module_trainable:
for idx in trainable_layer_ids:
trainable_layers.append("{:d}.{}".format(idx, module_name))
for name, param in model.named_parameters():
if not any(trainable_layer in name for trainable_layer in trainable_layers):
param.requires_grad_(False)
@@ -58,11 +62,11 @@ def init_adapter(
if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: LoRA")
latest_checkpoint = None
checkpoint_to_resume = None
if model_args.checkpoint_dir is not None:
if (is_trainable and finetuning_args.resume_lora_training) or (not is_mergeable): # continually fine-tuning
checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
if is_trainable and finetuning_args.resume_lora_training:
checkpoints_to_merge, checkpoint_to_resume = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
else:
checkpoints_to_merge = model_args.checkpoint_dir
@@ -73,10 +77,10 @@ def init_adapter(
if len(checkpoints_to_merge) > 0:
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
if latest_checkpoint is not None: # resume lora training or quantized inference
model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable)
if checkpoint_to_resume is not None: # resume lora training
model = PeftModel.from_pretrained(model, checkpoint_to_resume, is_trainable=is_trainable)
if is_trainable and latest_checkpoint is None: # create new lora weights while training
if is_trainable and checkpoint_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
target_modules = find_all_linear_modules(model, model_args.quantization_bit)
else:
@@ -92,8 +96,6 @@ def init_adapter(
modules_to_save=finetuning_args.additional_target
)
model = get_peft_model(model, lora_config)
if id(model.peft_config) != id(model.base_model.peft_config): # https://github.com/huggingface/peft/issues/923
model.base_model.peft_config = model.peft_config
if model_args.checkpoint_dir is not None:
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))

View File

@@ -14,7 +14,6 @@ from transformers import (
PreTrainedTokenizerBase
)
from transformers.models.llama import modeling_llama as LlamaModule
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
from trl import AutoModelForCausalLMWithValueHead
@@ -24,12 +23,12 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v
from transformers.deepspeed import is_deepspeed_zero3_enabled
from llmtuner.extras.logging import reset_logging, get_logger
from llmtuner.extras.misc import count_parameters, infer_optim_dtype
from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype
from llmtuner.extras.packages import is_flash_attn2_available
from llmtuner.extras.patches import llama_patch as LlamaPatches
from llmtuner.extras.save_and_load import load_valuehead_params
from llmtuner.hparams import FinetuningArguments
from llmtuner.tuner.core.adapter import init_adapter
from llmtuner.tuner.core.utils import prepare_model_for_training
from llmtuner.model.adapter import init_adapter
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
@@ -39,11 +38,11 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
check_min_version("4.31.0")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transformers>=4.31.0,<4.35.0\"")
require_version("datasets>=2.14.0", "To fix: pip install datasets>=2.14.0")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
require_version("trl>=0.7.1", "To fix: pip install trl>=0.7.1")
require_version("peft>=0.6.0", "To fix: pip install peft>=0.6.0")
require_version("trl>=0.7.4", "To fix: pip install trl>=0.7.4")
def load_model_and_tokenizer(
@@ -57,15 +56,12 @@ def load_model_and_tokenizer(
Support both training and inference.
"""
if (not is_trainable) and model_args.checkpoint_dir is None:
logger.warning("Checkpoint is not found at evaluation, load the original model.")
finetuning_args = FinetuningArguments(finetuning_type="none")
config_kwargs = {
"trust_remote_code": True,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
"token": model_args.hf_hub_token
}
tokenizer = AutoTokenizer.from_pretrained(
@@ -77,21 +73,21 @@ def load_model_and_tokenizer(
)
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
logger.info("Use `model_name_or_path` to specify the model trained with full/freeze method.")
model_to_load = model_args.checkpoint_dir[0]
else:
model_to_load = model_args.model_name_or_path
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
# Fix tokenizer (for ChatGLM2)
# Fix tokenizer (for ChatGLM2 and ChatGLM3)
if getattr(config, "model_type", None) == "chatglm":
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
# Set model dtype
if model_args.compute_dtype is not None: # for training
setattr(config, "torch_dtype", model_args.compute_dtype)
else: # for evaluation, priority: bf16 > fp16 > fp32
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
setattr(config, "torch_dtype", model_args.compute_dtype)
# Fix config (for Qwen)
if getattr(config, "model_type", None) == "qwen":
@@ -100,15 +96,9 @@ def load_model_and_tokenizer(
# Set RoPE scaling
if model_args.rope_scaling is not None:
if hasattr(config, "use_dynamic_ntk"): # for Qwen models
if is_trainable:
logger.warning("Qwen model does not support RoPE scaling in training.")
else:
setattr(config, "use_dynamic_ntk", True)
setattr(config, "use_logn_attn", True)
logger.info("Using dynamic NTK scaling.")
elif hasattr(config, "rope_scaling"): # for LLaMA and Falcon models
if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.")
else:
if is_trainable:
if model_args.rope_scaling == "dynamic":
logger.warning(
@@ -130,19 +120,19 @@ def load_model_and_tokenizer(
model_args.rope_scaling, scaling_factor
))
else:
logger.warning("Current model does not support RoPE scaling.")
# Set FlashAttention-2
if model_args.flash_attn:
if getattr(config, "model_type", None) == "llama":
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
logger.info("Using FlashAttention-2 for faster training and inference.")
elif getattr(config, "model_type", None) == "qwen":
logger.info("Qwen models automatically enable FlashAttention if installed.")
if is_flash_attn2_available():
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
logger.info("Using FlashAttention-2 for faster training and inference.")
else:
logger.warning("FlashAttention-2 is not installed.")
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
logger.info("Current model automatically enables FlashAttention if installed.")
else:
logger.warning("Current model does not support FlashAttention-2.")
logger.warning("Current model does not support FlashAttention.")
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
logger.warning("Using `--flash_attn` for faster training in large context length.")
@@ -155,8 +145,7 @@ def load_model_and_tokenizer(
else:
logger.warning("Current model does not support shift short attention.")
# Quantization configurations (using bitsandbytes library).
is_mergeable = True
# Quantization configurations (using bitsandbytes library)
if model_args.quantization_bit is not None:
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
@@ -166,7 +155,7 @@ def load_model_and_tokenizer(
config_kwargs["load_in_8bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif model_args.quantization_bit == 4:
if model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
config_kwargs["load_in_4bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(
@@ -176,11 +165,10 @@ def load_model_and_tokenizer(
bnb_4bit_quant_type=model_args.quantization_type
)
is_mergeable = False
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
config_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
# Load and prepare pre-trained models (without valuehead).
# Load pre-trained models (without valuehead)
model = AutoModelForCausalLM.from_pretrained(
model_to_load,
config=config,
@@ -193,11 +181,12 @@ def load_model_and_tokenizer(
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
# Fix LM head (for ChatGLM2)
# Fix LM head (for ChatGLM2 and ChatGLM3)
if getattr(config, "model_type", None) == "chatglm":
setattr(model, "lm_head", model.transformer.output_layer)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
# Register auto class to save the custom code files.
# Register auto class to save the custom code files
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()
if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
@@ -206,29 +195,20 @@ def load_model_and_tokenizer(
tokenizer.__class__.register_for_auto_class()
# Initialize adapters
if is_trainable:
model = prepare_model_for_training(model, model_args.upcast_layernorm, finetuning_args.finetuning_type)
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable)
model = model.train() if is_trainable else model.eval()
# Prepare model with valuehead for RLHF
if stage == "rm" or stage == "ppo":
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
model._keys_to_ignore_on_save = None
reset_logging()
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
logger.warning("Only the last checkpoint containing valuehead will be loaded.")
if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
model.v_head.load_state_dict({
"summary.weight": getattr(model, "reward_head_weight"),
"summary.bias": getattr(model, "reward_head_bias")
})
if stage == "ppo": # load reward model
logger.info("Load reward model from {}".format(model_args.reward_model))
if getattr(model, "is_peft_model", False):
model.pretrained_model.load_adapter(model_args.reward_model, "reward")
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
if stage in ["rm", "ppo"]:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
vhead_path = (
model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path
)
vhead_params = load_valuehead_params(vhead_path, model_args)
if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
# Prepare model for inference
if not is_trainable:
@@ -240,4 +220,7 @@ def load_model_and_tokenizer(
trainable_params, all_param, 100 * trainable_params / all_param
))
if not is_trainable:
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
return model, tokenizer

View File

@@ -1,5 +1,4 @@
import os
import sys
import torch
import datasets
import transformers
@@ -8,9 +7,11 @@ from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import parse_args
from llmtuner.hparams import (
ModelArguments,
DataArguments,
EvaluationArguments,
FinetuningArguments,
GeneratingArguments
)
@@ -19,62 +20,42 @@ from llmtuner.hparams import (
logger = get_logger(__name__)
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None:
return parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
else:
return parser.parse_args_into_dataclasses()
_TRAIN_ARGS = [
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
]
_TRAIN_CLS = Tuple[
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
]
_INFER_ARGS = [
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
]
_INFER_CLS = Tuple[
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
]
_EVAL_ARGS = [
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
]
_EVAL_CLS = Tuple[
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
]
def parse_train_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments
]:
parser = HfArgumentParser((
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments
))
return _parse_args(parser, args)
def parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
return parse_args(parser, args)
def parse_infer_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
]:
parser = HfArgumentParser((
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
))
return _parse_args(parser, args)
def parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS)
return parse_args(parser, args)
def get_train_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments
]:
def parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS)
return parse_args(parser, args)
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args)
# Setup logging
@@ -88,8 +69,8 @@ def get_train_args(
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
data_args.init_for_training()
# Check arguments
data_args.init_for_training(training_args.seed)
if finetuning_args.stage != "pt" and data_args.template is None:
raise ValueError("Please specify which `template` to use.")
@@ -100,26 +81,20 @@ def get_train_args(
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
if finetuning_args.stage in ["rm", "ppo"]:
if training_args.resume_from_checkpoint is not None:
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
if training_args.load_best_model_at_end:
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
if finetuning_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
if finetuning_args.stage in ["ppo", "dpo"] and not training_args.do_train:
raise ValueError("PPO and DPO stages can only be performed at training.")
if finetuning_args.stage == "ppo" and not training_args.do_train:
raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
if finetuning_args.stage in ["rm", "dpo"]:
for dataset_attr in data_args.dataset_list:
if not dataset_attr.ranking:
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
if finetuning_args.stage == "ppo" and model_args.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
if finetuning_args.stage == "ppo" and data_args.streaming:
raise ValueError("Streaming mode does not suppport PPO training currently.")
if finetuning_args.stage == "ppo" and model_args.shift_attn:
raise ValueError("PPO training is incompatible with S^2-Attn.")
@@ -135,18 +110,14 @@ def get_train_args(
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) != 1:
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
if (
model_args.checkpoint_dir is not None
and len(model_args.checkpoint_dir) != 1
and finetuning_args.finetuning_type != "lora"
):
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
if model_args.quantization_bit is not None:
if len(model_args.checkpoint_dir) != 1:
raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
if not finetuning_args.resume_lora_training:
raise ValueError("Quantized model cannot create new LoRA weight. Merge them first.")
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
if training_args.do_train and model_args.quantization_bit is not None and (not finetuning_args.upcast_layernorm):
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
@@ -155,6 +126,9 @@ def get_train_args(
if (not training_args.do_train) and model_args.quantization_bit is not None:
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
# postprocess training_args
if (
training_args.local_rank != -1
@@ -203,14 +177,7 @@ def get_train_args(
return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
]:
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
if data_args.template is None:
@@ -219,11 +186,25 @@ def get_infer_args(
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) != 1:
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
if model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
if (
model_args.checkpoint_dir is not None
and len(model_args.checkpoint_dir) != 1
and finetuning_args.finetuning_type != "lora"
):
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = parse_eval_args(args)
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
transformers.set_seed(eval_args.seed)
return model_args, data_args, eval_args, finetuning_args

165
src/llmtuner/model/utils.py Normal file
View File

@@ -0,0 +1,165 @@
import torch
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from transformers.utils import cached_file
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from llmtuner.extras.constants import LAYERNORM_NAMES
from llmtuner.extras.logging import get_logger
from llmtuner.hparams import ModelArguments, FinetuningArguments
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from llmtuner.hparams import DataArguments
logger = get_logger(__name__)
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
r"""
Dispatches a pre-trained model to GPUs with balanced memory.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
"""
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
return model
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
if model._no_split_modules is None:
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
max_memory = get_balanced_memory(model, **kwargs)
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
return dispatch_model(model, device_map)
else:
return model.cuda()
def find_all_linear_modules(
model: "PreTrainedModel",
quantization_bit: Optional[int] = None
) -> List[str]:
r"""
Finds all available modules to apply lora.
"""
if quantization_bit is not None:
import bitsandbytes as bnb
linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt
else:
linear_cls = torch.nn.Linear
output_layer_names = ["lm_head"]
if model.config.model_type == "chatglm":
output_layer_names.append("output_layer")
module_names = set()
for name, module in model.named_modules():
if (
isinstance(module, linear_cls)
and not any([output_layer in name for output_layer in output_layer_names])
):
module_names.add(name.split(".")[-1])
logger.info("Found linear modules: {}".format(",".join(module_names)))
return list(module_names)
def generate_model_card(
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments"
) -> Dict[str, Any]:
return {
"tasks": "text-generation",
"finetuned_from": model_args.model_name_or_path,
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
"tags": ["llama-factory"] + (["lora"] if finetuning_args.finetuning_type == "lora" else [])
}
def load_valuehead_params(
path_or_repo_id: str,
model_args: "ModelArguments"
) -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
kwargs = {
"path_or_repo_id": path_or_repo_id,
"cache_dir": model_args.cache_dir,
"token": model_args.hf_hub_token
}
try:
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
except:
try:
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
except:
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id))
return None
return torch.load(vhead_file, map_location="cpu")
def prepare_model_for_training(
model: "PreTrainedModel",
finetuning_args: "FinetuningArguments",
output_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True,
layernorm_names: Optional[Set[str]] = LAYERNORM_NAMES
) -> "PreTrainedModel":
r"""
Includes:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) upcast the lm_head to fp32
Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33
"""
if finetuning_args.upcast_layernorm:
for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in layernorm_names):
param.data = param.data.to(torch.float32)
logger.info("Upcasting weights in layernorm in float32.")
if finetuning_args.neft_alpha > 1e-6:
def neftune_forward_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
if module.training:
dims = torch.tensor(output.size(1) * output.size(2))
mag_norm = finetuning_args.neft_alpha / torch.sqrt(dims)
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
return output
model.get_input_embeddings().register_forward_hook(neftune_forward_hook)
logger.info("Using noisy embedding with alpha={:.2f}".format(finetuning_args.neft_alpha))
if use_gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
model.gradient_checkpointing_enable()
model.config.use_cache = False # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")
if finetuning_args.finetuning_type != "full" and hasattr(model, output_layer_name):
output_layer = getattr(model, output_layer_name)
if isinstance(output_layer, torch.nn.Linear):
def fp32_forward_pre_hook(module: torch.nn.Module, args: Tuple[torch.Tensor]):
return args[0].to(output_layer.weight.dtype)
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
return output.to(torch.float32)
output_layer.register_forward_pre_hook(fp32_forward_pre_hook)
output_layer.register_forward_hook(fp32_forward_post_hook)
return model

View File

@@ -0,0 +1 @@
from llmtuner.train.tuner import export_model, run_exp

View File

@@ -0,0 +1 @@
from llmtuner.train.dpo.workflow import run_dpo

View File

@@ -1,6 +1,6 @@
import torch
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
from transformers import BatchEncoding, Trainer
from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model
@@ -19,6 +19,7 @@ class CustomDPOTrainer(DPOTrainer):
model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: Optional[bool] = True,
loss_type: Optional[Literal["sigmoid", "hinge"]] = "sigmoid",
**kwargs
):
if disable_dropout:
@@ -29,9 +30,11 @@ class CustomDPOTrainer(DPOTrainer):
self.is_encoder_decoder = model.config.is_encoder_decoder
self.ref_model = ref_model
self.use_dpo_data_collator = True # hack to avoid warning
self.generate_during_eval = False # disable at evaluation
self.label_pad_token_id = IGNORE_INDEX
self.padding_value = 0
self.beta = beta
self.loss_type = loss_type
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, model=model, **kwargs)
@@ -40,8 +43,11 @@ class CustomDPOTrainer(DPOTrainer):
if ref_model is not None:
if self.is_deepspeed_enabled:
self.ref_model, = self.accelerator._prepare_deepspeed(self.ref_model)
self.ref_model.eval()
if not (
getattr(ref_model, "is_loaded_in_8bit", False)
or getattr(ref_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device
self.ref_model = self._prepare_deepspeed(self.ref_model)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

View File

@@ -1,20 +1,21 @@
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
from copy import deepcopy
from peft import PeftModel
from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
from llmtuner.tuner.dpo.trainer import CustomDPOTrainer
from llmtuner.hparams import ModelArguments
from llmtuner.model import generate_model_card, load_model_and_tokenizer
from llmtuner.train.utils import create_ref_model
from llmtuner.train.dpo.collator import DPODataCollatorWithPadding
from llmtuner.train.dpo.trainer import CustomDPOTrainer
if TYPE_CHECKING:
from transformers import TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.hparams import DataArguments, FinetuningArguments
def run_dpo(
@@ -33,6 +34,13 @@ def run_dpo(
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
)
# Create reference model
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
ref_model = model
else:
ref_model = create_ref_model(model_args, finetuning_args, stage="dpo")
# Update arguments
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
training_args = Seq2SeqTrainingArguments(**training_args_dict)
@@ -41,7 +49,7 @@ def run_dpo(
trainer = CustomDPOTrainer(
beta=finetuning_args.dpo_beta,
model=model,
ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None,
ref_model=ref_model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
@@ -52,9 +60,26 @@ def run_dpo(
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss:
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
if id(model) == id(ref_model): # unable to compute rewards without a reference model
remove_keys = [key for key in metrics.keys() if "rewards" in key]
for key in remove_keys:
metrics.pop(key)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Create model card
if training_args.do_train:
if training_args.push_to_hub:
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
else:
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))

View File

@@ -0,0 +1 @@
from llmtuner.train.ppo.workflow import run_ppo

View File

@@ -1,10 +1,11 @@
import os
import sys
import math
import torch
from tqdm import tqdm
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from trl import PPOTrainer
@@ -13,12 +14,12 @@ from trl.core import PPODecorators, logprobs_from_logits
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
from llmtuner.tuner.ppo.utils import dump_layernorm, restore_layernorm, replace_model
from llmtuner.train.ppo.utils import dump_layernorm, restore_layernorm, replace_model
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.hparams import ModelArguments, GeneratingArguments
from llmtuner.hparams import ModelArguments, FinetuningArguments, GeneratingArguments
logger = get_logger(__name__)
@@ -33,26 +34,46 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self,
model_args: "ModelArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: List["TrainerCallback"],
reward_model: "AutoModelForCausalLMWithValueHead",
**kwargs
):
PPOTrainer.__init__(self, **kwargs)
if getattr(self.accelerator.state, "deepspeed_plugin", None) is not None:
raise ValueError("PPOTrainer is incompatible with DeepSpeed.")
self.args = training_args
self.model_args = model_args
self.finetuning_args = finetuning_args
self.reward_model = reward_model
self.generation_config = GenerationConfig(
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
**generating_args.to_dict()
)
self.state = TrainerState()
self.control = TrainerControl()
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
if self.args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs")
if reward_model is not None:
is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
self.accelerator.state, "deepspeed_plugin"
)
if is_deepspeed_enabled:
if not (
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device
self.reward_model = self._prepare_deepspeed(self.reward_model)
else:
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
def ppo_train(self) -> None:
r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
@@ -60,10 +81,17 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
total_train_batch_size = (
self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size
)
len_dataloader = len(self.dataloader)
num_examples = len(self.dataset)
num_train_epochs = self.args.num_train_epochs
max_steps = math.ceil(num_train_epochs * len_dataloader)
if self.args.max_steps > 0:
num_examples = total_train_batch_size * self.args.max_steps
num_train_epochs = sys.maxsize
max_steps = self.args.max_steps
steps_in_epoch = self.args.max_steps * self.args.gradient_accumulation_steps
else:
len_dataloader = len(self.dataloader)
num_examples = len(self.dataset)
num_train_epochs = self.args.num_train_epochs
max_steps = math.ceil(num_train_epochs * len_dataloader)
steps_in_epoch = len_dataloader
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
@@ -82,14 +110,16 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
dataiter = iter(self.dataloader)
steps_trained = 0
loss_meter = AverageMeter()
reward_meter = AverageMeter()
self.log_callback.on_train_begin(self.args, self.state, self.control)
for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
batch = next(dataiter)
steps_trained += 1
try:
batch = next(dataiter)
except StopIteration:
dataiter = iter(self.dataloader)
batch = next(dataiter)
# Cast to inference mode
unwrapped_model.gradient_checkpointing_disable()
@@ -97,9 +127,14 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.model.eval()
# Get inputs
queries, responses = self.get_inputs(batch)
self.tokenizer.padding_side = "right" # change padding side
rewards = self.get_rewards(queries, responses, unwrapped_model)
queries, responses, rewards = [], [], []
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
mini_batch_queries, mini_batch_responses = self.get_inputs(batch[idx:idx+self.config.mini_batch_size])
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model)
queries.extend(mini_batch_queries)
responses.extend(mini_batch_responses)
rewards.extend(mini_batch_rewards)
# Cast to training mode
unwrapped_model.gradient_checkpointing_enable()
@@ -128,7 +163,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
loss=round(loss_meter.avg, 4),
reward=round(reward_meter.avg, 4),
learning_rate=stats["ppo/learning_rate"],
epoch=round(step / len_dataloader, 2)
epoch=round(step / steps_in_epoch, 2)
)
tqdm.write(str(logs))
logs["step"] = step
@@ -148,21 +183,17 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if self.control.should_epoch_stop or self.control.should_training_stop:
break
if steps_trained == len_dataloader:
dataiter = iter(self.dataloader)
steps_trained = 0
self.log_callback.on_train_end(self.args, self.state, self.control)
self.save_callback.on_train_end(
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
)
@torch.no_grad()
def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
def get_inputs(self, batch: BatchEncoding) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
r"""
Generates model's responses given queries.
"""
if self.model_args.upcast_layernorm:
if self.finetuning_args.upcast_layernorm:
layernorm_params = dump_layernorm(self.model)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
@@ -172,21 +203,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
**batch
)
if self.model_args.upcast_layernorm:
if self.finetuning_args.upcast_layernorm:
restore_layernorm(self.model, layernorm_params)
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
queries, responses = [], []
for i in range(len(query)):
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
if len(response_index) == 0:
response_length = 1 # allow empty response
elif self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
response_length = response_index[-1] + 2 # save the EOS token
else:
response_length = response_index[-1] + 1
response_length = response_index[-1].item() + 1
queries.append(query[i, query_length:]) # remove padding from left
responses.append(response[i, :response_length]) # remove padding from right
@@ -203,24 +232,30 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
r"""
Computes scores using given reward model.
"""
replace_model(unwrapped_model, target="reward")
if self.reward_model is None:
replace_model(unwrapped_model, target="reward")
batch = self.prepare_model_inputs(queries, responses)
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
_, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)
reward_model = self.reward_model if self.reward_model is not None else self.model
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
values = torch.transpose(values, 0, 1)
rewards = []
for i in range(values.size(0)):
end_index = batch["attention_mask"][i].nonzero()[-1] # use the score on the EOS token
end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero()
end_index = end_indexes[-1].item() if len(end_indexes) else 0
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
replace_model(unwrapped_model, target="default")
if self.reward_model is None:
replace_model(unwrapped_model, target="default")
return rewards
@PPODecorators.empty_cuda_cache()
@PPODecorators.empty_device_cache()
def batched_forward_pass(
self,
model: "AutoModelForCausalLMWithValueHead",
@@ -264,7 +299,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
for j in range(len(query_batch)):
start = len(query_batch[j]) - 1
if attention_mask[j, 0] == 0: # offset left padding
start += attention_mask[j, :].nonzero()[0]
start += attention_mask[j, :].nonzero()[0].item()
end = start + len(response_batch[j])
if response_masks is not None:

View File

@@ -7,11 +7,12 @@ from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorWithPadding
from transformers.optimization import get_scheduler
from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.data import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.ppo.trainer import CustomPPOTrainer
from llmtuner.model import load_model_and_tokenizer
from llmtuner.train.utils import create_ref_model, create_reward_model
from llmtuner.train.ppo.trainer import CustomPPOTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
@@ -33,6 +34,11 @@ def run_ppo(
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# Create reference model and reward model
ref_model = create_ref_model(model_args, finetuning_args, stage="ppo")
reward_model = create_reward_model(model, model_args, finetuning_args)
# Create ppo config
ppo_config = PPOConfig(
model_name=model_args.model_name_or_path,
learning_rate=training_args.learning_rate,
@@ -42,19 +48,25 @@ def run_ppo(
ppo_epochs=1,
max_grad_norm=training_args.max_grad_norm,
seed=training_args.seed,
optimize_cuda_cache=True,
optimize_device_cache=True,
target=finetuning_args.ppo_target,
log_with=finetuning_args.ppo_logger,
use_score_scaling=finetuning_args.ppo_score_norm,
use_score_norm=finetuning_args.ppo_score_norm,
whiten_rewards=finetuning_args.ppo_whiten_rewards,
accelerator_kwargs={"step_scheduler_with_optimizer": False}
)
# Create optimizer and scheduler
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
total_train_batch_size = (
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
)
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
if training_args.max_steps > 0:
num_training_steps = training_args.max_steps
else:
total_train_batch_size = (
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
)
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
lr_scheduler = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer,
@@ -66,11 +78,13 @@ def run_ppo(
ppo_trainer = CustomPPOTrainer(
model_args=model_args,
training_args=training_args,
finetuning_args=finetuning_args,
generating_args=generating_args,
callbacks=callbacks + [SavePeftModelCallback()],
reward_model=reward_model,
config=ppo_config,
model=model,
ref_model=None,
ref_model=ref_model,
tokenizer=tokenizer,
dataset=dataset,
data_collator=data_collator,
@@ -83,5 +97,5 @@ def run_ppo(
ppo_trainer.ppo_train()
ppo_trainer.save_model()
ppo_trainer.save_state() # must be called after save_model to have a folder
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "reward"])

View File

@@ -0,0 +1 @@
from llmtuner.train.pt.workflow import run_pt

View File

@@ -1,12 +1,12 @@
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/language-modeling/run_clm.py
import math
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForLanguageModeling, Trainer
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.model import generate_model_card, load_model_and_tokenizer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
@@ -38,11 +38,11 @@ def run_pt(
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss:
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation
@@ -56,3 +56,10 @@ def run_pt(
metrics["perplexity"] = perplexity
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Create model card
if training_args.do_train:
if training_args.push_to_hub:
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
else:
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))

View File

@@ -0,0 +1 @@
from llmtuner.train.rm.workflow import run_rm

View File

@@ -34,7 +34,7 @@ class PairwiseTrainer(Trainer):
Subclass and override to inject custom behavior.
Note that the first element will be removed from the output tuple.
Note that the first element will be removed from the output tuple.
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
"""
# Compute rewards
@@ -45,9 +45,6 @@ class PairwiseTrainer(Trainer):
# Split the inputs and rewards into two parts, chosen and rejected
batch_size = inputs["input_ids"].size(0) // 2
chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:]
chosen_attn_mask, rejected_attn_mask = (
inputs["attention_mask"][:batch_size], inputs["attention_mask"][batch_size:]
)
chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:]
chosen_scores, rejected_scores = [], []
@@ -55,8 +52,8 @@ class PairwiseTrainer(Trainer):
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py
loss = 0
for i in range(batch_size):
chosen_length = chosen_attn_mask[i].nonzero()[-1] + 1
rejected_length = rejected_attn_mask[i].nonzero()[-1] + 1
chosen_length = (chosen_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
rejected_length = (rejected_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero()
if len(check_divergence) == 0:
@@ -69,7 +66,7 @@ class PairwiseTrainer(Trainer):
assert div_index > 0
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
if return_outputs: # use the score on the EOS token for inference
if return_outputs: # use the score on the last token except pad token for inference
chosen_scores.append(chosen_rewards[i, chosen_length-1])
rejected_scores.append(rejected_rewards[i, rejected_length-1])
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
@@ -95,7 +92,6 @@ class PairwiseTrainer(Trainer):
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}")
chosen_scores, rejected_scores = predict_results.predictions
with open(output_prediction_file, "w", encoding="utf-8") as writer:

View File

@@ -1,16 +1,15 @@
# Inspired by:
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.rm.metric import compute_accuracy
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
from llmtuner.tuner.rm.trainer import PairwiseTrainer
from llmtuner.model import generate_model_card, load_model_and_tokenizer
from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding
from llmtuner.train.rm.metric import compute_accuracy
from llmtuner.train.rm.trainer import PairwiseTrainer
if TYPE_CHECKING:
from transformers import TrainerCallback
@@ -29,6 +28,7 @@ def run_rm(
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=4)
# Update arguments
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
training_args = Seq2SeqTrainingArguments(**training_args_dict)
@@ -47,11 +47,11 @@ def run_rm(
# Training
if training_args.do_train:
train_result = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss:
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation
@@ -66,3 +66,10 @@ def run_rm(
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(predict_results)
# Create model card
if training_args.do_train:
if training_args.push_to_hub:
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
else:
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))

View File

@@ -0,0 +1 @@
from llmtuner.train.sft.workflow import run_sft

View File

@@ -2,15 +2,23 @@ import numpy as np
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
import jieba
from rouge_chinese import Rouge
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.packages import (
is_jieba_available, is_nltk_available, is_rouge_available
)
if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer
if is_jieba_available():
import jieba
if is_nltk_available():
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
if is_rouge_available():
from rouge_chinese import Rouge
@dataclass
class ComputeMetrics:

View File

@@ -33,28 +33,20 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
Subclass and override to inject custom behavior.
"""
labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels
if self.args.predict_with_generate:
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
assert self.tokenizer.pad_token_id is not None, "Pad token is required."
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
if prompt_len > label_len:
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
if label_len > prompt_len:
inputs["input_ids"] = self._pad_tensors_to_target_len(inputs["input_ids"], inputs["labels"])
if "attention_mask" in inputs:
inputs["attention_mask"] = self._pad_tensors_to_target_len(
inputs["attention_mask"], inputs["labels"], pad_token_id=0
)
if "position_ids" in inputs:
inputs["position_ids"] = self._pad_tensors_to_target_len(
inputs["position_ids"], inputs["labels"], pad_token_id=0
)
inputs["labels"] = inputs["labels"][:, :prompt_len] # truncate the labels instead of padding the inputs
loss, generated_tokens, labels = super().prediction_step(
loss, generated_tokens, _ = super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
)
if generated_tokens is not None and self.args.predict_with_generate:
generated_tokens[:, :max(prompt_len, label_len)] = self.tokenizer.pad_token_id
generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id
generated_tokens = generated_tokens.contiguous()
return loss, generated_tokens, labels
@@ -62,14 +54,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
def _pad_tensors_to_target_len(
self,
src_tensor: torch.Tensor,
tgt_tensor: torch.Tensor,
pad_token_id: Optional[int] = None
tgt_tensor: torch.Tensor
) -> torch.Tensor:
r"""
Pads the tensor to the same length as the target tensor.
"""
pad_token_id = pad_token_id if pad_token_id is not None else self.tokenizer.pad_token_id
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
assert self.tokenizer.pad_token_id is not None, "Pad token is required."
padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor)
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
return padded_tensor.contiguous() # in contiguous memory

View File

@@ -1,15 +1,15 @@
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.sft.metric import ComputeMetrics
from llmtuner.tuner.sft.trainer import CustomSeq2SeqTrainer
from llmtuner.model import generate_model_card, load_model_and_tokenizer
from llmtuner.train.sft.metric import ComputeMetrics
from llmtuner.train.sft.trainer import CustomSeq2SeqTrainer
if TYPE_CHECKING:
from transformers import TrainerCallback
@@ -33,7 +33,7 @@ def run_sft(
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
pad_to_multiple_of=4, # for shift short attention
pad_to_multiple_of=4 if tokenizer.padding_side == "right" else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
)
@@ -65,11 +65,11 @@ def run_sft(
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss:
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation
@@ -88,3 +88,10 @@ def run_sft(
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(predict_results)
# Create model card
if training_args.do_train:
if training_args.push_to_hub:
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
else:
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))

View File

@@ -2,12 +2,12 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer
from llmtuner.tuner.pt import run_pt
from llmtuner.tuner.sft import run_sft
from llmtuner.tuner.rm import run_rm
from llmtuner.tuner.ppo import run_ppo
from llmtuner.tuner.dpo import run_dpo
from llmtuner.model import get_train_args, get_infer_args, load_model_and_tokenizer
from llmtuner.train.pt import run_pt
from llmtuner.train.sft import run_sft
from llmtuner.train.rm import run_rm
from llmtuner.train.ppo import run_ppo
from llmtuner.train.dpo import run_dpo
if TYPE_CHECKING:
from transformers import TrainerCallback
@@ -38,11 +38,11 @@ def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional
model_args, _, finetuning_args, _ = get_infer_args(args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.config.use_cache = True
tokenizer.padding_side = "left" # restore padding side
tokenizer.init_kwargs["padding_side"] = "left"
model.save_pretrained(model_args.export_dir, max_shard_size=max_shard_size)
model.save_pretrained(finetuning_args.export_dir, max_shard_size=max_shard_size)
try:
tokenizer.save_pretrained(model_args.export_dir)
tokenizer.padding_side = "left" # restore padding side
tokenizer.init_kwargs["padding_side"] = "left"
tokenizer.save_pretrained(finetuning_args.export_dir)
except:
logger.warning("Cannot save tokenizer, please copy the files manually.")

View File

@@ -0,0 +1,80 @@
import torch
from typing import TYPE_CHECKING, Literal, Union
from llmtuner.extras.logging import get_logger
from llmtuner.hparams import ModelArguments, FinetuningArguments
from llmtuner.model import load_model_and_tokenizer, load_valuehead_params
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from trl import AutoModelForCausalLMWithValueHead
logger = get_logger(__name__)
def create_ref_model(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
stage: Literal["ppo", "dpo"]
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
r"""
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
The valuehead parameter is randomly initialized since it is useless for PPO training.
"""
if finetuning_args.ref_model is not None:
ref_model_args_dict = model_args.to_dict()
ref_model_args_dict.update(dict(
model_name_or_path=finetuning_args.ref_model,
checkpoint_dir=finetuning_args.ref_model_checkpoint,
quantization_bit=finetuning_args.ref_model_quantization_bit
))
ref_model_args = ModelArguments(**ref_model_args_dict)
ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
ref_model, _ = load_model_and_tokenizer(ref_model_args, ref_finetuning_args, is_trainable=False, stage=stage)
logger.info("Created reference model from {}".format(finetuning_args.ref_model))
else:
if finetuning_args.finetuning_type == "lora":
ref_model = None
else:
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage=stage)
logger.info("Created reference model from the model itself.")
return ref_model
def create_reward_model(
model: "AutoModelForCausalLMWithValueHead",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments"
) -> "AutoModelForCausalLMWithValueHead":
r"""
Creates reward model for PPO training.
"""
if finetuning_args.reward_model_type == "lora":
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
if "default" in name:
param.data = param.data.to(torch.float32) # trainable params should in fp32
vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args)
assert vhead_params is not None, "Reward model is not correctly loaded."
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
return None
else:
reward_model_args_dict = model_args.to_dict()
reward_model_args_dict.update(dict(
model_name_or_path=finetuning_args.reward_model,
checkpoint_dir=finetuning_args.reward_model_checkpoint,
quantization_bit=finetuning_args.reward_model_quantization_bit
))
reward_model_args = ModelArguments(**reward_model_args_dict)
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
reward_model, _ = load_model_and_tokenizer(reward_model_args, reward_finetuning_args, is_trainable=False, stage="ppo")
logger.info("Load full weights of reward model from {}".format(finetuning_args.reward_model))
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
return reward_model

View File

@@ -1 +0,0 @@
from llmtuner.tuner.tune import export_model, run_exp

View File

@@ -1,2 +0,0 @@
from llmtuner.tuner.core.parser import get_train_args, get_infer_args
from llmtuner.tuner.core.loader import load_model_and_tokenizer

View File

@@ -1,74 +0,0 @@
import torch
from typing import TYPE_CHECKING, List, Optional
from llmtuner.extras.constants import LAYERNORM_NAMES
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
def find_all_linear_modules(
model: "PreTrainedModel",
quantization_bit: Optional[int] = None,
output_layer_name: Optional[str] = "lm_head"
) -> List[str]:
if quantization_bit is not None:
import bitsandbytes as bnb
linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt
else:
linear_cls = torch.nn.Linear
module_names = set()
for name, module in model.named_modules():
if output_layer_name not in name and isinstance(module, linear_cls):
module_names.add(name.split(".")[-1])
if output_layer_name in module_names:
module_names.pop(output_layer_name)
return list(module_names)
def prepare_model_for_training(
model: "PreTrainedModel",
upcast_layernorm: bool,
finetuning_type: str,
output_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True,
layernorm_names: Optional[List[str]] = LAYERNORM_NAMES
) -> "PreTrainedModel":
r"""
Includes:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) upcast the lm_head to fp32
Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33
"""
if upcast_layernorm:
for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in layernorm_names):
param.data = param.data.to(torch.float32)
if use_gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
model.gradient_checkpointing_enable()
model.config.use_cache = False # turn off when gradient checkpointing is enabled
if finetuning_type != "full" and hasattr(model, output_layer_name):
output_layer: torch.nn.Linear = getattr(model, output_layer_name)
input_dtype = output_layer.weight.dtype
class CastOutputToFloat(torch.nn.Sequential):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return super().forward(x.to(input_dtype)).to(torch.float32)
setattr(model, output_layer_name, CastOutputToFloat(output_layer))
return model

View File

@@ -1 +0,0 @@
from llmtuner.tuner.dpo.workflow import run_dpo

View File

@@ -1 +0,0 @@
from llmtuner.tuner.ppo.workflow import run_ppo

View File

@@ -1 +0,0 @@
from llmtuner.tuner.pt.workflow import run_pt

View File

@@ -1 +0,0 @@
from llmtuner.tuner.rm.workflow import run_rm

View File

@@ -1 +0,0 @@
from llmtuner.tuner.sft.workflow import run_sft

View File

@@ -1,7 +1,8 @@
import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from llmtuner.chat.stream_chat import ChatModel
from llmtuner.chat import ChatModel
from llmtuner.extras.misc import torch_gc
from llmtuner.hparams import GeneratingArguments
from llmtuner.webui.common import get_save_dir
@@ -13,32 +14,44 @@ if TYPE_CHECKING:
class WebChatModel(ChatModel):
def __init__(self, manager: "Manager", lazy_init: Optional[bool] = True) -> None:
def __init__(
self,
manager: "Manager",
demo_mode: Optional[bool] = False,
lazy_init: Optional[bool] = True
) -> None:
self.manager = manager
self.demo_mode = demo_mode
self.model = None
self.tokenizer = None
self.generating_args = GeneratingArguments()
if not lazy_init:
if not lazy_init: # read arguments from command line
super().__init__()
if demo_mode: # load openchat 3.5 by default
super().__init__(dict(model_name_or_path="openchat/openchat_3.5", template="openchat"))
@property
def loaded(self) -> bool:
return self.model is not None
def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
get = lambda name: data[self.manager.get_elem(name)]
get = lambda name: data[self.manager.get_elem_by_name(name)]
lang = get("top.lang")
error = ""
if self.loaded:
yield ALERTS["err_exists"][lang]
return
error = ALERTS["err_exists"][lang]
elif not get("top.model_name"):
error = ALERTS["err_no_model"][lang]
elif not get("top.model_path"):
error = ALERTS["err_no_path"][lang]
elif self.demo_mode:
error = ALERTS["err_demo"][lang]
if not get("top.model_name"):
yield ALERTS["err_no_model"][lang]
return
if not get("top.model_path"):
yield ALERTS["err_no_path"][lang]
if error:
gr.Warning(error)
yield error
return
if get("top.checkpoints"):
@@ -65,8 +78,11 @@ class WebChatModel(ChatModel):
yield ALERTS["info_loaded"][lang]
def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
get = lambda name: data[self.manager.get_elem(name)]
lang = get("top.lang")
lang = data[self.manager.get_elem_by_name("top.lang")]
if self.demo_mode:
yield ALERTS["err_demo"][lang]
return
yield ALERTS["info_unloading"][lang]
self.model = None

View File

@@ -61,26 +61,31 @@ def get_model_path(model_name: str) -> str:
return user_config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "")
def get_prefix(model_name: str) -> str:
return model_name.split("-")[0]
def get_module(model_name: str) -> str:
return DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj")
return DEFAULT_MODULE.get(get_prefix(model_name), "q_proj,v_proj")
def get_template(model_name: str) -> str:
if model_name.endswith("Chat") and model_name.split("-")[0] in DEFAULT_TEMPLATE:
return DEFAULT_TEMPLATE[model_name.split("-")[0]]
if model_name and model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE:
return DEFAULT_TEMPLATE[get_prefix(model_name)]
return "default"
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
checkpoints = []
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):
for checkpoint in os.listdir(save_dir):
if (
os.path.isdir(os.path.join(save_dir, checkpoint))
and any([os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CKPT_NAMES])
):
checkpoints.append(checkpoint)
if model_name:
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):
for checkpoint in os.listdir(save_dir):
if (
os.path.isdir(os.path.join(save_dir, checkpoint))
and any([os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CKPT_NAMES])
):
checkpoints.append(checkpoint)
return gr.update(value=[], choices=checkpoints)

View File

@@ -11,11 +11,9 @@ def create_chat_box(
engine: "Engine",
visible: Optional[bool] = False
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
elem_dict = dict()
with gr.Box(visible=visible) as chat_box:
chatbot = gr.Chatbot()
history = gr.State([])
with gr.Row():
with gr.Column(scale=4):
system = gr.Textbox(show_label=False)
@@ -29,13 +27,6 @@ def create_chat_box(
top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01)
temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01)
elem_dict.update(dict(
system=system, query=query, submit_btn=submit_btn, clear_btn=clear_btn,
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
))
history = gr.State([])
submit_btn.click(
engine.chatter.predict,
[chatbot, query, history, system, max_new_tokens, top_p, temperature],
@@ -47,4 +38,12 @@ def create_chat_box(
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
return chat_box, chatbot, history, elem_dict
return chat_box, chatbot, history, dict(
system=system,
query=query,
submit_btn=submit_btn,
clear_btn=clear_btn,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature
)

View File

@@ -1,21 +1,103 @@
import os
import json
import gradio as gr
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, Any, Dict, Tuple
from llmtuner.webui.common import DATA_CONFIG
if TYPE_CHECKING:
from gradio.blocks import Block
from gradio.components import Component
def create_preview_box() -> Tuple["Block", "Component", "Component", "Component"]:
with gr.Box(visible=False, elem_classes="modal-box") as preview_box:
PAGE_SIZE = 2
def prev_page(page_index: int) -> int:
return page_index - 1 if page_index > 0 else page_index
def next_page(page_index: int, total_num: int) -> int:
return page_index + 1 if (page_index + 1) * PAGE_SIZE < total_num else page_index
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f)
if (
len(dataset) > 0
and "file_name" in dataset_info[dataset[0]]
and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]))
):
return gr.update(interactive=True)
else:
return gr.update(interactive=False)
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, Dict[str, Any]]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f)
data_file: str = dataset_info[dataset[0]]["file_name"]
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
if data_file.endswith(".json"):
data = json.load(f)
elif data_file.endswith(".jsonl"):
data = [json.loads(line) for line in f]
else:
data = [line for line in f]
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.update(visible=True)
def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dict[str, "Component"]:
data_preview_btn = gr.Button(interactive=False, scale=1)
with gr.Column(visible=False, elem_classes="modal-box") as preview_box:
with gr.Row():
preview_count = gr.Number(interactive=False)
preview_count = gr.Number(value=0, interactive=False, precision=0)
page_index = gr.Number(value=0, interactive=False, precision=0)
with gr.Row():
prev_btn = gr.Button()
next_btn = gr.Button()
close_btn = gr.Button()
with gr.Row():
preview_samples = gr.JSON(interactive=False)
close_btn = gr.Button()
dataset.change(
can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False
).then(
lambda: 0, outputs=[page_index], queue=False
)
data_preview_btn.click(
get_preview,
[dataset_dir, dataset, page_index],
[preview_count, preview_samples, preview_box],
queue=False
)
prev_btn.click(
prev_page, [page_index], [page_index], queue=False
).then(
get_preview,
[dataset_dir, dataset, page_index],
[preview_count, preview_samples, preview_box],
queue=False
)
next_btn.click(
next_page, [page_index, preview_count], [page_index], queue=False
).then(
get_preview,
[dataset_dir, dataset, page_index],
[preview_count, preview_samples, preview_box],
queue=False
)
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
return preview_box, preview_count, preview_samples, close_btn
return dict(
data_preview_btn=data_preview_btn,
preview_count=preview_count,
page_index=page_index,
prev_btn=prev_btn,
next_btn=next_btn,
close_btn=close_btn,
preview_samples=preview_samples
)

View File

@@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Dict
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.utils import can_preview, get_preview
if TYPE_CHECKING:
from gradio.components import Component
@@ -17,28 +16,12 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4)
data_preview_btn = gr.Button(interactive=False, scale=1)
preview_elems = create_preview_box(dataset_dir, dataset)
dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False)
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False)
input_elems.update({dataset_dir, dataset})
elem_dict.update(dict(
dataset_dir=dataset_dir, dataset=dataset, data_preview_btn=data_preview_btn
))
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
data_preview_btn.click(
get_preview,
[dataset_dir, dataset],
[preview_count, preview_samples, preview_box],
queue=False
)
elem_dict.update(dict(
preview_count=preview_count, preview_samples=preview_samples, close_btn=close_btn
))
elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
with gr.Row():
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)

View File

@@ -1,16 +1,54 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING, Dict, Generator, List
from llmtuner.webui.utils import save_model
from llmtuner.train import export_model
from llmtuner.webui.common import get_save_dir
from llmtuner.webui.locales import ALERTS
if TYPE_CHECKING:
from gradio.components import Component
from llmtuner.webui.engine import Engine
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
elem_dict = dict()
def save_model(
lang: str,
model_name: str,
model_path: str,
checkpoints: List[str],
finetuning_type: str,
template: str,
max_shard_size: int,
export_dir: str
) -> Generator[str, None, None]:
error = ""
if not model_name:
error = ALERTS["err_no_model"][lang]
elif not model_path:
error = ALERTS["err_no_path"][lang]
elif not checkpoints:
error = ALERTS["err_no_checkpoint"][lang]
elif not export_dir:
error = ALERTS["err_no_export_dir"][lang]
if error:
gr.Warning(error)
yield error
return
args = dict(
model_name_or_path=model_path,
checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]),
finetuning_type=finetuning_type,
template=template,
export_dir=export_dir
)
yield ALERTS["info_exporting"][lang]
export_model(args, max_shard_size="{}GB".format(max_shard_size))
yield ALERTS["info_exported"][lang]
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row():
export_dir = gr.Textbox()
max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)
@@ -21,23 +59,21 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
export_btn.click(
save_model,
[
engine.manager.get_elem("top.lang"),
engine.manager.get_elem("top.model_name"),
engine.manager.get_elem("top.model_path"),
engine.manager.get_elem("top.checkpoints"),
engine.manager.get_elem("top.finetuning_type"),
engine.manager.get_elem("top.template"),
engine.manager.get_elem_by_name("top.lang"),
engine.manager.get_elem_by_name("top.model_name"),
engine.manager.get_elem_by_name("top.model_path"),
engine.manager.get_elem_by_name("top.checkpoints"),
engine.manager.get_elem_by_name("top.finetuning_type"),
engine.manager.get_elem_by_name("top.template"),
max_shard_size,
export_dir
],
[info_box]
)
elem_dict.update(dict(
return dict(
export_dir=export_dir,
max_shard_size=max_shard_size,
export_btn=export_btn,
info_box=info_box
))
return elem_dict
)

View File

@@ -1,8 +1,8 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict
from llmtuner.data.template import templates
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.extras.template import templates
from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, save_config
from llmtuner.webui.utils import can_quantize
@@ -31,9 +31,10 @@ def create_top() -> Dict[str, "Component"]:
with gr.Accordion(label="Model config (LLaMA only)", open=False) as llama_tab:
with gr.Row():
flash_attn = gr.Checkbox(value=False)
shift_attn = gr.Checkbox(value=False)
rope_scaling = gr.Dropdown(choices=["none", "linear", "dynamic"], value="none")
with gr.Column():
flash_attn = gr.Checkbox(value=False)
shift_attn = gr.Checkbox(value=False)
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
model_name.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False

View File

@@ -5,7 +5,7 @@ from transformers.trainer_utils import SchedulerType
from llmtuner.extras.constants import TRAINING_STAGES
from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
from llmtuner.webui.utils import gen_plot
if TYPE_CHECKING:
from gradio.components import Component
@@ -22,28 +22,14 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
)
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4)
data_preview_btn = gr.Button(interactive=False, scale=1)
preview_elems = create_preview_box(dataset_dir, dataset)
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False)
input_elems.update({training_stage, dataset_dir, dataset})
elem_dict.update(dict(
training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, data_preview_btn=data_preview_btn
))
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
data_preview_btn.click(
get_preview,
[dataset_dir, dataset],
[preview_count, preview_samples, preview_box],
queue=False
)
elem_dict.update(dict(
preview_count=preview_count, preview_samples=preview_samples, close_btn=close_btn
training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems
))
with gr.Row():
@@ -79,26 +65,30 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
neft_alpha = gr.Slider(value=0, minimum=0, maximum=10, step=0.1)
input_elems.update({logging_steps, save_steps, warmup_steps})
with gr.Column():
train_on_prompt = gr.Checkbox(value=False)
upcast_layernorm = gr.Checkbox(value=False)
input_elems.update({logging_steps, save_steps, warmup_steps, neft_alpha, train_on_prompt, upcast_layernorm})
elem_dict.update(dict(
advanced_tab=advanced_tab, logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps
advanced_tab=advanced_tab, logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps,
neft_alpha=neft_alpha, train_on_prompt=train_on_prompt, upcast_layernorm=upcast_layernorm
))
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
with gr.Row():
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
lora_target = gr.Textbox(scale=2)
lora_target = gr.Textbox(scale=1)
additional_target = gr.Textbox(scale=1)
resume_lora_training = gr.Checkbox(value=True, scale=1)
input_elems.update({lora_rank, lora_dropout, lora_target, resume_lora_training})
input_elems.update({lora_rank, lora_dropout, lora_target, additional_target, resume_lora_training})
elem_dict.update(dict(
lora_tab=lora_tab,
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target,
resume_lora_training=resume_lora_training,
lora_tab=lora_tab, lora_rank=lora_rank, lora_dropout=lora_dropout, lora_target=lora_target,
additional_target=additional_target, resume_lora_training=resume_lora_training,
))
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
@@ -109,7 +99,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
refresh_btn.click(
list_checkpoint,
[engine.manager.get_elem("top.model_name"), engine.manager.get_elem("top.finetuning_type")],
[engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
[reward_model],
queue=False
)
@@ -139,19 +129,24 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
input_elems.add(output_dir)
output_elems = [output_box, process_bar]
elem_dict.update(dict(
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir,
resume_btn=resume_btn, process_bar=process_bar, output_box=output_box, loss_viewer=loss_viewer
))
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems)
start_btn.click(engine.runner.run_train, input_elems, output_elems)
stop_btn.click(engine.runner.set_abort, queue=False)
resume_btn.change(engine.runner.monitor, outputs=output_elems)
elem_dict.update(dict(
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir,
resume_btn=resume_btn, process_bar=process_bar, output_box=output_box, loss_viewer=loss_viewer
))
output_box.change(
gen_plot,
[engine.manager.get_elem("top.model_name"), engine.manager.get_elem("top.finetuning_type"), output_dir],
[
engine.manager.get_elem_by_name("top.model_name"),
engine.manager.get_elem_by_name("top.finetuning_type"),
output_dir
],
loss_viewer,
queue=False
)

View File

@@ -1,4 +1,11 @@
CSS = r"""
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
.modal-box {
position: fixed !important;
top: 50%;
@@ -6,10 +13,12 @@ CSS = r"""
transform: translate(-50%, -50%); /* center horizontally */
max-width: 1000px;
max-height: 750px;
overflow-y: scroll !important;
overflow-y: auto;
background-color: var(--input-background-fill);
flex-wrap: nowrap !important;
border: 2px solid black !important;
z-index: 1000;
padding: 10px;
}
.dark .modal-box {

View File

@@ -12,43 +12,43 @@ from llmtuner.webui.utils import get_time
class Engine:
def __init__(self, pure_chat: Optional[bool] = False) -> None:
def __init__(self, demo_mode: Optional[bool] = False, pure_chat: Optional[bool] = False) -> None:
self.pure_chat = pure_chat
self.manager: "Manager" = Manager()
self.runner: "Runner" = Runner(self.manager)
self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not pure_chat))
self.manager = Manager()
self.runner = Runner(self.manager, demo_mode=demo_mode)
self.chatter = WebChatModel(manager=self.manager, demo_mode=demo_mode, lazy_init=(not pure_chat))
def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
return {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()}
return {self.manager.get_elem_by_name(k): gr.update(**v) for k, v in resume_dict.items()}
def resume(self) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
user_config = load_config()
lang = user_config.get("lang", None) or "en"
resume_dict = {
init_dict = {
"top.lang": {"value": lang},
"infer.chat_box": {"visible": self.chatter.loaded}
}
if not self.pure_chat:
resume_dict["train.dataset"] = {"choices": list_dataset()["choices"]}
resume_dict["eval.dataset"] = {"choices": list_dataset()["choices"]}
init_dict["train.dataset"] = {"choices": list_dataset()["choices"]}
init_dict["eval.dataset"] = {"choices": list_dataset()["choices"]}
if user_config.get("last_model", None):
resume_dict["top.model_name"] = {"value": user_config["last_model"]}
resume_dict["top.model_path"] = {"value": get_model_path(user_config["last_model"])}
init_dict["top.model_name"] = {"value": user_config["last_model"]}
init_dict["top.model_path"] = {"value": get_model_path(user_config["last_model"])}
yield self._form_dict(resume_dict)
yield self._form_dict(init_dict)
if self.runner.alive:
yield {elem: gr.update(value=value) for elem, value in self.runner.data.items()}
if self.runner.do_train:
resume_dict = {"train.resume_btn": {"value": True}}
if not self.pure_chat:
if self.runner.alive:
yield {elem: gr.update(value=value) for elem, value in self.runner.running_data.items()}
if self.runner.do_train:
yield self._form_dict({"train.resume_btn": {"value": True}})
else:
yield self._form_dict({"eval.resume_btn": {"value": True}})
else:
resume_dict = {"eval.resume_btn": {"value": True}}
else:
resume_dict = {"train.output_dir": {"value": get_time()}}
yield self._form_dict(resume_dict)
yield self._form_dict({"train.output_dir": {"value": get_time()}})
def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
return {

View File

@@ -1,4 +1,5 @@
import gradio as gr
from typing import Optional
from transformers.utils.versions import require_version
from llmtuner.webui.components import (
@@ -14,27 +15,38 @@ from llmtuner.webui.css import CSS
from llmtuner.webui.engine import Engine
require_version("gradio==3.38.0", "To fix: pip install gradio==3.38.0")
require_version("gradio>=3.38.0,<4.0.0", "To fix: pip install \"gradio>=3.38.0,<4.0.0\"")
def create_ui() -> gr.Blocks:
engine = Engine(pure_chat=False)
def create_ui(demo_mode: Optional[bool] = False) -> gr.Blocks:
engine = Engine(demo_mode=demo_mode, pure_chat=False)
with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
if demo_mode:
gr.HTML(
"<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>"
)
gr.HTML(
"<h3><center>Visit <a href=\"https://github.com/hiyouga/LLaMA-Factory\" target=\"_blank\">"
"LLaMA Factory</a> for details.</center></h3>"
)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
engine.manager.all_elems["top"] = create_top()
lang: "gr.Dropdown" = engine.manager.get_elem("top.lang")
lang: "gr.Dropdown" = engine.manager.get_elem_by_name("top.lang")
with gr.Tab("Train"):
engine.manager.all_elems["train"] = create_train_tab(engine)
with gr.Tab("Evaluate"):
with gr.Tab("Evaluate & Predict"):
engine.manager.all_elems["eval"] = create_eval_tab(engine)
with gr.Tab("Chat"):
engine.manager.all_elems["infer"] = create_infer_tab(engine)
with gr.Tab("Export"):
engine.manager.all_elems["export"] = create_export_tab(engine)
if not demo_mode:
with gr.Tab("Export"):
engine.manager.all_elems["export"] = create_export_tab(engine)
demo.load(engine.resume, outputs=engine.manager.list_elems())
lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)

View File

@@ -163,12 +163,28 @@ LOCALES = {
"label": "数量"
}
},
"preview_samples": {
"page_index": {
"en": {
"label": "Samples"
"label": "Page"
},
"zh": {
"label": "样例"
"label": "页数"
}
},
"prev_btn": {
"en": {
"value": "Prev"
},
"zh": {
"value": "上一页"
}
},
"next_btn": {
"en": {
"value": "Next"
},
"zh": {
"value": "下一页"
}
},
"close_btn": {
@@ -179,6 +195,14 @@ LOCALES = {
"value": "关闭"
}
},
"preview_samples": {
"en": {
"label": "Samples"
},
"zh": {
"label": "样例"
}
},
"cutoff_len": {
"en": {
"label": "Cutoff length",
@@ -309,6 +333,36 @@ LOCALES = {
"info": "学习率预热采用的步数。"
}
},
"neft_alpha": {
"en": {
"label": "NEFTune Alpha",
"info": "Magnitude of noise adding to embedding vectors."
},
"zh": {
"label": "NEFTune 噪声参数",
"info": "嵌入向量所添加的噪声大小。"
}
},
"train_on_prompt": {
"en": {
"label": "Train on prompt",
"info": "Compute loss on the prompt tokens in supervised fine-tuning."
},
"zh": {
"label": "计算输入损失",
"info": "在监督微调时候计算输入序列的损失。"
}
},
"upcast_layernorm": {
"en": {
"label": "Upcast LayerNorm",
"info": "Upcast weights of layernorm in float32."
},
"zh": {
"label": "缩放归一化层",
"info": "将归一化层权重缩放至 32 位浮点数。"
}
},
"lora_tab": {
"en": {
"label": "LoRA configurations"
@@ -340,11 +394,21 @@ LOCALES = {
"lora_target": {
"en": {
"label": "LoRA modules (optional)",
"info": "The name(s) of target modules to apply LoRA. Use commas to separate multiple modules."
"info": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules."
},
"zh": {
"label": "LoRA 作用(非必填)",
"info": "应用 LoRA 的线性层名称。使用英文逗号分隔多个名称。"
"label": "LoRA 作用模块(非必填)",
"info": "应用 LoRA 的目标模块名称。使用英文逗号分隔多个名称。"
}
},
"additional_target": {
"en": {
"label": "Additional modules (optional)",
"info": "Name(s) of modules apart from LoRA layers to be set as trainable. Use commas to separate multiple modules."
},
"zh": {
"label": "附加模块(非必填)",
"info": "除 LoRA 层以外的可训练模块名称。使用英文逗号分隔多个名称。"
}
},
"resume_lora_training": {
@@ -595,6 +659,10 @@ ALERTS = {
"en": "Failed.",
"zh": "训练出错。"
},
"err_demo": {
"en": "Training is unavailable in demo mode, duplicate the space to a private one first.",
"zh": "展示模式不支持训练,请先复制到私人空间。"
},
"info_aborting": {
"en": "Aborted, wait for terminating...",
"zh": "训练中断,正在等待线程结束……"

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Dict, List
from typing import TYPE_CHECKING, Dict, List, Set
if TYPE_CHECKING:
from gradio.components import Component
@@ -9,14 +9,14 @@ class Manager:
def __init__(self) -> None:
self.all_elems: Dict[str, Dict[str, "Component"]] = {}
def get_elem(self, name: str) -> "Component":
def get_elem_by_name(self, name: str) -> "Component":
r"""
Example: top.lang, train.dataset
"""
tab_name, elem_name = name.split(".")
return self.all_elems[tab_name][elem_name]
def get_base_elems(self):
def get_base_elems(self) -> Set["Component"]:
return {
self.all_elems["top"]["lang"],
self.all_elems["top"]["model_name"],

View File

@@ -4,7 +4,7 @@ import logging
import gradio as gr
from threading import Thread
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple
import transformers
from transformers.trainer import TRAINING_ARGS_NAME
@@ -13,7 +13,7 @@ from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import TRAINING_STAGES
from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import run_exp
from llmtuner.train import run_exp
from llmtuner.webui.common import get_module, get_save_dir, load_config
from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
@@ -24,14 +24,17 @@ if TYPE_CHECKING:
class Runner:
def __init__(self, manager: "Manager") -> None:
def __init__(self, manager: "Manager", demo_mode: Optional[bool] = False) -> None:
self.manager = manager
self.demo_mode = demo_mode
""" Resume """
self.thread: "Thread" = None
self.data: Dict["Component", Any] = None
self.do_train = True
self.monitor_inputs: Dict[str, str] = None
self.running_data: Dict["Component", Any] = None
""" State """
self.aborted = False
self.running = False
""" Handler """
self.logger_handler = LoggerHandler()
self.logger_handler.setLevel(logging.INFO)
logging.root.addHandler(self.logger_handler)
@@ -43,9 +46,12 @@ class Runner:
def set_abort(self) -> None:
self.aborted = True
self.running = False
def _initialize(self, lang: str, model_name: str, model_path: str, dataset: List[str]) -> str:
def _initialize(self, data: Dict[Component, Any], do_train: bool, from_preview: bool) -> str:
get = lambda name: data[self.manager.get_elem_by_name(name)]
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
dataset = get("train.dataset") if do_train else get("eval.dataset")
if self.running:
return ALERTS["err_conflict"][lang]
@@ -58,6 +64,9 @@ class Runner:
if len(dataset) == 0:
return ALERTS["err_no_dataset"][lang]
if self.demo_mode and (not from_preview):
return ALERTS["err_demo"][lang]
self.aborted = False
self.logger_handler.reset()
self.trainer_callback = LogCallback(self)
@@ -65,6 +74,7 @@ class Runner:
def _finalize(self, lang: str, finish_info: str) -> str:
self.thread = None
self.running_data = None
self.running = False
torch_gc()
if self.aborted:
@@ -72,24 +82,21 @@ class Runner:
else:
return finish_info
def _parse_train_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]:
get = lambda name: data[self.manager.get_elem(name)]
def _parse_train_args(self, data: Dict[Component, Any]) -> Dict[str, Any]:
get = lambda name: data[self.manager.get_elem_by_name(name)]
user_config = load_config()
if get("top.checkpoints"):
checkpoint_dir = ",".join([
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
])
checkpoint_dir = ",".join([get_save_dir(
get("top.model_name"), get("top.finetuning_type"), ckpt
) for ckpt in get("top.checkpoints")])
else:
checkpoint_dir = None
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir"))
args = dict(
stage=TRAINING_STAGES[get("train.training_stage")],
model_name_or_path=get("top.model_path"),
do_train=True,
overwrite_cache=False,
cache_dir=user_config.get("cache_dir", None),
checkpoint_dir=checkpoint_dir,
finetuning_type=get("top.finetuning_type"),
@@ -112,11 +119,15 @@ class Runner:
logging_steps=get("train.logging_steps"),
save_steps=get("train.save_steps"),
warmup_steps=get("train.warmup_steps"),
neft_alpha=get("train.neft_alpha"),
train_on_prompt=get("train.train_on_prompt"),
upcast_layernorm=get("train.upcast_layernorm"),
lora_rank=get("train.lora_rank"),
lora_dropout=get("train.lora_dropout"),
lora_target=get("train.lora_target") or get_module(get("top.model_name")),
additional_target=get("train.additional_target") if get("train.additional_target") else None,
resume_lora_training=get("train.resume_lora_training"),
output_dir=output_dir
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir"))
)
args[get("train.compute_type")] = True
args["disable_tqdm"] = True
@@ -128,7 +139,10 @@ class Runner:
args["upcast_layernorm"] = True
if args["stage"] == "ppo":
args["reward_model"] = get("train.reward_model")
args["reward_model"] = get_save_dir(
get("top.model_name"), get("top.finetuning_type"), get("train.reward_model")
)
args["reward_model_type"] = "lora" if get("top.finetuning_type") == "lora" else "full"
if args["stage"] == "dpo":
args["dpo_beta"] = get("train.dpo_beta")
@@ -139,16 +153,16 @@ class Runner:
args["eval_steps"] = get("train.save_steps")
args["load_best_model_at_end"] = True
return get("top.lang"), get("top.model_name"), get("top.model_path"), get("train.dataset"), output_dir, args
return args
def _parse_eval_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]:
get = lambda name: data[self.manager.get_elem(name)]
def _parse_eval_args(self, data: Dict[Component, Any]) -> Dict[str, Any]:
get = lambda name: data[self.manager.get_elem_by_name(name)]
user_config = load_config()
if get("top.checkpoints"):
checkpoint_dir = ",".join([
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
])
checkpoint_dir = ",".join([get_save_dir(
get("top.model_name"), get("top.finetuning_type"), ckpt
) for ckpt in get("top.checkpoints")])
output_dir = get_save_dir(
get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints"))
)
@@ -160,7 +174,6 @@ class Runner:
stage="sft",
model_name_or_path=get("top.model_path"),
do_eval=True,
overwrite_cache=False,
predict_with_generate=True,
cache_dir=user_config.get("cache_dir", None),
checkpoint_dir=checkpoint_dir,
@@ -179,34 +192,33 @@ class Runner:
max_new_tokens=get("eval.max_new_tokens"),
top_p=get("eval.top_p"),
temperature=get("eval.temperature"),
output_dir=get("eval.output_dir")
output_dir=output_dir
)
if get("eval.predict"):
args.pop("do_eval", None)
args["do_predict"] = True
return get("top.lang"), get("top.model_name"), get("top.model_path"), get("eval.dataset"), output_dir, args
return args
def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
parse_func = self._parse_train_args if do_train else self._parse_eval_args
lang, model_name, model_path, dataset, _, args = parse_func(data)
error = self._initialize(lang, model_name, model_path, dataset)
error = self._initialize(data, do_train, from_preview=True)
if error:
gr.Warning(error)
yield error, gr.update(visible=False)
else:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
yield gen_cmd(args), gr.update(visible=False)
def _launch(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
parse_func = self._parse_train_args if do_train else self._parse_eval_args
lang, model_name, model_path, dataset, output_dir, args = parse_func(data)
self.data, self.do_train, self.monitor_inputs = data, do_train, dict(lang=lang, output_dir=output_dir)
error = self._initialize(lang, model_name, model_path, dataset)
error = self._initialize(data, do_train, from_preview=False)
if error:
gr.Warning(error)
yield error, gr.update(visible=False)
else:
self.running = True
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
self.do_train, self.running_data = do_train, data
self.thread = Thread(target=run_exp, kwargs=run_kwargs)
self.thread.start()
yield from self.monitor()
@@ -224,7 +236,12 @@ class Runner:
yield from self._launch(data, do_train=False)
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, output_dir = self.monitor_inputs["lang"], self.monitor_inputs["output_dir"]
get = lambda name: self.running_data[self.manager.get_elem_by_name(name)]
self.running = True
lang = get("top.lang")
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get(
"{}.output_dir".format("train" if self.do_train else "eval")
))
while self.thread.is_alive():
time.sleep(2)
if self.aborted:

View File

@@ -1,19 +1,20 @@
import os
import json
import gradio as gr
import matplotlib.figure
import matplotlib.pyplot as plt
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict
from datetime import datetime
from llmtuner.extras.packages import is_matplotlib_available
from llmtuner.extras.ploting import smooth
from llmtuner.tuner import export_model
from llmtuner.webui.common import get_save_dir, DATA_CONFIG
from llmtuner.webui.locales import ALERTS
from llmtuner.webui.common import get_save_dir
if TYPE_CHECKING:
from llmtuner.extras.callbacks import LogCallback
if is_matplotlib_available():
import matplotlib.figure
import matplotlib.pyplot as plt
def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
if not callback.max_steps:
@@ -33,37 +34,6 @@ def get_time() -> str:
return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f)
if (
len(dataset) > 0
and "file_name" in dataset_info[dataset[0]]
and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]))
):
return gr.update(interactive=True)
else:
return gr.update(interactive=False)
def get_preview(
dataset_dir: str, dataset: list, start: Optional[int] = 0, end: Optional[int] = 2
) -> Tuple[int, list, Dict[str, Any]]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f)
data_file: str = dataset_info[dataset[0]]["file_name"]
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
if data_file.endswith(".json"):
data = json.load(f)
elif data_file.endswith(".jsonl"):
data = [json.loads(line) for line in f]
else:
data = [line for line in f]
return len(data), data[start:end], gr.update(visible=True)
def can_quantize(finetuning_type: str) -> Dict[str, Any]:
if finetuning_type != "lora":
return gr.update(value="None", interactive=False)
@@ -89,10 +59,12 @@ def get_eval_results(path: os.PathLike) -> str:
return "```json\n{}\n```\n".format(result)
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotlib.figure.Figure:
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> "matplotlib.figure.Figure":
if not base_model:
return
log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl")
if not os.path.isfile(log_file):
return None
return
plt.close("all")
fig = plt.figure()
@@ -114,42 +86,3 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
ax.set_xlabel("step")
ax.set_ylabel("loss")
return fig
def save_model(
lang: str,
model_name: str,
model_path: str,
checkpoints: List[str],
finetuning_type: str,
template: str,
max_shard_size: int,
export_dir: str
) -> Generator[str, None, None]:
if not model_name:
yield ALERTS["err_no_model"][lang]
return
if not model_path:
yield ALERTS["err_no_path"][lang]
return
if not checkpoints:
yield ALERTS["err_no_checkpoint"][lang]
return
if not export_dir:
yield ALERTS["err_no_export_dir"][lang]
return
args = dict(
model_name_or_path=model_path,
checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]),
finetuning_type=finetuning_type,
template=template,
export_dir=export_dir
)
yield ALERTS["info_exporting"][lang]
export_model(args, max_shard_size="{}GB".format(max_shard_size))
yield ALERTS["info_exported"][lang]

View File

@@ -12,7 +12,7 @@ from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore
from llmtuner import ChatModel
def calculate(
def calculate_flops(
model_name_or_path: str,
batch_size: Optional[int] = 1,
seq_length: Optional[int] = 256,
@@ -41,4 +41,4 @@ def calculate(
if __name__ == "__main__":
fire.Fire(calculate)
fire.Fire(calculate_flops)

63
tests/cal_lr.py Normal file
View File

@@ -0,0 +1,63 @@
# coding=utf-8
# Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
# Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py
import fire
import math
import torch
from tqdm import tqdm
from typing import Optional
from torch.utils.data import DataLoader
from transformers import DataCollatorForSeq2Seq
from llmtuner.data import get_dataset, preprocess_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.model import get_train_args, load_model_and_tokenizer
BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models
BASE_BS = 4_000_000 # from llama paper
def calculate_lr(
model_name_or_path: str,
dataset: str,
cutoff_len: int, # i.e. maximum input length during training
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
is_mistral: bool, # mistral model uses a smaller learning rate,
dataset_dir: Optional[str] = "data"
):
model_args, data_args, training_args, finetuning_args, _ = get_train_args(dict(
stage="sft",
model_name_or_path=model_name_or_path,
dataset=dataset,
dataset_dir=dataset_dir,
template="default",
cutoff_len=cutoff_len,
output_dir="dummy_dir"
))
trainset = get_dataset(model_args, data_args)
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft")
trainset = preprocess_dataset(trainset, tokenizer, data_args, training_args, stage="sft")
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
dataloader = DataLoader(
dataset=trainset, batch_size=batch_size, shuffle=True, collate_fn=data_collator, pin_memory=True
)
valid_tokens, total_tokens = 0, 0
for batch in tqdm(dataloader):
valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
total_tokens += torch.numel(batch["labels"])
batch_max_len = cutoff_len * batch_size # max tokens in a batch
valid_ratio = valid_tokens / total_tokens
batch_valid_len = batch_max_len * valid_ratio
lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS) # lr ~ sqrt(batch_size)
lr = lr / 6.0 if is_mistral else lr
print("Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
lr, valid_ratio * 100, batch_valid_len
))
if __name__ == "__main__":
fire.Fire(calculate_lr)

View File

@@ -4,7 +4,6 @@
# --max_length 1024 --max_samples 1024
# dataset format: instruction (string), input (string), output (string), history (List[string])
import fire
from datasets import load_dataset
from transformers import AutoTokenizer