[misc] fix packing and eval plot (#7623)
This commit is contained in:
@@ -27,14 +27,14 @@ from llamafactory.train.test_utils import (
|
||||
)
|
||||
|
||||
|
||||
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
TINY_LLAMA_ADAPTER = os.getenv("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-lora")
|
||||
|
||||
TINY_LLAMA_VALUEHEAD = os.getenv("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead")
|
||||
|
||||
TRAIN_ARGS = {
|
||||
"model_name_or_path": TINY_LLAMA,
|
||||
"model_name_or_path": TINY_LLAMA3,
|
||||
"stage": "sft",
|
||||
"do_train": True,
|
||||
"finetuning_type": "lora",
|
||||
@@ -48,7 +48,7 @@ TRAIN_ARGS = {
|
||||
}
|
||||
|
||||
INFER_ARGS = {
|
||||
"model_name_or_path": TINY_LLAMA,
|
||||
"model_name_or_path": TINY_LLAMA3,
|
||||
"adapter_name_or_path": TINY_LLAMA_ADAPTER,
|
||||
"finetuning_type": "lora",
|
||||
"template": "llama3",
|
||||
@@ -81,13 +81,13 @@ def test_lora_train_extra_modules():
|
||||
|
||||
def test_lora_train_old_adapters():
|
||||
model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=False, **TRAIN_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
|
||||
ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
|
||||
compare_model(model, ref_model)
|
||||
|
||||
|
||||
def test_lora_train_new_adapters():
|
||||
model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=True, **TRAIN_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
|
||||
ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
|
||||
compare_model(
|
||||
model, ref_model, diff_keys=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"]
|
||||
)
|
||||
@@ -105,5 +105,5 @@ def test_lora_train_valuehead():
|
||||
|
||||
def test_lora_inference():
|
||||
model = load_infer_model(**INFER_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA, TINY_LLAMA_ADAPTER, use_lora=True).merge_and_unload()
|
||||
ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True).merge_and_unload()
|
||||
compare_model(model, ref_model)
|
||||
|
||||
Reference in New Issue
Block a user