[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -12,7 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Sequence, Set, Tuple, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
@@ -43,7 +44,7 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k
|
||||
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is True
|
||||
|
||||
|
||||
def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]:
|
||||
def check_lora_model(model: "LoraModel") -> tuple[set[str], set[str]]:
|
||||
linear_modules, extra_modules = set(), set()
|
||||
for name, param in model.named_parameters():
|
||||
if any(module in name for module in ["lora_A", "lora_B"]):
|
||||
@@ -83,7 +84,7 @@ def load_reference_model(
|
||||
) -> Union["PreTrainedModel", "LoraModel"]:
|
||||
current_device = get_current_device()
|
||||
if add_valuehead:
|
||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
model_path, torch_dtype=torch.float16, device_map=current_device
|
||||
)
|
||||
if not is_trainable:
|
||||
@@ -111,7 +112,7 @@ def load_dataset_module(**kwargs) -> "DatasetModule":
|
||||
|
||||
|
||||
def patch_valuehead_model() -> None:
|
||||
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]) -> None:
|
||||
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: dict[str, "torch.Tensor"]) -> None:
|
||||
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
|
||||
self.v_head.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
|
||||
Reference in New Issue
Block a user