update webUI, fix #179
Former-commit-id: f9074fed5e22585679661588befcf266a79009f2
This commit is contained in:
@@ -3,7 +3,7 @@ import json
|
||||
import gradio as gr
|
||||
import matplotlib.figure
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import Tuple
|
||||
from typing import Any, Dict, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
from llmtuner.extras.ploting import smooth
|
||||
@@ -23,7 +23,7 @@ def get_time() -> str:
|
||||
return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
||||
|
||||
|
||||
def can_preview(dataset_dir: str, dataset: list) -> dict:
|
||||
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
if (
|
||||
@@ -36,7 +36,7 @@ def can_preview(dataset_dir: str, dataset: list) -> dict:
|
||||
return gr.update(interactive=False)
|
||||
|
||||
|
||||
def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, dict]:
|
||||
def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, Dict[str, Any]]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
data_file = dataset_info[dataset[0]]["file_name"]
|
||||
@@ -45,6 +45,13 @@ def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, dict]:
|
||||
return len(data), data[:2], gr.update(visible=True)
|
||||
|
||||
|
||||
def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
||||
if finetuning_type != "lora":
|
||||
return gr.update(value="", interactive=False)
|
||||
else:
|
||||
return gr.update(interactive=True)
|
||||
|
||||
|
||||
def get_eval_results(path: os.PathLike) -> str:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
result = json.dumps(json.load(f), indent=4)
|
||||
@@ -66,6 +73,10 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
|
||||
if log_info.get("loss", None):
|
||||
steps.append(log_info["current_steps"])
|
||||
losses.append(log_info["loss"])
|
||||
|
||||
if len(losses) == 0:
|
||||
return None
|
||||
|
||||
ax.plot(steps, losses, alpha=0.4, label="original")
|
||||
ax.plot(steps, smooth(losses), label="smoothed")
|
||||
ax.legend()
|
||||
|
||||
Reference in New Issue
Block a user