mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-01 08:13:38 +00:00
[misc] fix parser (#9730)
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
@@ -70,13 +71,13 @@ def read_args(args: dict[str, Any] | list[str] | None = None) -> dict[str, Any]
|
||||
if args is not None:
|
||||
return args
|
||||
|
||||
if sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml"):
|
||||
if len(sys.argv) > 1 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
|
||||
override_config = OmegaConf.from_cli(sys.argv[2:])
|
||||
dict_config = OmegaConf.load(Path(sys.argv[1]).absolute())
|
||||
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
|
||||
elif sys.argv[1].endswith(".json"):
|
||||
elif len(sys.argv) > 1 and sys.argv[1].endswith(".json"):
|
||||
override_config = OmegaConf.from_cli(sys.argv[2:])
|
||||
dict_config = OmegaConf.load(Path(sys.argv[1]).absolute())
|
||||
dict_config = OmegaConf.create(json.load(Path(sys.argv[1]).absolute()))
|
||||
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
|
||||
else:
|
||||
return sys.argv[1:]
|
||||
|
||||
Reference in New Issue
Block a user