add unittest
Former-commit-id: 8a1f0c5f922989e08a19c65de0b2c4afd2a5771f
This commit is contained in:
@@ -16,8 +16,7 @@ import os
|
||||
|
||||
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
||||
|
||||
from llamafactory.hparams import get_infer_args
|
||||
from llamafactory.model import load_model, load_tokenizer
|
||||
from llamafactory.train.test_utils import load_infer_model
|
||||
|
||||
|
||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
@@ -42,9 +41,7 @@ def test_attention():
|
||||
"fa2": "LlamaFlashAttention2",
|
||||
}
|
||||
for requested_attention in attention_available:
|
||||
model_args, _, finetuning_args, _ = get_infer_args({"flash_attn": requested_attention, **INFER_ARGS})
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args)
|
||||
model = load_infer_model(flash_attn=requested_attention, **INFER_ARGS)
|
||||
for module in model.modules():
|
||||
if "Attention" in module.__class__.__name__:
|
||||
assert module.__class__.__name__ == llama_attention_classes[requested_attention]
|
||||
|
||||
@@ -17,8 +17,7 @@ import os
|
||||
import torch
|
||||
|
||||
from llamafactory.extras.misc import get_current_device
|
||||
from llamafactory.hparams import get_train_args
|
||||
from llamafactory.model import load_model, load_tokenizer
|
||||
from llamafactory.train.test_utils import load_train_model
|
||||
|
||||
|
||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
@@ -41,34 +40,26 @@ TRAIN_ARGS = {
|
||||
|
||||
|
||||
def test_checkpointing_enable():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args({"disable_gradient_checkpointing": False, **TRAIN_ARGS})
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
model = load_train_model(disable_gradient_checkpointing=False, **TRAIN_ARGS)
|
||||
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
|
||||
assert getattr(module, "gradient_checkpointing") is True
|
||||
|
||||
|
||||
def test_checkpointing_disable():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args({"disable_gradient_checkpointing": True, **TRAIN_ARGS})
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
model = load_train_model(disable_gradient_checkpointing=True, **TRAIN_ARGS)
|
||||
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
|
||||
assert getattr(module, "gradient_checkpointing") is False
|
||||
|
||||
|
||||
def test_upcast_layernorm():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args({"upcast_layernorm": True, **TRAIN_ARGS})
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
model = load_train_model(upcast_layernorm=True, **TRAIN_ARGS)
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and "norm" in name:
|
||||
assert param.dtype == torch.float32
|
||||
|
||||
|
||||
def test_upcast_lmhead_output():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args({"upcast_lmhead_output": True, **TRAIN_ARGS})
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
model = load_train_model(upcast_lmhead_output=True, **TRAIN_ARGS)
|
||||
inputs = torch.randn((1, 16), dtype=torch.float16, device=get_current_device())
|
||||
outputs: "torch.Tensor" = model.get_output_embeddings()(inputs)
|
||||
assert outputs.dtype == torch.float32
|
||||
|
||||
@@ -13,16 +13,15 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from llamafactory.extras.misc import get_current_device
|
||||
from llamafactory.hparams import get_infer_args
|
||||
from llamafactory.model import load_model, load_tokenizer
|
||||
from llamafactory.train.test_utils import (
|
||||
compare_model,
|
||||
load_infer_model,
|
||||
load_reference_model,
|
||||
patch_valuehead_model,
|
||||
)
|
||||
|
||||
|
||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
@@ -36,45 +35,19 @@ INFER_ARGS = {
|
||||
}
|
||||
|
||||
|
||||
def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"):
|
||||
state_dict_a = model_a.state_dict()
|
||||
state_dict_b = model_b.state_dict()
|
||||
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
|
||||
for name in state_dict_a.keys():
|
||||
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fix_valuehead_cpu_loading():
|
||||
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]):
|
||||
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
|
||||
self.v_head.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
|
||||
AutoModelForCausalLMWithValueHead.post_init = post_init
|
||||
patch_valuehead_model()
|
||||
|
||||
|
||||
def test_base():
|
||||
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False)
|
||||
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(
|
||||
TINY_LLAMA, torch_dtype=torch.float16, device_map=get_current_device()
|
||||
)
|
||||
model = load_infer_model(**INFER_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA)
|
||||
compare_model(model, ref_model)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
|
||||
def test_valuehead():
|
||||
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(
|
||||
tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False, add_valuehead=True
|
||||
)
|
||||
|
||||
ref_model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
TINY_LLAMA_VALUEHEAD, torch_dtype=torch.float16, device_map=get_current_device()
|
||||
)
|
||||
ref_model.v_head = ref_model.v_head.to(torch.float16)
|
||||
model = load_infer_model(add_valuehead=True, **INFER_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA_VALUEHEAD, add_valuehead=True)
|
||||
compare_model(model, ref_model)
|
||||
|
||||
@@ -16,8 +16,7 @@ import os
|
||||
|
||||
import torch
|
||||
|
||||
from llamafactory.hparams import get_infer_args, get_train_args
|
||||
from llamafactory.model import load_model, load_tokenizer
|
||||
from llamafactory.train.test_utils import load_infer_model, load_train_model
|
||||
|
||||
|
||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
@@ -46,10 +45,7 @@ INFER_ARGS = {
|
||||
|
||||
|
||||
def test_freeze_train_all_modules():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args({"freeze_trainable_layers": 1, **TRAIN_ARGS})
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
|
||||
model = load_train_model(freeze_trainable_layers=1, **TRAIN_ARGS)
|
||||
for name, param in model.named_parameters():
|
||||
if name.startswith("model.layers.1."):
|
||||
assert param.requires_grad is True
|
||||
@@ -60,12 +56,7 @@ def test_freeze_train_all_modules():
|
||||
|
||||
|
||||
def test_freeze_train_extra_modules():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args(
|
||||
{"freeze_trainable_layers": 1, "freeze_extra_modules": "embed_tokens,lm_head", **TRAIN_ARGS}
|
||||
)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
|
||||
model = load_train_model(freeze_trainable_layers=1, freeze_extra_modules="embed_tokens,lm_head", **TRAIN_ARGS)
|
||||
for name, param in model.named_parameters():
|
||||
if name.startswith("model.layers.1.") or any(module in name for module in ["embed_tokens", "lm_head"]):
|
||||
assert param.requires_grad is True
|
||||
@@ -76,10 +67,7 @@ def test_freeze_train_extra_modules():
|
||||
|
||||
|
||||
def test_freeze_inference():
|
||||
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False)
|
||||
|
||||
model = load_infer_model(**INFER_ARGS)
|
||||
for param in model.parameters():
|
||||
assert param.requires_grad is False
|
||||
assert param.dtype == torch.float16
|
||||
|
||||
@@ -16,8 +16,7 @@ import os
|
||||
|
||||
import torch
|
||||
|
||||
from llamafactory.hparams import get_infer_args, get_train_args
|
||||
from llamafactory.model import load_model, load_tokenizer
|
||||
from llamafactory.train.test_utils import load_infer_model, load_train_model
|
||||
|
||||
|
||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
@@ -46,20 +45,14 @@ INFER_ARGS = {
|
||||
|
||||
|
||||
def test_full_train():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args(TRAIN_ARGS)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
|
||||
model = load_train_model(**TRAIN_ARGS)
|
||||
for param in model.parameters():
|
||||
assert param.requires_grad is True
|
||||
assert param.dtype == torch.float32
|
||||
|
||||
|
||||
def test_full_inference():
|
||||
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False)
|
||||
|
||||
model = load_infer_model(**INFER_ARGS)
|
||||
for param in model.parameters():
|
||||
assert param.requires_grad is False
|
||||
assert param.dtype == torch.float16
|
||||
|
||||
@@ -13,17 +13,18 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Dict, Sequence
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from peft import LoraModel, PeftModel
|
||||
from transformers import AutoModelForCausalLM
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from llamafactory.extras.misc import get_current_device
|
||||
from llamafactory.hparams import get_infer_args, get_train_args
|
||||
from llamafactory.model import load_model, load_tokenizer
|
||||
from llamafactory.train.test_utils import (
|
||||
check_lora_model,
|
||||
compare_model,
|
||||
load_infer_model,
|
||||
load_reference_model,
|
||||
load_train_model,
|
||||
patch_valuehead_model,
|
||||
)
|
||||
|
||||
|
||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
@@ -56,116 +57,38 @@ INFER_ARGS = {
|
||||
}
|
||||
|
||||
|
||||
def load_reference_model(is_trainable: bool = False) -> "LoraModel":
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
TINY_LLAMA, torch_dtype=torch.float16, device_map=get_current_device()
|
||||
)
|
||||
lora_model = PeftModel.from_pretrained(model, TINY_LLAMA_ADAPTER, is_trainable=is_trainable)
|
||||
for param in filter(lambda p: p.requires_grad, lora_model.parameters()):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
return lora_model
|
||||
|
||||
|
||||
def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_keys: Sequence[str] = []):
|
||||
state_dict_a = model_a.state_dict()
|
||||
state_dict_b = model_b.state_dict()
|
||||
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
|
||||
for name in state_dict_a.keys():
|
||||
if any(key in name for key in diff_keys):
|
||||
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is False
|
||||
else:
|
||||
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fix_valuehead_cpu_loading():
|
||||
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]):
|
||||
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
|
||||
self.v_head.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
|
||||
AutoModelForCausalLMWithValueHead.post_init = post_init
|
||||
patch_valuehead_model()
|
||||
|
||||
|
||||
def test_lora_train_qv_modules():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args({"lora_target": "q_proj,v_proj", **TRAIN_ARGS})
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
|
||||
linear_modules = set()
|
||||
for name, param in model.named_parameters():
|
||||
if any(module in name for module in ["lora_A", "lora_B"]):
|
||||
linear_modules.add(name.split(".lora_", maxsplit=1)[0].split(".")[-1])
|
||||
assert param.requires_grad is True
|
||||
assert param.dtype == torch.float32
|
||||
else:
|
||||
assert param.requires_grad is False
|
||||
assert param.dtype == torch.float16
|
||||
|
||||
model = load_train_model(lora_target="q_proj,v_proj", **TRAIN_ARGS)
|
||||
linear_modules, _ = check_lora_model(model)
|
||||
assert linear_modules == {"q_proj", "v_proj"}
|
||||
|
||||
|
||||
def test_lora_train_all_modules():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args({"lora_target": "all", **TRAIN_ARGS})
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
|
||||
linear_modules = set()
|
||||
for name, param in model.named_parameters():
|
||||
if any(module in name for module in ["lora_A", "lora_B"]):
|
||||
linear_modules.add(name.split(".lora_", maxsplit=1)[0].split(".")[-1])
|
||||
assert param.requires_grad is True
|
||||
assert param.dtype == torch.float32
|
||||
else:
|
||||
assert param.requires_grad is False
|
||||
assert param.dtype == torch.float16
|
||||
|
||||
model = load_train_model(lora_target="all", **TRAIN_ARGS)
|
||||
linear_modules, _ = check_lora_model(model)
|
||||
assert linear_modules == {"q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"}
|
||||
|
||||
|
||||
def test_lora_train_extra_modules():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args(
|
||||
{"lora_target": "all", "additional_target": "embed_tokens,lm_head", **TRAIN_ARGS}
|
||||
)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
|
||||
extra_modules = set()
|
||||
for name, param in model.named_parameters():
|
||||
if any(module in name for module in ["lora_A", "lora_B"]):
|
||||
assert param.requires_grad is True
|
||||
assert param.dtype == torch.float32
|
||||
elif "modules_to_save" in name:
|
||||
extra_modules.add(name.split(".modules_to_save", maxsplit=1)[0].split(".")[-1])
|
||||
assert param.requires_grad is True
|
||||
assert param.dtype == torch.float32
|
||||
else:
|
||||
assert param.requires_grad is False
|
||||
assert param.dtype == torch.float16
|
||||
|
||||
model = load_train_model(additional_target="embed_tokens,lm_head", **TRAIN_ARGS)
|
||||
_, extra_modules = check_lora_model(model)
|
||||
assert extra_modules == {"embed_tokens", "lm_head"}
|
||||
|
||||
|
||||
def test_lora_train_old_adapters():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args(
|
||||
{"adapter_name_or_path": TINY_LLAMA_ADAPTER, "create_new_adapter": False, **TRAIN_ARGS}
|
||||
)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
|
||||
ref_model = load_reference_model(is_trainable=True)
|
||||
model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=False, **TRAIN_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
|
||||
compare_model(model, ref_model)
|
||||
|
||||
|
||||
def test_lora_train_new_adapters():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args(
|
||||
{"adapter_name_or_path": TINY_LLAMA_ADAPTER, "create_new_adapter": True, **TRAIN_ARGS}
|
||||
)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
|
||||
ref_model = load_reference_model(is_trainable=True)
|
||||
model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=True, **TRAIN_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
|
||||
compare_model(
|
||||
model, ref_model, diff_keys=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"]
|
||||
)
|
||||
@@ -173,26 +96,15 @@ def test_lora_train_new_adapters():
|
||||
|
||||
@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
|
||||
def test_lora_train_valuehead():
|
||||
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(
|
||||
tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True, add_valuehead=True
|
||||
)
|
||||
|
||||
ref_model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
TINY_LLAMA_VALUEHEAD, torch_dtype=torch.float16, device_map=get_current_device()
|
||||
)
|
||||
model = load_train_model(add_valuehead=True, **TRAIN_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA_VALUEHEAD, is_trainable=True, add_valuehead=True)
|
||||
state_dict = model.state_dict()
|
||||
ref_state_dict = ref_model.state_dict()
|
||||
|
||||
assert torch.allclose(state_dict["v_head.summary.weight"], ref_state_dict["v_head.summary.weight"])
|
||||
assert torch.allclose(state_dict["v_head.summary.bias"], ref_state_dict["v_head.summary.bias"])
|
||||
|
||||
|
||||
def test_lora_inference():
|
||||
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False)
|
||||
|
||||
ref_model = load_reference_model().merge_and_unload()
|
||||
model = load_infer_model(**INFER_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA, TINY_LLAMA_ADAPTER, use_lora=True).merge_and_unload()
|
||||
compare_model(model, ref_model)
|
||||
|
||||
@@ -14,13 +14,7 @@
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
from peft import LoraModel, PeftModel
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from llamafactory.extras.misc import get_current_device
|
||||
from llamafactory.hparams import get_infer_args, get_train_args
|
||||
from llamafactory.model import load_model, load_tokenizer
|
||||
from llamafactory.train.test_utils import compare_model, load_infer_model, load_reference_model, load_train_model
|
||||
|
||||
|
||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
@@ -54,37 +48,14 @@ INFER_ARGS = {
|
||||
}
|
||||
|
||||
|
||||
def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"):
|
||||
state_dict_a = model_a.state_dict()
|
||||
state_dict_b = model_b.state_dict()
|
||||
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
|
||||
for name in state_dict_a.keys():
|
||||
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5)
|
||||
|
||||
|
||||
def test_pissa_init():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args(TRAIN_ARGS)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
TINY_LLAMA_PISSA, torch_dtype=torch.float16, device_map=get_current_device()
|
||||
)
|
||||
ref_model = PeftModel.from_pretrained(base_model, TINY_LLAMA_PISSA, subfolder="pissa_init", is_trainable=True)
|
||||
for param in filter(lambda p: p.requires_grad, ref_model.parameters()):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
def test_pissa_train():
|
||||
model = load_train_model(**TRAIN_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA_PISSA, TINY_LLAMA_PISSA, use_pissa=True, is_trainable=True)
|
||||
compare_model(model, ref_model)
|
||||
|
||||
|
||||
def test_pissa_inference():
|
||||
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False)
|
||||
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
TINY_LLAMA_PISSA, torch_dtype=torch.float16, device_map=get_current_device()
|
||||
)
|
||||
ref_model: "LoraModel" = PeftModel.from_pretrained(base_model, TINY_LLAMA_PISSA, subfolder="pissa_init")
|
||||
model = load_infer_model(**INFER_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA_PISSA, TINY_LLAMA_PISSA, use_pissa=True, is_trainable=False)
|
||||
ref_model = ref_model.merge_and_unload()
|
||||
compare_model(model, ref_model)
|
||||
|
||||
Reference in New Issue
Block a user