update tests

Former-commit-id: 4e92b656e324725048d914946e70867be20032ff
This commit is contained in:
hiyouga
2024-11-02 12:21:41 +08:00
parent 83479c9ef0
commit ba66ac084f
23 changed files with 53 additions and 62 deletions

View File

@@ -23,8 +23,8 @@ from llamafactory.chat import ChatModel
def main():
chat_model = ChatModel()
app = create_app(chat_model)
api_host = os.environ.get("API_HOST", "0.0.0.0")
api_port = int(os.environ.get("API_PORT", "8000"))
api_host = os.getenv("API_HOST", "0.0.0.0")
api_port = int(os.getenv("API_PORT", "8000"))
print(f"Visit http://localhost:{api_port}/docs for API document.")
uvicorn.run(app, host=api_host, port=api_port)

View File

@@ -86,19 +86,19 @@ def main():
elif command == Command.EXPORT:
export_model()
elif command == Command.TRAIN:
force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
if force_torchrun or get_device_count() > 1:
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info(f"Initializing distributed tasks at: {master_addr}:{master_port}")
process = subprocess.run(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
).format(
nnodes=os.environ.get("NNODES", "1"),
node_rank=os.environ.get("RANK", "0"),
nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())),
nnodes=os.getenv("NNODES", "1"),
node_rank=os.getenv("NODE_RANK", "0"),
nproc_per_node=os.getenv("NPROC_PER_NODE", str(get_device_count())),
master_addr=master_addr,
master_port=master_port,
file_name=launcher.__file__,

View File

@@ -19,7 +19,7 @@
# limitations under the License.
import inspect
from functools import partial, wraps
from functools import WRAPPER_ASSIGNMENTS, partial, wraps
from types import MethodType
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
@@ -81,7 +81,7 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
Only applies gradient checkpointing to trainable layers.
"""
@wraps(gradient_checkpointing_func)
@wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
module: "torch.nn.Module" = func.__self__
@@ -92,9 +92,6 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
return gradient_checkpointing_func(func, *args, **kwargs)
if hasattr(gradient_checkpointing_func, "__self__"): # fix unsloth gc test case
custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__
return custom_gradient_checkpointing_func

View File

@@ -80,18 +80,17 @@ def load_reference_model(
is_trainable: bool = False,
add_valuehead: bool = False,
) -> Union["PreTrainedModel", "LoraModel"]:
current_device = get_current_device()
if add_valuehead:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(
model_path, torch_dtype=torch.float16, device_map=get_current_device()
model_path, torch_dtype=torch.float16, device_map=current_device
)
if not is_trainable:
model.v_head = model.v_head.to(torch.float16)
return model
model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=torch.float16, device_map=get_current_device()
)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map=current_device)
if use_lora or use_pissa:
model = PeftModel.from_pretrained(
model, lora_path, subfolder="pissa_init" if use_pissa else None, is_trainable=is_trainable
@@ -110,7 +109,7 @@ def load_train_dataset(**kwargs) -> "Dataset":
return dataset_module["train_dataset"]
def patch_valuehead_model():
def patch_valuehead_model() -> None:
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]) -> None:
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
self.v_head.load_state_dict(state_dict, strict=False)