-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathquery_server.py
More file actions
112 lines (93 loc) · 3.51 KB
/
query_server.py
File metadata and controls
112 lines (93 loc) · 3.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import asyncio
import os
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
import uvicorn
from pymongo import MongoClient
from dotenv import load_dotenv
from backend.main_backend import QueryRequest
from mongo.general.functions import add_chat_to_conversation, create_conversation
from mongo.general.schema import PyMongoConversation
from agent import run_pipeline
# from common.evaluator import EnhancedRAGPipeline
from fastapi.middleware.cors import CORSMiddleware
load_dotenv()
try:
client = MongoClient(os.environ["MONGO_CONNECTION_STRING"])
client.admin.command("ping")
print("MongoDB connected successfully!")
except Exception as e:
print("Failed to connect to MongoDB:", e)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
async def handle_conversation(websocket: WebSocket, request: QueryRequest):
db = client["sirius"]
conversations_collection = db["Message"]
print(request)
try:
if not request.id or request.id.strip() == "":
conversation_data = {
"title": request.query,
"chats": [{"message": request.query, "role": "USER", "order": 1}],
}
new_conversation = PyMongoConversation.model_validate(conversation_data)
result = create_conversation(client, new_conversation)
inserted_conversation = conversations_collection.find_one(
{"_id": result.inserted_id}
)
await websocket.send_json(
{
"type": "request",
"conversation": inserted_conversation,
}
)
await asyncio.sleep(0.5)
rag_response = run_pipeline(request.query)
print(rag_response)
updated_conversation = add_chat_to_conversation(
client,
str(inserted_conversation["_id"]),
rag_response,
"RAG",
)
await websocket.send_json(
{
"type": "response",
"conversation": {
"id": str(updated_conversation["_id"]),
"title": updated_conversation["title"],
"chats": updated_conversation["chats"],
},
}
)
else:
rag_response = run_pipeline(request.query)
from swarm.util import debug_print
debug_print(True, f"Processing tree call: Final RAG Response")
assistant_msg = add_chat_to_conversation(client, request.id, rag_response)
print(assistant_msg)
await websocket.send_json({"type": "response", "message": assistant_msg})
except Exception as e:
print(f"Conversation handling error: {e}")
await websocket.send_json({"type": "error", "message": str(e)})
@app.websocket("/ws/query")
async def query_websocket_handler(websocket: WebSocket):
await websocket.accept()
try:
while True:
data = await websocket.receive_text()
request = QueryRequest.model_validate_json(data)
await handle_conversation(websocket, request)
except WebSocketDisconnect:
print("Query WebSocket disconnected")
except Exception as e:
print(f"Query WebSocket error: {e}")
finally:
await websocket.close()
if __name__ == "__main__":
uvicorn.run(app, host="localhost", port=5050)