77 lines
2.1 KiB
Python
77 lines
2.1 KiB
Python
import os
|
|
from typing import List
|
|
|
|
import numpy as np
|
|
import uvicorn
|
|
from fastapi import FastAPI, Depends, HTTPException, status
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
from pydantic import BaseModel
|
|
from sentence_transformers import SentenceTransformer, models
|
|
|
|
# 环境变量传入
|
|
sk_key = os.environ.get('sk-key', 'sk-aaabbbcccdddeeefffggghhhiiijjjkkk')
|
|
|
|
# 创建一个FastAPI实例
|
|
app = FastAPI()
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# 创建一个HTTPBearer实例
|
|
security = HTTPBearer()
|
|
# 加载预训练的 Transformer 模型
|
|
transformer_model = models.Transformer('./m3e-large', cache_dir='./cache')
|
|
|
|
# 创建 Mean Pooling 层
|
|
pooling_model = models.Pooling(transformer_model.get_word_embedding_dimension(), pooling_mode='mean')
|
|
|
|
# 构建 SentenceTransformer 模型
|
|
model = SentenceTransformer(modules=[transformer_model, pooling_model])
|
|
|
|
|
|
class EmbeddingRequest(BaseModel):
|
|
input: List[str]
|
|
|
|
|
|
class EmbeddingResponse(BaseModel):
|
|
data: list
|
|
dimension: int
|
|
|
|
|
|
@app.post("/v1/embed", response_model=EmbeddingResponse)
|
|
async def get_embed(request: EmbeddingRequest, credentials: HTTPAuthorizationCredentials = Depends(security)):
|
|
if credentials.credentials != sk_key:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid authorization code",
|
|
)
|
|
|
|
# 计算嵌入向量和tokens数量
|
|
embeddings = [model.encode(text) for text in request.input]
|
|
# 归一化处理
|
|
embeddings = [embedding / np.linalg.norm(embedding) for embedding in embeddings]
|
|
# 将numpy数组转换为列表
|
|
embeddings = [embedding.tolist() for embedding in embeddings]
|
|
|
|
response = {
|
|
"data": [
|
|
{
|
|
"embedding": embedding,
|
|
"index": index
|
|
} for index, embedding in enumerate(embeddings)
|
|
],
|
|
"dimension": len(embeddings[0])
|
|
}
|
|
|
|
return response
|
|
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run("embed:app", host='0.0.0.0', port=6009, workers=2)
|