forked from sugarlabs/speak-ai
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathLLM.py
More file actions
91 lines (75 loc) · 3.07 KB
/
LLM.py
File metadata and controls
91 lines (75 loc) · 3.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import requests
import json
import socket
import logging
#TODO: Dont hard code these, need to see how sugar as a whole manages API Keys
API_URL = "https://ai.sugarlabs.org/ask-llm-prompted"
try:
with open("API_KEY.txt", "r") as f:
API_KEY = f.read().strip()
except OSError:
logging.error("Missing API_KEY.txt file.")
API_KEY = None
DEFAULT_PROMPT = "You are a friendly teacher named Jane who is 28 years old. You teach 10 year old children. Always give helpful, educational responses in simple words that children can understand. Keep your answers between 20-40 words. Be encouraging and enthusiastic but never use emojis(ever). If you notice spelling mistakes, gently correct them. Stay focused on the topic and give relevant answers."
def is_connected():
try:
socket.create_connection(("8.8.8.8", 53), timeout=5)
logging.debug("Connection to 8.8.8.8 successful")
return True
except OSError:
logging.error("Error: No internet connection. Please check your network.")
return False
def ask_llm_prompted(question, custom_prompt = DEFAULT_PROMPT, timeout=120, max_length=200):
if API_KEY is None:
logging.error("Missing API key file: API_KEY.txt")
return False
if not is_connected():
return False
headers = {
"X-API-Key": API_KEY,
"Content-Type": "application/json"
}
payload = {
"question": question,
"custom_prompt": custom_prompt,
"max_length": max_length,
"truncation": True,
"repetition_penalty": 1.2, # Slightly higher to avoid repetition
"temperature": 0.3, # Lower for more consistent responses
"top_p": 0.8, # Slightly lower for better focus
"top_k": 20 # Much lower for more predictable responses
}
try:
response = requests.post(
API_URL,
headers=headers,
data=json.dumps(payload),
timeout=(10, timeout),
)
if 500 <= response.status_code < 600:
logging.error(f"Server error: {response.status_code}")
return False
response.raise_for_status()
# Parse the JSON response.
data = response.json()
# Check if the 'answer' key is in the response and return it.
if isinstance(data, dict) and "answer" in data:
return data['answer']
else:
return data
except requests.exceptions.Timeout:
logging.error(f"The request timed out after {timeout} seconds. The server might be slow.")
except requests.exceptions.RequestException as e:
logging.error(f"An error occurred: {e}")
try:
logging.error(f"Response content: {response.text}")
except Exception:
pass
return False
if __name__ == "__main__":
while True:
answer = ask_llm_prompted(question=input("Enter question to LLM"),custom_prompt=DEFAULT_PROMPT)
if answer:
print(f'LLM ANS: {answer}')
else:
print("Error, LLM did not respond")