add docstrings, refactor logger
Former-commit-id: c34e489d71f8f539028543ccf8ee92cecedd6276
This commit is contained in:
@@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
from typing_extensions import Annotated
|
||||
@@ -50,15 +52,24 @@ if is_uvicorn_available():
|
||||
import uvicorn
|
||||
|
||||
|
||||
async def sweeper() -> None:
|
||||
while True:
|
||||
torch_gc()
|
||||
await asyncio.sleep(300)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: "FastAPI"): # collects GPU memory
|
||||
async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU memory
|
||||
if chat_model.engine_type == "huggingface":
|
||||
asyncio.create_task(sweeper())
|
||||
|
||||
yield
|
||||
torch_gc()
|
||||
|
||||
|
||||
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
root_path = os.environ.get("FASTAPI_ROOT_PATH", "")
|
||||
app = FastAPI(lifespan=lifespan, root_path=root_path)
|
||||
app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
@@ -66,7 +77,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
api_key = os.environ.get("API_KEY")
|
||||
api_key = os.environ.get("API_KEY", None)
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
||||
@@ -80,7 +91,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
dependencies=[Depends(verify_api_key)],
|
||||
)
|
||||
async def list_models():
|
||||
model_card = ModelCard(id="gpt-3.5-turbo")
|
||||
model_card = ModelCard(id=os.environ.get("API_MODEL_NAME", "gpt-3.5-turbo"))
|
||||
return ModelList(data=[model_card])
|
||||
|
||||
@app.post(
|
||||
|
||||
@@ -52,9 +52,8 @@ if is_requests_available():
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..data.mm_plugin import ImageInput
|
||||
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
||||
|
||||
|
||||
@@ -70,7 +69,7 @@ ROLE_MAPPING = {
|
||||
|
||||
def _process_request(
|
||||
request: "ChatCompletionRequest",
|
||||
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["NDArray"]]:
|
||||
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]:
|
||||
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
|
||||
|
||||
if len(request.messages) == 0:
|
||||
|
||||
Reference in New Issue
Block a user