[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.misc import get_current_device
|
||||
@@ -29,7 +29,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
def _get_unsloth_kwargs(
|
||||
config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments"
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"model_name": model_name_or_path,
|
||||
"max_seq_length": model_args.model_max_length or 4096,
|
||||
@@ -47,10 +47,8 @@ def _get_unsloth_kwargs(
|
||||
def load_unsloth_pretrained_model(
|
||||
config: "PretrainedConfig", model_args: "ModelArguments"
|
||||
) -> Optional["PreTrainedModel"]:
|
||||
r"""
|
||||
Optionally loads pretrained model with unsloth. Used in training.
|
||||
"""
|
||||
from unsloth import FastLanguageModel
|
||||
r"""Optionally load pretrained model with unsloth. Used in training."""
|
||||
from unsloth import FastLanguageModel # type: ignore
|
||||
|
||||
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args)
|
||||
try:
|
||||
@@ -64,12 +62,10 @@ def load_unsloth_pretrained_model(
|
||||
|
||||
|
||||
def get_unsloth_peft_model(
|
||||
model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any]
|
||||
model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: dict[str, Any]
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Gets the peft model for the pretrained model with unsloth. Used in training.
|
||||
"""
|
||||
from unsloth import FastLanguageModel
|
||||
r"""Get the peft model for the pretrained model with unsloth. Used in training."""
|
||||
from unsloth import FastLanguageModel # type: ignore
|
||||
|
||||
unsloth_peft_kwargs = {
|
||||
"model": model,
|
||||
@@ -82,10 +78,8 @@ def get_unsloth_peft_model(
|
||||
def load_unsloth_peft_model(
|
||||
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Loads peft model with unsloth. Used in both training and inference.
|
||||
"""
|
||||
from unsloth import FastLanguageModel
|
||||
r"""Load peft model with unsloth. Used in both training and inference."""
|
||||
from unsloth import FastLanguageModel # type: ignore
|
||||
|
||||
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args)
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user