luckay-knows-demo/konws-python/embed/embed.py
liushuang 98d7406a99 init
2025-03-05 14:14:54 +08:00

77 lines
2.1 KiB
Python

import os
from typing import List
import numpy as np
import uvicorn
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer, models
# 环境变量传入
sk_key = os.environ.get('sk-key', 'sk-aaabbbcccdddeeefffggghhhiiijjjkkk')
# 创建一个FastAPI实例
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 创建一个HTTPBearer实例
security = HTTPBearer()
# 加载预训练的 Transformer 模型
transformer_model = models.Transformer('./m3e-large', cache_dir='./cache')
# 创建 Mean Pooling 层
pooling_model = models.Pooling(transformer_model.get_word_embedding_dimension(), pooling_mode='mean')
# 构建 SentenceTransformer 模型
model = SentenceTransformer(modules=[transformer_model, pooling_model])
class EmbeddingRequest(BaseModel):
input: List[str]
class EmbeddingResponse(BaseModel):
data: list
dimension: int
@app.post("/v1/embed", response_model=EmbeddingResponse)
async def get_embed(request: EmbeddingRequest, credentials: HTTPAuthorizationCredentials = Depends(security)):
if credentials.credentials != sk_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authorization code",
)
# 计算嵌入向量和tokens数量
embeddings = [model.encode(text) for text in request.input]
# 归一化处理
embeddings = [embedding / np.linalg.norm(embedding) for embedding in embeddings]
# 将numpy数组转换为列表
embeddings = [embedding.tolist() for embedding in embeddings]
response = {
"data": [
{
"embedding": embedding,
"index": index
} for index, embedding in enumerate(embeddings)
],
"dimension": len(embeddings[0])
}
return response
if __name__ == "__main__":
uvicorn.run("embed:app", host='0.0.0.0', port=6009, workers=2)