Former-commit-id: 2289436567a7860d25d9da0afb39e4a3e5e83839
This commit is contained in:
hiyouga
2024-06-17 17:47:25 +08:00
parent ca67b7a568
commit 485a80d294
7 changed files with 32 additions and 39 deletions

View File

@@ -56,9 +56,15 @@ INFER_ARGS = {
}
def load_reference_model() -> "torch.nn.Module":
model = AutoModelForCausalLM.from_pretrained(TINY_LLAMA)
return PeftModel.from_pretrained(model, TINY_LLAMA_ADAPTER)
def load_reference_model(is_trainable: bool = False) -> "LoraModel":
model = AutoModelForCausalLM.from_pretrained(
TINY_LLAMA, torch_dtype=torch.float16, device_map=get_current_device()
)
lora_model = PeftModel.from_pretrained(model, TINY_LLAMA_ADAPTER, is_trainable=is_trainable)
for param in filter(lambda p: p.requires_grad, lora_model.parameters()):
param.data = param.data.to(torch.float32)
return lora_model
def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_keys: Sequence[str] = []):
@@ -148,13 +154,7 @@ def test_lora_train_old_adapters():
tokenizer_module = load_tokenizer(model_args)
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
base_model = AutoModelForCausalLM.from_pretrained(
TINY_LLAMA, torch_dtype=torch.float16, device_map=get_current_device()
)
ref_model = PeftModel.from_pretrained(base_model, TINY_LLAMA_ADAPTER, is_trainable=True)
for param in filter(lambda p: p.requires_grad, ref_model.parameters()):
param.data = param.data.to(torch.float32)
ref_model = load_reference_model(is_trainable=True)
compare_model(model, ref_model)
@@ -165,13 +165,7 @@ def test_lora_train_new_adapters():
tokenizer_module = load_tokenizer(model_args)
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
base_model = AutoModelForCausalLM.from_pretrained(
TINY_LLAMA, torch_dtype=torch.float16, device_map=get_current_device()
)
ref_model = PeftModel.from_pretrained(base_model, TINY_LLAMA_ADAPTER, is_trainable=True)
for param in filter(lambda p: p.requires_grad, ref_model.parameters()):
param.data = param.data.to(torch.float32)
ref_model = load_reference_model(is_trainable=True)
compare_model(
model, ref_model, diff_keys=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"]
)
@@ -200,9 +194,5 @@ def test_lora_inference():
tokenizer_module = load_tokenizer(model_args)
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False)
base_model = AutoModelForCausalLM.from_pretrained(
TINY_LLAMA, torch_dtype=torch.float16, device_map=get_current_device()
)
ref_model: "LoraModel" = PeftModel.from_pretrained(base_model, TINY_LLAMA_ADAPTER)
ref_model = ref_model.merge_and_unload()
ref_model = load_reference_model().merge_and_unload()
compare_model(model, ref_model)