Files
my_agent/my_agent/start_all.py
T
2026-06-29 23:22:14 +08:00

691 lines
36 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""同时启动 LangGraph API、调度器、聊天前端,支持技能管理"""
"""测试"""
import sys
import io
if sys.platform == "win32":
try:
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace')
except Exception:
pass
import subprocess, threading, webbrowser, re, os, json, shutil, zipfile, time, secrets, hashlib
from pathlib import Path
from datetime import datetime
from http.server import HTTPServer, SimpleHTTPRequestHandler
from urllib.parse import urlparse, parse_qs
from email.parser import BytesParser
from dotenv import load_dotenv
from scheduler import start_scheduler_background
from simple_agent.utils.log_setup import setup_all_logging
from simple_agent.utils.path_utils import get_workspace_dir, get_skills_dir
from simple_agent.skills.registry import initialize_skills
from simple_agent.utils.audit import get_confirmation, update_confirmation
load_dotenv()
# ========== Auth ==========
ADMIN_USER = os.getenv("AGENT_ADMIN_USER", "admin")
PBKDF2_ITERATIONS = 600_000; PBKDF2_HASH_ALGO = "sha256"
def _verify_password(password, stored_hash, salt):
return secrets.compare_digest(
hashlib.pbkdf2_hmac(PBKDF2_HASH_ALGO, password.encode(), bytes.fromhex(salt), PBKDF2_ITERATIONS, dklen=32).hex(),
stored_hash
)
def _init_auth():
pw_hash = os.getenv("AGENT_ADMIN_PASS_HASH", "")
salt = os.getenv("AGENT_ADMIN_PASS_SALT", "")
if pw_hash and salt: return pw_hash, salt
default_password = "Ascii2013!"
salt = secrets.token_hex(32)
pw_hash = hashlib.pbkdf2_hmac(PBKDF2_HASH_ALGO, default_password.encode(), bytes.fromhex(salt), PBKDF2_ITERATIONS, dklen=32).hex()
print(f"\n[AUTH] 首次运行,请将以下配置写入 .env 文件:")
print(f"AGENT_ADMIN_PASS_HASH={pw_hash}")
print(f"AGENT_ADMIN_PASS_SALT={salt}\n")
return pw_hash, salt
ADMIN_PASS_HASH, ADMIN_PASS_SALT = _init_auth()
LOGIN_ATTEMPTS = {}; MAX_LOGIN_ATTEMPTS = 5; LOGIN_LOCKOUT_MINUTES = 15
_sessions = {}; TOKEN_EXPIRE_HOURS = 24
def _new_token(): return secrets.token_hex(32)
def _cleanup_expired_tokens():
now = time.time()
for t in list(_sessions.keys()):
if now - _sessions[t]["created_at"] > TOKEN_EXPIRE_HOURS * 3600:
del _sessions[t]
def _validate_token(token): _cleanup_expired_tokens(); return token in _sessions
def _check_auth(handler):
token = handler.headers.get("Authorization", "").replace("Bearer ", "")
if token and _validate_token(token): return True
handler._json_response(401, {"error": "未登录或会话已过期"}); return False
UUID_PATTERN = re.compile(r'^[a-f0-9\-]{36}$')
def _is_valid_uuid(s): return bool(UUID_PATTERN.match(s)) if s else False
def parse_multipart_body(body, boundary):
if boundary.startswith(b'"') and boundary.endswith(b'"'): boundary = boundary[1:-1]
msg = BytesParser().parsebytes(b'Content-Type: multipart/form-data; boundary=' + boundary + b'\r\n\r\n' + body)
parts = []
if msg.is_multipart():
for part in msg.walk():
if part.get_content_maintype() == 'multipart': continue
cd = part.get_content_disposition()
if cd == 'form-data':
name = part.get_param('name', header='content-disposition')
fname = part.get_filename()
data = part.get_payload(decode=True)
parts.append((name, fname, data))
return parts
class ChatHandler(SimpleHTTPRequestHandler):
def do_GET(self):
api_paths = ["/skills/", "/api/"]
if any(self.path.startswith(p) for p in api_paths) and not _check_auth(self): return
if self.path.startswith("/workspace/"):
self._serve_workspace_file()
elif self.path == "/skills/list":
try:
from simple_agent.skills.registry import get_skills_info
self._json_response(200, get_skills_info())
except Exception as e: self.send_error(500, str(e))
elif self.path.startswith("/skills/trash"):
skills_dir = get_skills_dir(); trash_dir = skills_dir / ".trash"
items = []
if trash_dir.exists():
for item in trash_dir.iterdir():
if item.is_dir():
try:
name, stamp = item.name.rsplit("_", 1)
items.append({"name": name, "deleted_at": int(stamp), "description": "已删除"})
except ValueError: continue
self._json_response(200, items)
elif self.path.startswith("/api/confirm_info"):
self._handle_confirm_info()
elif self.path == "/" or self.path == "":
self.send_response(302)
self.send_header("Location", "/static/chat.html")
self.end_headers()
else:
super().do_GET()
def do_POST(self):
if self.path == "/api/login":
return self._handle_login()
if not _check_auth(self): return
if self.path.startswith("/upload"): self._handle_upload()
elif self.path == "/skills/reload": self._handle_skills_reload()
elif self.path.startswith("/skills/delete"): self._handle_skills_delete()
elif self.path == "/skills/upload": self._handle_skills_upload()
elif self.path.startswith("/skills/permanent_delete"): self._handle_skills_permanent_delete()
elif self.path.startswith("/skills/recover"): self._handle_skills_recover()
elif self.path == "/api/write_file": self._handle_write_file()
elif self.path == "/api/confirm_write": self._handle_confirm_write()
elif self.path == "/api/deploy": self._handle_deploy()
elif self.path == "/api/mcp_query": self._handle_mcp_query()
elif self.path == "/api/mcp_reconnect": self._handle_mcp_reconnect()
elif self.path == "/api/restart": self._handle_restart()
elif self.path == "/api/save_config": self._handle_save_config()
else: super().do_POST()
def list_directory(self, path):
"""禁止目录浏览"""
self.send_error(403, "Directory listing not allowed")
def _json_response(self, code, data):
self.send_response(code)
self.send_header("Content-Type", "application/json; charset=utf-8")
self.send_header("X-Content-Type-Options", "nosniff")
self.send_header("X-Frame-Options", "DENY")
self.send_header("Cache-Control", "no-store")
self.end_headers()
self.wfile.write(json.dumps(data, ensure_ascii=False).encode())
def _get_query_param(self, name):
query = urlparse(self.path).query; params = parse_qs(query)
return params.get(name, [None])[0]
def _extract_file(self):
content_type = self.headers.get("Content-Type", "")
if not content_type.startswith("multipart/form-data"): self.send_error(400, "需要 multipart/form-data"); return None
boundary = content_type.split("boundary=")[1].encode()
body = self.rfile.read(int(self.headers.get("Content-Length", 0)))
parts = parse_multipart_body(body, boundary)
for name, fname, data in parts:
if name == "file": return name, fname, data
self.send_error(400, "缺少文件"); return None
def _serve_workspace_file(self):
"""提供 workspace 目录下文件的静态访问(优先桌面,回退项目目录)"""
from urllib.parse import unquote
parsed = urlparse(self.path)
req_path = unquote(parsed.path)
rel = req_path[len("/workspace/"):]
if ".." in rel or "/" in rel.replace("\\", "/").lstrip("/").split("/")[0] or "\\" in rel:
self.send_error(403, "禁止访问"); return
safe_name = Path(rel).name
if not safe_name:
self.send_error(400, "文件名不能为空"); return
# 搜索目录:桌面 AgentWorkspace → 项目 workspace → writer workspace
from simple_agent.utils.path_security import _get_writer_workspace
search_dirs = []
try: search_dirs.append(_get_writer_workspace())
except: pass
try: search_dirs.append(get_workspace_dir())
except: pass
found = None
for sd in search_dirs:
for f in sd.rglob(safe_name):
found = f
break
if found:
break
if not found or not found.is_file():
self.send_error(404, "文件不存在"); return
ext = found.suffix.lower()
ctype_map = {".svg": "image/svg+xml", ".png": "image/png", ".jpg": "image/jpeg", ".jpeg": "image/jpeg",
".gif": "image/gif", ".html": "text/html", ".css": "text/css", ".js": "application/javascript",
".json": "application/json", ".pdf": "application/pdf", ".txt": "text/plain",
".md": "text/markdown", ".csv": "text/csv"}
content_type = ctype_map.get(ext, "application/octet-stream")
self.send_response(200)
self.send_header("Content-Type", content_type + "; charset=utf-8")
self.send_header("Cache-Control", "no-cache")
self.end_headers()
self.wfile.write(found.read_bytes())
def _handle_upload(self):
extracted = self._extract_file()
if extracted is None: return
_, filename, data = extracted
upload_type = self._get_query_param("type") or "knowledge"
if upload_type not in ("knowledge", "memory"): self.send_error(400, "type 必须为 knowledge 或 memory"); return
safe_name = Path(filename).name if filename else "uploaded_file"
target_dir = get_workspace_dir() / upload_type; target_path = target_dir / safe_name
with open(target_path, "wb") as f: f.write(data)
self._json_response(200, {"status": "ok", "path": str(target_path)})
def _handle_skills_reload(self):
try:
from simple_agent.utils.hotreload import reload_all
msg = reload_all()
self._json_response(200, {"status": "ok", "message": msg})
except Exception as e: self.send_error(500, str(e))
def _handle_skills_delete(self):
name = self._get_query_param("name")
if not name or ".." in name or "/" in name or "\\" in name: self.send_error(400, "无效的技能名称"); return
skills_dir = get_skills_dir(); skill_folder = skills_dir / name
if not skill_folder.exists(): self.send_error(404, "技能不存在"); return
trash_dir = skills_dir / ".trash"; trash_dir.mkdir(exist_ok=True)
shutil.move(str(skill_folder), str(trash_dir / f"{name}_{int(time.time()*1000)}"))
from simple_agent.skills.registry import reload_skills
msg = reload_skills()
self._json_response(200, {"status": "ok", "message": f"技能 {name} 已移至垃圾桶。{msg}"})
def _handle_skills_upload(self):
extracted = self._extract_file()
if extracted is None: return
_, filename, data = extracted
if not filename: self.send_error(400, "文件名不能为空"); return
file_ext = Path(filename).suffix.lower(); skills_dir = get_skills_dir(); skills_dir.mkdir(parents=True, exist_ok=True)
if file_ext == ".zip":
temp_zip = skills_dir / "_upload_temp.zip"
with open(temp_zip, "wb") as f: f.write(data)
try:
with zipfile.ZipFile(temp_zip, "r") as zf:
for member in zf.infolist():
member_path = Path(member.filename)
if member_path.is_absolute() or ".." in member_path.parts:
self.send_error(400, f"ZIP中包含不安全的路径: {member.filename}"); return
zf.extractall(skills_dir)
finally: temp_zip.unlink(missing_ok=True)
for md_file in list(skills_dir.glob("*.md")) + list(skills_dir.glob("*.markdown")):
content = md_file.read_text(encoding="utf-8"); skill_name = None
if content.startswith("---"):
parts = content.split("---", maxsplit=2)
if len(parts) >= 3:
try:
import yaml; meta = yaml.safe_load(parts[1]); skill_name = meta.get("name")
except: pass
if not skill_name: skill_name = md_file.stem
target_dir = skills_dir / skill_name; target_dir.mkdir(parents=True, exist_ok=True)
shutil.move(str(md_file), str(target_dir / "SKILL.md"))
elif file_ext in (".md", ".markdown"):
content = data.decode("utf-8"); import yaml; skill_name = None
if content.startswith("---"):
parts = content.split("---", maxsplit=2)
if len(parts) >= 3:
try: meta = yaml.safe_load(parts[1]); skill_name = meta.get("name")
except: pass
if not skill_name: skill_name = Path(filename).stem
if ".." in skill_name or "/" in skill_name or "\\" in skill_name: self.send_error(400, "无效的技能名称"); return
skill_dir = skills_dir / skill_name; skill_dir.mkdir(parents=True, exist_ok=True)
(skill_dir / "SKILL.md").write_text(content, encoding="utf-8")
else: self.send_error(400, "不支持的文件类型,请上传 .zip 或 .md"); return
from simple_agent.skills.registry import reload_skills
self._json_response(200, {"status": "ok", "message": f"技能已安装。{reload_skills()}"})
def _handle_skills_permanent_delete(self):
name = self._get_query_param("name")
if not name or ".." in name or "/" in name or "\\" in name or "*" in name: self.send_error(400, "无效的技能名称"); return
skills_dir = get_skills_dir(); trash_dir = skills_dir / ".trash"
candidates = sorted(trash_dir.glob(f"{name}_*")) if trash_dir.exists() else []
if not candidates: self.send_error(404, "垃圾桶中没有该技能"); return
shutil.rmtree(str(candidates[0]))
self._json_response(200, {"status": "ok", "message": f"已永久删除技能 {name}。"})
def _handle_skills_recover(self):
name = self._get_query_param("name")
if not name or ".." in name or "/" in name or "\\" in name or "*" in name: self.send_error(400, "无效的技能名称"); return
skills_dir = get_skills_dir(); trash_dir = skills_dir / ".trash"
candidates = sorted(trash_dir.glob(f"{name}_*"), reverse=True) if trash_dir.exists() else []
if not candidates: self.send_error(404, "垃圾桶中没有该技能"); return
target = candidates[0]; target.rename(skills_dir / name)
from simple_agent.skills.registry import reload_skills
self._json_response(200, {"status": "ok", "message": f"已恢复技能 {name}。"})
def _handle_login(self):
client_ip = self.client_address[0]
if LOGIN_ATTEMPTS.get(client_ip, 0) >= MAX_LOGIN_ATTEMPTS:
self._json_response(429, {"error": "登录尝试次数过多,请15分钟后再试"}); return
body = self.rfile.read(int(self.headers.get("Content-Length", 0)))
try:
data = json.loads(body)
username = data.get("username", ""); password = data.get("password", "")
if username == ADMIN_USER and _verify_password(password, ADMIN_PASS_HASH, ADMIN_PASS_SALT):
LOGIN_ATTEMPTS[client_ip] = 0
token = _new_token(); _sessions[token] = {"user": username, "created_at": time.time(), "ip": client_ip}
self._json_response(200, {"token": token, "user": username})
else:
LOGIN_ATTEMPTS[client_ip] = LOGIN_ATTEMPTS.get(client_ip, 0) + 1
self._json_response(401, {"error": "用户名或密码错误"})
except Exception as e: self._json_response(400, {"error": str(e)})
def _handle_write_file(self):
body = self.rfile.read(int(self.headers.get("Content-Length", 0)))
try:
data = json.loads(body); file_path = data.get("path", ""); content = data.get("content", "")
if not file_path: self._json_response(400, {"error": "缺少文件路径"}); return
from simple_agent.utils.path_security import resolve_write_path
resolved, err = resolve_write_path(file_path)
if err: self._json_response(403, {"error": err}); return
p = Path(resolved)
try:
p.parent.mkdir(parents=True, exist_ok=True); p.write_text(content, encoding="utf-8")
except PermissionError as pe:
self._json_response(403, {"error": f"权限不足: {pe}"}); return
self._json_response(200, {"success": True, "path": str(p), "size": len(content)})
except Exception as e: self._json_response(500, {"error": str(e)})
def _handle_confirm_info(self):
confirm_id = self._get_query_param("confirm_id")
if not confirm_id: self._json_response(400, {"error": "缺少 confirm_id"}); return
if not _is_valid_uuid(confirm_id): self._json_response(400, {"error": "confirm_id 格式无效"}); return
try:
c = get_confirmation(confirm_id)
if not c: self._json_response(404, {"error": "确认记录不存在"}); return
self._json_response(200, {"confirm_id": c["id"], "target_path": c["target_path"],
"operation_details": c["operation_details"], "risk_analysis": c["risk_analysis"],
"result": c["result"], "created_at": c["created_at"]})
except Exception as e: self._json_response(500, {"error": str(e)})
def _handle_confirm_write(self):
body = self.rfile.read(int(self.headers.get("Content-Length", 0)))
try:
data = json.loads(body); confirm_id = data.get("confirm_id", ""); choice = data.get("choice", "")
if not confirm_id or not choice: self._json_response(400, {"error": "缺少 confirm_id 或 choice"}); return
if not _is_valid_uuid(confirm_id): self._json_response(400, {"error": "confirm_id 格式无效"}); return
if choice not in ("approve", "reject"): self._json_response(400, {"error": "choice 必须为 approve 或 reject"}); return
confirmation = get_confirmation(confirm_id)
if not confirmation: self._json_response(404, {"error": "确认记录不存在"}); return
if confirmation["result"] != "pending": self._json_response(400, {"error": f"该确认已处理"}); return
if choice == "approve":
update_confirmation(confirm_id, "approved", "user")
fp = confirmation["target_path"]; content = confirmation["content"] or ""
p = Path(fp)
try:
p.parent.mkdir(parents=True, exist_ok=True)
if p.suffix.lower() == '.pdf':
from fpdf import FPDF; from simple_agent.utils.path_security import find_cjk_font
pdf = FPDF(); pdf.add_page(); font_path = find_cjk_font()
if font_path: pdf.add_font("CJK", "", font_path); pdf.set_font("CJK", "", 12)
else: pdf.set_font("Helvetica", "", 12)
for line in content.split("\n"): pdf.cell(0, 10, line, ln=True)
pdf.output(str(p))
elif p.suffix.lower() in ('.docx', '.pptx', '.xlsx'):
self._json_response(400, {"error": f"{p.suffix} 需通过 WriterAgent 直接生成"}); return
else: p.write_text(content, encoding="utf-8")
except PermissionError as pe:
self._json_response(403, {"error": f"权限不足: {pe}"}); return
self._json_response(200, {"success": True, "message": "写入已确认并执行", "path": str(p), "size": len(content)})
else:
update_confirmation(confirm_id, "rejected", "user")
self._json_response(200, {"success": True, "message": "用户拒绝写入"})
except Exception as e: self._json_response(500, {"error": str(e)})
def _handle_deploy(self):
"""处理JAR部署:备份→保存→执行→成功/回滚"""
content_type = self.headers.get("Content-Type", "")
if "multipart/form-data" not in content_type:
self._json_response(400, {"error": "需要 multipart/form-data"}); return
boundary = content_type.split("boundary=")[1].encode()
body = self.rfile.read(int(self.headers.get("Content-Length", 0)))
parts = parse_multipart_body(body, boundary)
jar_data = None; jar_name = "deploy.jar"; target_dir = ""; health_url = ""
for name, fname, data in parts:
if name == "jar_file": jar_data = data; jar_name = fname or "deploy.jar"
elif name == "target_dir": target_dir = data.decode("utf-8")
elif name == "health_url": health_url = data.decode("utf-8")
if not jar_data: self._json_response(400, {"error": "缺少 jar_file"}); return
if not target_dir: self._json_response(400, {"error": "缺少 target_dir"}); return
target = Path(target_dir); jar_path = target / jar_name
backup_name = ""; rolled_back = False
try:
target.mkdir(parents=True, exist_ok=True)
# 1. 备份
if jar_path.exists():
backup_name = f"{jar_name}.bak.{datetime.now().strftime('%Y%m%d_%H%M%S')}"
shutil.move(str(jar_path), str(target / backup_name))
print(f"[DEPLOY] 已备份: {backup_name}")
# 2. 保存新JAR
jar_path.write_bytes(jar_data)
print(f"[DEPLOY] 已保存: {jar_path} ({len(jar_data)} bytes)")
result = {"backup": backup_name, "saved": str(jar_path), "started": False, "rolled_back": False, "health_ok": False}
# 3. 执行 start.sh
start_script = target / "start.sh"
if start_script.exists():
if sys.platform == "win32":
proc = subprocess.run(["cmd","/c",str(start_script)], cwd=str(target), capture_output=True, text=True, timeout=120)
else:
proc = subprocess.run(["bash",str(start_script)], cwd=str(target), capture_output=True, text=True, timeout=120)
result["started"] = True; result["returncode"] = proc.returncode
result["output"] = (proc.stdout + proc.stderr)[-1000:]
print(f"[DEPLOY] start.sh 完成 (code={proc.returncode})")
# 4. 健康检查:等待服务启动,有响应即视为成功
health_ok = False
if health_url and proc.returncode == 0:
import urllib.request
for attempt in range(5):
time.sleep(3)
try:
resp = urllib.request.urlopen(health_url, timeout=5)
health_ok = True; result["health_ok"] = True
print(f"[DEPLOY] 健康检查通过 ({resp.status})")
break
except urllib.error.HTTPError as e:
health_ok = True; result["health_ok"] = True
print(f"[DEPLOY] 服务已响应 ({e.code})")
break
except Exception:
print(f"[DEPLOY] 健康检查重试 {attempt+1}/5...")
elif health_url and proc.returncode != 0:
health_ok = False
else:
health_ok = (proc.returncode == 0) # 无 health_url 时只看退出码
# 5. 失败且可回滚
if not health_ok and backup_name:
backup_path = target / backup_name
if backup_path.exists():
jar_path.unlink(missing_ok=True)
shutil.move(str(backup_path), str(jar_path))
result["rolled_back"] = True
print(f"[DEPLOY] 已回滚: {backup_name} -> {jar_name}")
else:
print(f"[DEPLOY] start.sh 不存在,跳过")
self._json_response(200, {"success": True, **result})
except subprocess.TimeoutExpired:
self._json_response(500, {"error": "start.sh 执行超时"})
except PermissionError as e:
self._json_response(403, {"error": f"权限不足: {e}"})
except Exception as e:
self._json_response(500, {"error": str(e)})
def _handle_mcp_query(self):
"""只读MCP查询——直接调用MCP工具,不经过LLM代理"""
body = self.rfile.read(int(self.headers.get("Content-Length", 0)))
try:
data = json.loads(body)
query = data.get("query", "")
if not query:
self._json_response(400, {"error": "缺少 query 参数"}); return
import asyncio
from langchain_mcp_adapters.client import MultiServerMCPClient
amap_key = os.getenv("AMAP_KEY", "")
if not amap_key:
self._json_response(503, {"error": "AMAP_KEY 未配置"}); return
async def _direct_mcp_query(q: str) -> str:
client = MultiServerMCPClient({
"amap": {
"transport": "http",
"url": f"https://mcp.amap.com/mcp?key={amap_key}",
"headers": {}
}
})
tools = await asyncio.wait_for(client.get_tools(), timeout=10)
tool_map = {t.name: t for t in tools}
# 关键词→工具路由
q_lower = q.lower()
result = None
# 天气
if any(kw in q for kw in ["天气", "weather"]):
city = "北京"
city_match = re.search(r'([\u4e00-\u9fff]{2,4}(?:市|县|区)?)\s*(?:今天|明天|后天|的|这)', q)
if city_match:
city = city_match.group(1).rstrip("市")
else:
city_match2 = re.search(r'查询\s*([\u4e00-\u9fff]{2,4})', q)
if city_match2:
city = city_match2.group(1)
wtool = tool_map.get("maps_weather")
if wtool:
raw = await wtool.ainvoke({"city": city})
result = json.loads(raw[0]["text"]) if isinstance(raw, list) else raw
# 地理编码
elif any(kw in q for kw in ["地址", "地理编码", "经纬度", "在哪", "哪里"]):
gtool = tool_map.get("maps_geo")
if gtool:
addr_match = re.search(r'(?:地址|查询|搜索|编码)[:\s]*([^\s,。]+)', q)
addr = addr_match.group(1) if addr_match else q[-20:]
raw = await gtool.ainvoke({"address": addr})
result = json.loads(raw[0]["text"]) if isinstance(raw, list) else raw
# 周边搜索
elif any(kw in q for kw in ["附近", "周边", "周围"]):
stool = tool_map.get("maps_around_search")
if stool:
kw_match = re.search(r'(?:搜索|查|找|附近|周边|周围)[:\s]*([^\s,。]+)', q)
keyword = kw_match.group(1) if kw_match else "餐厅"
raw = await stool.ainvoke({"keywords": keyword, "location": "116.397428,39.90923", "radius": "3000"})
result = json.loads(raw[0]["text"]) if isinstance(raw, list) else raw
# IP定位
elif any(kw in q for kw in ["ip", "IP", "定位"]):
itool = tool_map.get("maps_ip_location")
if itool:
ip_match = re.search(r'(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})', q)
ip = ip_match.group(1) if ip_match else ""
raw = await itool.ainvoke({"ip": ip})
result = json.loads(raw[0]["text"]) if isinstance(raw, list) else raw
# 默认:尝试天气
else:
wtool = tool_map.get("maps_weather")
if wtool:
city_match = re.search(r'([\u4e00-\u9fff]{2,4})(?:市|县|区)?', q)
city = city_match.group(1) if city_match else "北京"
raw = await wtool.ainvoke({"city": city})
result = json.loads(raw[0]["text"]) if isinstance(raw, list) else raw
if result:
return json.dumps(result, ensure_ascii=False, indent=2)
return "未能匹配到合适的MCP工具,请尝试更具体的查询。"
result_text = asyncio.run(asyncio.wait_for(_direct_mcp_query(query), timeout=30))
self._json_response(200, {"result": result_text})
except asyncio.TimeoutError:
self._json_response(504, {"error": "MCP 查询超时"})
except Exception as e:
self._json_response(500, {"error": str(e)})
def _handle_restart(self):
"""应用配置到.env并重启服务"""
body = self.rfile.read(int(self.headers.get("Content-Length", 0)))
try:
data = json.loads(body)
env_path = Path(__file__).parent / ".env"
# 备份旧 .env
if env_path.exists():
backup_path = env_path.with_suffix(f".bak.{datetime.now().strftime('%Y%m%d_%H%M%S')}")
shutil.copy2(env_path, backup_path)
print(f"[RESTART] 已备份 .env → {backup_path.name}")
# 构建新配置
updates = {
"LLM_MODEL": data.get("m_model", ""),
"LLM_BASE_URL": data.get("m_url", ""),
"LLM_API_KEY": data.get("m_key", ""),
"LLM_MODEL_W": data.get("w_model", ""),
"LLM_BASE_URL_W": data.get("w_url", ""),
"LLM_API_KEY_W": data.get("w_key", ""),
"LLM_MODEL_M": data.get("mcp_model", ""),
"LLM_BASE_URL_M": data.get("mcp_url", ""),
"LLM_API_KEY_M": data.get("mcp_key", ""),
}
# 回写 .env
if env_path.exists():
lines = env_path.read_text(encoding="utf-8").split("\n")
new_lines = []
updated = set()
for line in lines:
key = line.split("=")[0].strip() if "=" in line else ""
if key in updates and updates[key]:
new_lines.append(f"{key}={updates[key]}")
updated.add(key)
elif key and key in updates and not updates[key]:
continue
else:
new_lines.append(line)
for k, v in updates.items():
if k not in updated and v:
new_lines.append(f"{k}={v}")
env_path.write_text("\n".join(new_lines), encoding="utf-8")
else:
with open(env_path, "w", encoding="utf-8") as f:
for k, v in updates.items():
if v:
f.write(f"{k}={v}\n")
self._json_response(200, {"success": True, "message": "配置已应用并备份,服务正在重启..."})
import threading
def restart():
import time
time.sleep(1)
os.environ["PYTHONIOENCODING"] = "utf-8"
os.execv(sys.executable, [sys.executable] + sys.argv)
threading.Thread(target=restart, daemon=True).start()
except Exception as e:
self._json_response(500, {"error": str(e)})
def _handle_save_config(self):
"""保存配置到 .env,热更新自动生效(无需重启)"""
body = self.rfile.read(int(self.headers.get("Content-Length", 0)))
try:
data = json.loads(body)
env_path = Path(__file__).parent / ".env"
updates = {
"LLM_MODEL": data.get("m_model", ""),
"LLM_BASE_URL": data.get("m_url", ""),
"LLM_API_KEY": data.get("m_key", ""),
"LLM_MODEL_W": data.get("w_model", ""),
"LLM_BASE_URL_W": data.get("w_url", ""),
"LLM_API_KEY_W": data.get("w_key", ""),
"LLM_MODEL_M": data.get("mcp_model", ""),
"LLM_BASE_URL_M": data.get("mcp_url", ""),
"LLM_API_KEY_M": data.get("mcp_key", ""),
}
if env_path.exists():
lines = env_path.read_text(encoding="utf-8").split("\n")
new_lines = []
updated = set()
for line in lines:
key = line.split("=")[0].strip() if "=" in line else ""
if key in updates and updates[key]:
new_lines.append(f"{key}={updates[key]}")
updated.add(key)
elif key and key in updates and not updates[key]:
continue
else:
new_lines.append(line)
for k, v in updates.items():
if k not in updated and v:
new_lines.append(f"{k}={v}")
env_path.write_text("\n".join(new_lines), encoding="utf-8")
else:
with open(env_path, "w", encoding="utf-8") as f:
for k, v in updates.items():
if v: f.write(f"{k}={v}\n")
self._json_response(200, {"success": True, "message": "配置已保存,热更新自动生效"})
except Exception as e:
self._json_response(500, {"error": str(e)})
def _handle_mcp_reconnect(self):
"""强制重连所有MCP服务器"""
from simple_agent.agents.mcp_manager_agent import reconnect_mcp
ok = reconnect_mcp()
self._json_response(200, {"success": ok, "message": "MCP 重连成功" if ok else "MCP 重连失败,使用降级模式"})
def start_chat_server(port=8765):
root_dir = Path(__file__).parent.resolve(); os.chdir(root_dir)
host = os.getenv("AGENT_HOST", "127.0.0.1")
server = HTTPServer((host, port), ChatHandler)
print(f"[Chat] 聊天界面已启动:http://{host}:{port}/static/chat.html")
server.serve_forever()
def _find_langgraph_bin():
base = Path(__file__).parent
for p in [base/".venv"/"Scripts"/"langgraph.exe", base/".venv"/"bin"/"langgraph"]:
if p.exists(): return str(p)
found = shutil.which("langgraph")
if found: return found
return None
def main():
setup_all_logging(); initialize_skills()
# 启动文件监听线程(热更新配置)
from simple_agent.utils.hotreload import start_watcher
start_watcher()
sched = start_scheduler_background()
chat_thread = threading.Thread(target=start_chat_server, daemon=True); chat_thread.start()
# MCP 定时重连(每5分钟)
def mcp_retry_loop():
import time
time.sleep(30) # 先等30秒让首次连接完成
while True:
time.sleep(300)
try:
from simple_agent.agents.mcp_manager_agent import reconnect_mcp
reconnect_mcp()
except Exception:
pass
threading.Thread(target=mcp_retry_loop, daemon=True).start()
host = os.getenv("AGENT_HOST", "127.0.0.1")
if sys.platform == "win32": webbrowser.open(f"http://{host}:8765/static/chat.html")
lb = _find_langgraph_bin()
if not lb: print("[ERROR] 找不到 langgraph"); chat_thread.join(); return
subprocess.run([lb, "dev", "--allow-blocking", "--host", "0.0.0.0", "--port", "2024", "--no-browser", "--no-reload"],
env={**os.environ, "PYTHONIOENCODING": "utf-8"})
if __name__ == "__main__":
main()