support rank0 logger

Former-commit-id: 84528eabe560091bfd866b6a0ca864085af7529b
This commit is contained in:
hiyouga
2024-11-02 18:31:04 +08:00
parent ceb701c2d4
commit 093eda2ad6
42 changed files with 316 additions and 252 deletions

View File

@@ -68,7 +68,7 @@ async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU mem
def create_app(chat_model: "ChatModel") -> "FastAPI":
root_path = os.environ.get("FASTAPI_ROOT_PATH", "")
root_path = os.getenv("FASTAPI_ROOT_PATH", "")
app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
app.add_middleware(
CORSMiddleware,
@@ -77,7 +77,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
allow_methods=["*"],
allow_headers=["*"],
)
api_key = os.environ.get("API_KEY", None)
api_key = os.getenv("API_KEY")
security = HTTPBearer(auto_error=False)
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
@@ -91,7 +91,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
dependencies=[Depends(verify_api_key)],
)
async def list_models():
model_card = ModelCard(id=os.environ.get("API_MODEL_NAME", "gpt-3.5-turbo"))
model_card = ModelCard(id=os.getenv("API_MODEL_NAME", "gpt-3.5-turbo"))
return ModelList(data=[model_card])
@app.post(
@@ -128,7 +128,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
def run_api() -> None:
chat_model = ChatModel()
app = create_app(chat_model)
api_host = os.environ.get("API_HOST", "0.0.0.0")
api_port = int(os.environ.get("API_PORT", "8000"))
api_host = os.getenv("API_HOST", "0.0.0.0")
api_port = int(os.getenv("API_PORT", "8000"))
print(f"Visit http://localhost:{api_port}/docs for API document.")
uvicorn.run(app, host=api_host, port=api_port)