import os from typing import List import numpy as np import uvicorn from fastapi import FastAPI, Depends, HTTPException, status from fastapi.middleware.cors import CORSMiddleware from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from pydantic import BaseModel from sentence_transformers import SentenceTransformer, models # 环境变量传入 sk_key = os.environ.get('sk-key', 'sk-aaabbbcccdddeeefffggghhhiiijjjkkk') # 创建一个FastAPI实例 app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 创建一个HTTPBearer实例 security = HTTPBearer() # 加载预训练的 Transformer 模型 transformer_model = models.Transformer('./m3e-large', cache_dir='./cache') # 创建 Mean Pooling 层 pooling_model = models.Pooling(transformer_model.get_word_embedding_dimension(), pooling_mode='mean') # 构建 SentenceTransformer 模型 model = SentenceTransformer(modules=[transformer_model, pooling_model]) class EmbeddingRequest(BaseModel): input: List[str] class EmbeddingResponse(BaseModel): data: list dimension: int @app.post("/v1/embed", response_model=EmbeddingResponse) async def get_embed(request: EmbeddingRequest, credentials: HTTPAuthorizationCredentials = Depends(security)): if credentials.credentials != sk_key: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authorization code", ) # 计算嵌入向量和tokens数量 embeddings = [model.encode(text) for text in request.input] # 归一化处理 embeddings = [embedding / np.linalg.norm(embedding) for embedding in embeddings] # 将numpy数组转换为列表 embeddings = [embedding.tolist() for embedding in embeddings] response = { "data": [ { "embedding": embedding, "index": index } for index, embedding in enumerate(embeddings) ], "dimension": len(embeddings[0]) } return response if __name__ == "__main__": uvicorn.run("embed:app", host='0.0.0.0', port=6009, workers=2)