[misc] update format (#7277)
This commit is contained in:
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -33,7 +32,7 @@ if TYPE_CHECKING:
|
||||
from ..data.data_utils import DatasetModule
|
||||
|
||||
|
||||
def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_keys: Sequence[str] = []) -> None:
|
||||
def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_keys: list[str] = []) -> None:
|
||||
state_dict_a = model_a.state_dict()
|
||||
state_dict_b = model_b.state_dict()
|
||||
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
|
||||
|
||||
Reference in New Issue
Block a user