691 lines
36 KiB
Python
691 lines
36 KiB
Python
"""同时启动 LangGraph API、调度器、聊天前端,支持技能管理"""
|
||
"""测试"""
|
||
import sys
|
||
import io
|
||
if sys.platform == "win32":
|
||
try:
|
||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
|
||
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace')
|
||
except Exception:
|
||
pass
|
||
|
||
import subprocess, threading, webbrowser, re, os, json, shutil, zipfile, time, secrets, hashlib
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
from http.server import HTTPServer, SimpleHTTPRequestHandler
|
||
from urllib.parse import urlparse, parse_qs
|
||
from email.parser import BytesParser
|
||
|
||
from dotenv import load_dotenv
|
||
from scheduler import start_scheduler_background
|
||
from simple_agent.utils.log_setup import setup_all_logging
|
||
from simple_agent.utils.path_utils import get_workspace_dir, get_skills_dir
|
||
from simple_agent.skills.registry import initialize_skills
|
||
from simple_agent.utils.audit import get_confirmation, update_confirmation
|
||
|
||
load_dotenv()
|
||
|
||
# ========== Auth ==========
|
||
ADMIN_USER = os.getenv("AGENT_ADMIN_USER", "admin")
|
||
PBKDF2_ITERATIONS = 600_000; PBKDF2_HASH_ALGO = "sha256"
|
||
|
||
def _verify_password(password, stored_hash, salt):
|
||
return secrets.compare_digest(
|
||
hashlib.pbkdf2_hmac(PBKDF2_HASH_ALGO, password.encode(), bytes.fromhex(salt), PBKDF2_ITERATIONS, dklen=32).hex(),
|
||
stored_hash
|
||
)
|
||
|
||
def _init_auth():
|
||
pw_hash = os.getenv("AGENT_ADMIN_PASS_HASH", "")
|
||
salt = os.getenv("AGENT_ADMIN_PASS_SALT", "")
|
||
if pw_hash and salt: return pw_hash, salt
|
||
default_password = "Ascii2013!"
|
||
salt = secrets.token_hex(32)
|
||
pw_hash = hashlib.pbkdf2_hmac(PBKDF2_HASH_ALGO, default_password.encode(), bytes.fromhex(salt), PBKDF2_ITERATIONS, dklen=32).hex()
|
||
print(f"\n[AUTH] 首次运行,请将以下配置写入 .env 文件:")
|
||
print(f"AGENT_ADMIN_PASS_HASH={pw_hash}")
|
||
print(f"AGENT_ADMIN_PASS_SALT={salt}\n")
|
||
return pw_hash, salt
|
||
|
||
ADMIN_PASS_HASH, ADMIN_PASS_SALT = _init_auth()
|
||
LOGIN_ATTEMPTS = {}; MAX_LOGIN_ATTEMPTS = 5; LOGIN_LOCKOUT_MINUTES = 15
|
||
_sessions = {}; TOKEN_EXPIRE_HOURS = 24
|
||
|
||
def _new_token(): return secrets.token_hex(32)
|
||
def _cleanup_expired_tokens():
|
||
now = time.time()
|
||
for t in list(_sessions.keys()):
|
||
if now - _sessions[t]["created_at"] > TOKEN_EXPIRE_HOURS * 3600:
|
||
del _sessions[t]
|
||
def _validate_token(token): _cleanup_expired_tokens(); return token in _sessions
|
||
def _check_auth(handler):
|
||
token = handler.headers.get("Authorization", "").replace("Bearer ", "")
|
||
if token and _validate_token(token): return True
|
||
handler._json_response(401, {"error": "未登录或会话已过期"}); return False
|
||
|
||
UUID_PATTERN = re.compile(r'^[a-f0-9\-]{36}$')
|
||
def _is_valid_uuid(s): return bool(UUID_PATTERN.match(s)) if s else False
|
||
|
||
def parse_multipart_body(body, boundary):
|
||
if boundary.startswith(b'"') and boundary.endswith(b'"'): boundary = boundary[1:-1]
|
||
msg = BytesParser().parsebytes(b'Content-Type: multipart/form-data; boundary=' + boundary + b'\r\n\r\n' + body)
|
||
parts = []
|
||
if msg.is_multipart():
|
||
for part in msg.walk():
|
||
if part.get_content_maintype() == 'multipart': continue
|
||
cd = part.get_content_disposition()
|
||
if cd == 'form-data':
|
||
name = part.get_param('name', header='content-disposition')
|
||
fname = part.get_filename()
|
||
data = part.get_payload(decode=True)
|
||
parts.append((name, fname, data))
|
||
return parts
|
||
|
||
class ChatHandler(SimpleHTTPRequestHandler):
|
||
def do_GET(self):
|
||
api_paths = ["/skills/", "/api/"]
|
||
if any(self.path.startswith(p) for p in api_paths) and not _check_auth(self): return
|
||
if self.path.startswith("/workspace/"):
|
||
self._serve_workspace_file()
|
||
elif self.path == "/skills/list":
|
||
try:
|
||
from simple_agent.skills.registry import get_skills_info
|
||
self._json_response(200, get_skills_info())
|
||
except Exception as e: self.send_error(500, str(e))
|
||
elif self.path.startswith("/skills/trash"):
|
||
skills_dir = get_skills_dir(); trash_dir = skills_dir / ".trash"
|
||
items = []
|
||
if trash_dir.exists():
|
||
for item in trash_dir.iterdir():
|
||
if item.is_dir():
|
||
try:
|
||
name, stamp = item.name.rsplit("_", 1)
|
||
items.append({"name": name, "deleted_at": int(stamp), "description": "已删除"})
|
||
except ValueError: continue
|
||
self._json_response(200, items)
|
||
elif self.path.startswith("/api/confirm_info"):
|
||
self._handle_confirm_info()
|
||
elif self.path == "/" or self.path == "":
|
||
self.send_response(302)
|
||
self.send_header("Location", "/static/chat.html")
|
||
self.end_headers()
|
||
else:
|
||
super().do_GET()
|
||
|
||
def do_POST(self):
|
||
if self.path == "/api/login":
|
||
return self._handle_login()
|
||
if not _check_auth(self): return
|
||
if self.path.startswith("/upload"): self._handle_upload()
|
||
elif self.path == "/skills/reload": self._handle_skills_reload()
|
||
elif self.path.startswith("/skills/delete"): self._handle_skills_delete()
|
||
elif self.path == "/skills/upload": self._handle_skills_upload()
|
||
elif self.path.startswith("/skills/permanent_delete"): self._handle_skills_permanent_delete()
|
||
elif self.path.startswith("/skills/recover"): self._handle_skills_recover()
|
||
elif self.path == "/api/write_file": self._handle_write_file()
|
||
elif self.path == "/api/confirm_write": self._handle_confirm_write()
|
||
elif self.path == "/api/deploy": self._handle_deploy()
|
||
elif self.path == "/api/mcp_query": self._handle_mcp_query()
|
||
elif self.path == "/api/mcp_reconnect": self._handle_mcp_reconnect()
|
||
elif self.path == "/api/restart": self._handle_restart()
|
||
elif self.path == "/api/save_config": self._handle_save_config()
|
||
else: super().do_POST()
|
||
|
||
def list_directory(self, path):
|
||
"""禁止目录浏览"""
|
||
self.send_error(403, "Directory listing not allowed")
|
||
|
||
def _json_response(self, code, data):
|
||
self.send_response(code)
|
||
self.send_header("Content-Type", "application/json; charset=utf-8")
|
||
self.send_header("X-Content-Type-Options", "nosniff")
|
||
self.send_header("X-Frame-Options", "DENY")
|
||
self.send_header("Cache-Control", "no-store")
|
||
self.end_headers()
|
||
self.wfile.write(json.dumps(data, ensure_ascii=False).encode())
|
||
|
||
def _get_query_param(self, name):
|
||
query = urlparse(self.path).query; params = parse_qs(query)
|
||
return params.get(name, [None])[0]
|
||
|
||
def _extract_file(self):
|
||
content_type = self.headers.get("Content-Type", "")
|
||
if not content_type.startswith("multipart/form-data"): self.send_error(400, "需要 multipart/form-data"); return None
|
||
boundary = content_type.split("boundary=")[1].encode()
|
||
body = self.rfile.read(int(self.headers.get("Content-Length", 0)))
|
||
parts = parse_multipart_body(body, boundary)
|
||
for name, fname, data in parts:
|
||
if name == "file": return name, fname, data
|
||
self.send_error(400, "缺少文件"); return None
|
||
|
||
def _serve_workspace_file(self):
|
||
"""提供 workspace 目录下文件的静态访问(优先桌面,回退项目目录)"""
|
||
from urllib.parse import unquote
|
||
parsed = urlparse(self.path)
|
||
req_path = unquote(parsed.path)
|
||
rel = req_path[len("/workspace/"):]
|
||
if ".." in rel or "/" in rel.replace("\\", "/").lstrip("/").split("/")[0] or "\\" in rel:
|
||
self.send_error(403, "禁止访问"); return
|
||
safe_name = Path(rel).name
|
||
if not safe_name:
|
||
self.send_error(400, "文件名不能为空"); return
|
||
|
||
# 搜索目录:桌面 AgentWorkspace → 项目 workspace → writer workspace
|
||
from simple_agent.utils.path_security import _get_writer_workspace
|
||
search_dirs = []
|
||
try: search_dirs.append(_get_writer_workspace())
|
||
except: pass
|
||
try: search_dirs.append(get_workspace_dir())
|
||
except: pass
|
||
|
||
found = None
|
||
for sd in search_dirs:
|
||
for f in sd.rglob(safe_name):
|
||
found = f
|
||
break
|
||
if found:
|
||
break
|
||
if not found or not found.is_file():
|
||
self.send_error(404, "文件不存在"); return
|
||
|
||
ext = found.suffix.lower()
|
||
ctype_map = {".svg": "image/svg+xml", ".png": "image/png", ".jpg": "image/jpeg", ".jpeg": "image/jpeg",
|
||
".gif": "image/gif", ".html": "text/html", ".css": "text/css", ".js": "application/javascript",
|
||
".json": "application/json", ".pdf": "application/pdf", ".txt": "text/plain",
|
||
".md": "text/markdown", ".csv": "text/csv"}
|
||
content_type = ctype_map.get(ext, "application/octet-stream")
|
||
self.send_response(200)
|
||
self.send_header("Content-Type", content_type + "; charset=utf-8")
|
||
self.send_header("Cache-Control", "no-cache")
|
||
self.end_headers()
|
||
self.wfile.write(found.read_bytes())
|
||
|
||
def _handle_upload(self):
|
||
extracted = self._extract_file()
|
||
if extracted is None: return
|
||
_, filename, data = extracted
|
||
upload_type = self._get_query_param("type") or "knowledge"
|
||
if upload_type not in ("knowledge", "memory"): self.send_error(400, "type 必须为 knowledge 或 memory"); return
|
||
safe_name = Path(filename).name if filename else "uploaded_file"
|
||
target_dir = get_workspace_dir() / upload_type; target_path = target_dir / safe_name
|
||
with open(target_path, "wb") as f: f.write(data)
|
||
self._json_response(200, {"status": "ok", "path": str(target_path)})
|
||
|
||
def _handle_skills_reload(self):
|
||
try:
|
||
from simple_agent.utils.hotreload import reload_all
|
||
msg = reload_all()
|
||
self._json_response(200, {"status": "ok", "message": msg})
|
||
except Exception as e: self.send_error(500, str(e))
|
||
|
||
def _handle_skills_delete(self):
|
||
name = self._get_query_param("name")
|
||
if not name or ".." in name or "/" in name or "\\" in name: self.send_error(400, "无效的技能名称"); return
|
||
skills_dir = get_skills_dir(); skill_folder = skills_dir / name
|
||
if not skill_folder.exists(): self.send_error(404, "技能不存在"); return
|
||
trash_dir = skills_dir / ".trash"; trash_dir.mkdir(exist_ok=True)
|
||
shutil.move(str(skill_folder), str(trash_dir / f"{name}_{int(time.time()*1000)}"))
|
||
from simple_agent.skills.registry import reload_skills
|
||
msg = reload_skills()
|
||
self._json_response(200, {"status": "ok", "message": f"技能 {name} 已移至垃圾桶。{msg}"})
|
||
|
||
def _handle_skills_upload(self):
|
||
extracted = self._extract_file()
|
||
if extracted is None: return
|
||
_, filename, data = extracted
|
||
if not filename: self.send_error(400, "文件名不能为空"); return
|
||
file_ext = Path(filename).suffix.lower(); skills_dir = get_skills_dir(); skills_dir.mkdir(parents=True, exist_ok=True)
|
||
if file_ext == ".zip":
|
||
temp_zip = skills_dir / "_upload_temp.zip"
|
||
with open(temp_zip, "wb") as f: f.write(data)
|
||
try:
|
||
with zipfile.ZipFile(temp_zip, "r") as zf:
|
||
for member in zf.infolist():
|
||
member_path = Path(member.filename)
|
||
if member_path.is_absolute() or ".." in member_path.parts:
|
||
self.send_error(400, f"ZIP中包含不安全的路径: {member.filename}"); return
|
||
zf.extractall(skills_dir)
|
||
finally: temp_zip.unlink(missing_ok=True)
|
||
for md_file in list(skills_dir.glob("*.md")) + list(skills_dir.glob("*.markdown")):
|
||
content = md_file.read_text(encoding="utf-8"); skill_name = None
|
||
if content.startswith("---"):
|
||
parts = content.split("---", maxsplit=2)
|
||
if len(parts) >= 3:
|
||
try:
|
||
import yaml; meta = yaml.safe_load(parts[1]); skill_name = meta.get("name")
|
||
except: pass
|
||
if not skill_name: skill_name = md_file.stem
|
||
target_dir = skills_dir / skill_name; target_dir.mkdir(parents=True, exist_ok=True)
|
||
shutil.move(str(md_file), str(target_dir / "SKILL.md"))
|
||
elif file_ext in (".md", ".markdown"):
|
||
content = data.decode("utf-8"); import yaml; skill_name = None
|
||
if content.startswith("---"):
|
||
parts = content.split("---", maxsplit=2)
|
||
if len(parts) >= 3:
|
||
try: meta = yaml.safe_load(parts[1]); skill_name = meta.get("name")
|
||
except: pass
|
||
if not skill_name: skill_name = Path(filename).stem
|
||
if ".." in skill_name or "/" in skill_name or "\\" in skill_name: self.send_error(400, "无效的技能名称"); return
|
||
skill_dir = skills_dir / skill_name; skill_dir.mkdir(parents=True, exist_ok=True)
|
||
(skill_dir / "SKILL.md").write_text(content, encoding="utf-8")
|
||
else: self.send_error(400, "不支持的文件类型,请上传 .zip 或 .md"); return
|
||
from simple_agent.skills.registry import reload_skills
|
||
self._json_response(200, {"status": "ok", "message": f"技能已安装。{reload_skills()}"})
|
||
|
||
def _handle_skills_permanent_delete(self):
|
||
name = self._get_query_param("name")
|
||
if not name or ".." in name or "/" in name or "\\" in name or "*" in name: self.send_error(400, "无效的技能名称"); return
|
||
skills_dir = get_skills_dir(); trash_dir = skills_dir / ".trash"
|
||
candidates = sorted(trash_dir.glob(f"{name}_*")) if trash_dir.exists() else []
|
||
if not candidates: self.send_error(404, "垃圾桶中没有该技能"); return
|
||
shutil.rmtree(str(candidates[0]))
|
||
self._json_response(200, {"status": "ok", "message": f"已永久删除技能 {name}。"})
|
||
|
||
def _handle_skills_recover(self):
|
||
name = self._get_query_param("name")
|
||
if not name or ".." in name or "/" in name or "\\" in name or "*" in name: self.send_error(400, "无效的技能名称"); return
|
||
skills_dir = get_skills_dir(); trash_dir = skills_dir / ".trash"
|
||
candidates = sorted(trash_dir.glob(f"{name}_*"), reverse=True) if trash_dir.exists() else []
|
||
if not candidates: self.send_error(404, "垃圾桶中没有该技能"); return
|
||
target = candidates[0]; target.rename(skills_dir / name)
|
||
from simple_agent.skills.registry import reload_skills
|
||
self._json_response(200, {"status": "ok", "message": f"已恢复技能 {name}。"})
|
||
|
||
def _handle_login(self):
|
||
client_ip = self.client_address[0]
|
||
if LOGIN_ATTEMPTS.get(client_ip, 0) >= MAX_LOGIN_ATTEMPTS:
|
||
self._json_response(429, {"error": "登录尝试次数过多,请15分钟后再试"}); return
|
||
body = self.rfile.read(int(self.headers.get("Content-Length", 0)))
|
||
try:
|
||
data = json.loads(body)
|
||
username = data.get("username", ""); password = data.get("password", "")
|
||
if username == ADMIN_USER and _verify_password(password, ADMIN_PASS_HASH, ADMIN_PASS_SALT):
|
||
LOGIN_ATTEMPTS[client_ip] = 0
|
||
token = _new_token(); _sessions[token] = {"user": username, "created_at": time.time(), "ip": client_ip}
|
||
self._json_response(200, {"token": token, "user": username})
|
||
else:
|
||
LOGIN_ATTEMPTS[client_ip] = LOGIN_ATTEMPTS.get(client_ip, 0) + 1
|
||
self._json_response(401, {"error": "用户名或密码错误"})
|
||
except Exception as e: self._json_response(400, {"error": str(e)})
|
||
|
||
def _handle_write_file(self):
|
||
body = self.rfile.read(int(self.headers.get("Content-Length", 0)))
|
||
try:
|
||
data = json.loads(body); file_path = data.get("path", ""); content = data.get("content", "")
|
||
if not file_path: self._json_response(400, {"error": "缺少文件路径"}); return
|
||
from simple_agent.utils.path_security import resolve_write_path
|
||
resolved, err = resolve_write_path(file_path)
|
||
if err: self._json_response(403, {"error": err}); return
|
||
p = Path(resolved)
|
||
try:
|
||
p.parent.mkdir(parents=True, exist_ok=True); p.write_text(content, encoding="utf-8")
|
||
except PermissionError as pe:
|
||
self._json_response(403, {"error": f"权限不足: {pe}"}); return
|
||
self._json_response(200, {"success": True, "path": str(p), "size": len(content)})
|
||
except Exception as e: self._json_response(500, {"error": str(e)})
|
||
|
||
def _handle_confirm_info(self):
|
||
confirm_id = self._get_query_param("confirm_id")
|
||
if not confirm_id: self._json_response(400, {"error": "缺少 confirm_id"}); return
|
||
if not _is_valid_uuid(confirm_id): self._json_response(400, {"error": "confirm_id 格式无效"}); return
|
||
try:
|
||
c = get_confirmation(confirm_id)
|
||
if not c: self._json_response(404, {"error": "确认记录不存在"}); return
|
||
self._json_response(200, {"confirm_id": c["id"], "target_path": c["target_path"],
|
||
"operation_details": c["operation_details"], "risk_analysis": c["risk_analysis"],
|
||
"result": c["result"], "created_at": c["created_at"]})
|
||
except Exception as e: self._json_response(500, {"error": str(e)})
|
||
|
||
def _handle_confirm_write(self):
|
||
body = self.rfile.read(int(self.headers.get("Content-Length", 0)))
|
||
try:
|
||
data = json.loads(body); confirm_id = data.get("confirm_id", ""); choice = data.get("choice", "")
|
||
if not confirm_id or not choice: self._json_response(400, {"error": "缺少 confirm_id 或 choice"}); return
|
||
if not _is_valid_uuid(confirm_id): self._json_response(400, {"error": "confirm_id 格式无效"}); return
|
||
if choice not in ("approve", "reject"): self._json_response(400, {"error": "choice 必须为 approve 或 reject"}); return
|
||
confirmation = get_confirmation(confirm_id)
|
||
if not confirmation: self._json_response(404, {"error": "确认记录不存在"}); return
|
||
if confirmation["result"] != "pending": self._json_response(400, {"error": f"该确认已处理"}); return
|
||
if choice == "approve":
|
||
update_confirmation(confirm_id, "approved", "user")
|
||
fp = confirmation["target_path"]; content = confirmation["content"] or ""
|
||
p = Path(fp)
|
||
try:
|
||
p.parent.mkdir(parents=True, exist_ok=True)
|
||
if p.suffix.lower() == '.pdf':
|
||
from fpdf import FPDF; from simple_agent.utils.path_security import find_cjk_font
|
||
pdf = FPDF(); pdf.add_page(); font_path = find_cjk_font()
|
||
if font_path: pdf.add_font("CJK", "", font_path); pdf.set_font("CJK", "", 12)
|
||
else: pdf.set_font("Helvetica", "", 12)
|
||
for line in content.split("\n"): pdf.cell(0, 10, line, ln=True)
|
||
pdf.output(str(p))
|
||
elif p.suffix.lower() in ('.docx', '.pptx', '.xlsx'):
|
||
self._json_response(400, {"error": f"{p.suffix} 需通过 WriterAgent 直接生成"}); return
|
||
else: p.write_text(content, encoding="utf-8")
|
||
except PermissionError as pe:
|
||
self._json_response(403, {"error": f"权限不足: {pe}"}); return
|
||
self._json_response(200, {"success": True, "message": "写入已确认并执行", "path": str(p), "size": len(content)})
|
||
else:
|
||
update_confirmation(confirm_id, "rejected", "user")
|
||
self._json_response(200, {"success": True, "message": "用户拒绝写入"})
|
||
except Exception as e: self._json_response(500, {"error": str(e)})
|
||
|
||
def _handle_deploy(self):
|
||
"""处理JAR部署:备份→保存→执行→成功/回滚"""
|
||
content_type = self.headers.get("Content-Type", "")
|
||
if "multipart/form-data" not in content_type:
|
||
self._json_response(400, {"error": "需要 multipart/form-data"}); return
|
||
boundary = content_type.split("boundary=")[1].encode()
|
||
body = self.rfile.read(int(self.headers.get("Content-Length", 0)))
|
||
parts = parse_multipart_body(body, boundary)
|
||
jar_data = None; jar_name = "deploy.jar"; target_dir = ""; health_url = ""
|
||
for name, fname, data in parts:
|
||
if name == "jar_file": jar_data = data; jar_name = fname or "deploy.jar"
|
||
elif name == "target_dir": target_dir = data.decode("utf-8")
|
||
elif name == "health_url": health_url = data.decode("utf-8")
|
||
if not jar_data: self._json_response(400, {"error": "缺少 jar_file"}); return
|
||
if not target_dir: self._json_response(400, {"error": "缺少 target_dir"}); return
|
||
target = Path(target_dir); jar_path = target / jar_name
|
||
backup_name = ""; rolled_back = False
|
||
try:
|
||
target.mkdir(parents=True, exist_ok=True)
|
||
# 1. 备份
|
||
if jar_path.exists():
|
||
backup_name = f"{jar_name}.bak.{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||
shutil.move(str(jar_path), str(target / backup_name))
|
||
print(f"[DEPLOY] 已备份: {backup_name}")
|
||
# 2. 保存新JAR
|
||
jar_path.write_bytes(jar_data)
|
||
print(f"[DEPLOY] 已保存: {jar_path} ({len(jar_data)} bytes)")
|
||
result = {"backup": backup_name, "saved": str(jar_path), "started": False, "rolled_back": False, "health_ok": False}
|
||
# 3. 执行 start.sh
|
||
start_script = target / "start.sh"
|
||
if start_script.exists():
|
||
if sys.platform == "win32":
|
||
proc = subprocess.run(["cmd","/c",str(start_script)], cwd=str(target), capture_output=True, text=True, timeout=120)
|
||
else:
|
||
proc = subprocess.run(["bash",str(start_script)], cwd=str(target), capture_output=True, text=True, timeout=120)
|
||
result["started"] = True; result["returncode"] = proc.returncode
|
||
result["output"] = (proc.stdout + proc.stderr)[-1000:]
|
||
print(f"[DEPLOY] start.sh 完成 (code={proc.returncode})")
|
||
# 4. 健康检查:等待服务启动,有响应即视为成功
|
||
health_ok = False
|
||
if health_url and proc.returncode == 0:
|
||
import urllib.request
|
||
for attempt in range(5):
|
||
time.sleep(3)
|
||
try:
|
||
resp = urllib.request.urlopen(health_url, timeout=5)
|
||
health_ok = True; result["health_ok"] = True
|
||
print(f"[DEPLOY] 健康检查通过 ({resp.status})")
|
||
break
|
||
except urllib.error.HTTPError as e:
|
||
health_ok = True; result["health_ok"] = True
|
||
print(f"[DEPLOY] 服务已响应 ({e.code})")
|
||
break
|
||
except Exception:
|
||
print(f"[DEPLOY] 健康检查重试 {attempt+1}/5...")
|
||
elif health_url and proc.returncode != 0:
|
||
health_ok = False
|
||
else:
|
||
health_ok = (proc.returncode == 0) # 无 health_url 时只看退出码
|
||
# 5. 失败且可回滚
|
||
if not health_ok and backup_name:
|
||
backup_path = target / backup_name
|
||
if backup_path.exists():
|
||
jar_path.unlink(missing_ok=True)
|
||
shutil.move(str(backup_path), str(jar_path))
|
||
result["rolled_back"] = True
|
||
print(f"[DEPLOY] 已回滚: {backup_name} -> {jar_name}")
|
||
else:
|
||
print(f"[DEPLOY] start.sh 不存在,跳过")
|
||
self._json_response(200, {"success": True, **result})
|
||
except subprocess.TimeoutExpired:
|
||
self._json_response(500, {"error": "start.sh 执行超时"})
|
||
except PermissionError as e:
|
||
self._json_response(403, {"error": f"权限不足: {e}"})
|
||
except Exception as e:
|
||
self._json_response(500, {"error": str(e)})
|
||
|
||
def _handle_mcp_query(self):
|
||
"""只读MCP查询——直接调用MCP工具,不经过LLM代理"""
|
||
body = self.rfile.read(int(self.headers.get("Content-Length", 0)))
|
||
try:
|
||
data = json.loads(body)
|
||
query = data.get("query", "")
|
||
if not query:
|
||
self._json_response(400, {"error": "缺少 query 参数"}); return
|
||
|
||
import asyncio
|
||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||
|
||
amap_key = os.getenv("AMAP_KEY", "")
|
||
if not amap_key:
|
||
self._json_response(503, {"error": "AMAP_KEY 未配置"}); return
|
||
|
||
async def _direct_mcp_query(q: str) -> str:
|
||
client = MultiServerMCPClient({
|
||
"amap": {
|
||
"transport": "http",
|
||
"url": f"https://mcp.amap.com/mcp?key={amap_key}",
|
||
"headers": {}
|
||
}
|
||
})
|
||
tools = await asyncio.wait_for(client.get_tools(), timeout=10)
|
||
tool_map = {t.name: t for t in tools}
|
||
|
||
# 关键词→工具路由
|
||
q_lower = q.lower()
|
||
result = None
|
||
|
||
# 天气
|
||
if any(kw in q for kw in ["天气", "weather"]):
|
||
city = "北京"
|
||
city_match = re.search(r'([\u4e00-\u9fff]{2,4}(?:市|县|区)?)\s*(?:今天|明天|后天|的|这)', q)
|
||
if city_match:
|
||
city = city_match.group(1).rstrip("市")
|
||
else:
|
||
city_match2 = re.search(r'查询\s*([\u4e00-\u9fff]{2,4})', q)
|
||
if city_match2:
|
||
city = city_match2.group(1)
|
||
wtool = tool_map.get("maps_weather")
|
||
if wtool:
|
||
raw = await wtool.ainvoke({"city": city})
|
||
result = json.loads(raw[0]["text"]) if isinstance(raw, list) else raw
|
||
|
||
# 地理编码
|
||
elif any(kw in q for kw in ["地址", "地理编码", "经纬度", "在哪", "哪里"]):
|
||
gtool = tool_map.get("maps_geo")
|
||
if gtool:
|
||
addr_match = re.search(r'(?:地址|查询|搜索|编码)[::\s]*([^\s,,。]+)', q)
|
||
addr = addr_match.group(1) if addr_match else q[-20:]
|
||
raw = await gtool.ainvoke({"address": addr})
|
||
result = json.loads(raw[0]["text"]) if isinstance(raw, list) else raw
|
||
|
||
# 周边搜索
|
||
elif any(kw in q for kw in ["附近", "周边", "周围"]):
|
||
stool = tool_map.get("maps_around_search")
|
||
if stool:
|
||
kw_match = re.search(r'(?:搜索|查|找|附近|周边|周围)[::\s]*([^\s,,。]+)', q)
|
||
keyword = kw_match.group(1) if kw_match else "餐厅"
|
||
raw = await stool.ainvoke({"keywords": keyword, "location": "116.397428,39.90923", "radius": "3000"})
|
||
result = json.loads(raw[0]["text"]) if isinstance(raw, list) else raw
|
||
|
||
# IP定位
|
||
elif any(kw in q for kw in ["ip", "IP", "定位"]):
|
||
itool = tool_map.get("maps_ip_location")
|
||
if itool:
|
||
ip_match = re.search(r'(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})', q)
|
||
ip = ip_match.group(1) if ip_match else ""
|
||
raw = await itool.ainvoke({"ip": ip})
|
||
result = json.loads(raw[0]["text"]) if isinstance(raw, list) else raw
|
||
|
||
# 默认:尝试天气
|
||
else:
|
||
wtool = tool_map.get("maps_weather")
|
||
if wtool:
|
||
city_match = re.search(r'([\u4e00-\u9fff]{2,4})(?:市|县|区)?', q)
|
||
city = city_match.group(1) if city_match else "北京"
|
||
raw = await wtool.ainvoke({"city": city})
|
||
result = json.loads(raw[0]["text"]) if isinstance(raw, list) else raw
|
||
|
||
if result:
|
||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||
return "未能匹配到合适的MCP工具,请尝试更具体的查询。"
|
||
|
||
result_text = asyncio.run(asyncio.wait_for(_direct_mcp_query(query), timeout=30))
|
||
self._json_response(200, {"result": result_text})
|
||
except asyncio.TimeoutError:
|
||
self._json_response(504, {"error": "MCP 查询超时"})
|
||
except Exception as e:
|
||
self._json_response(500, {"error": str(e)})
|
||
|
||
def _handle_restart(self):
|
||
"""应用配置到.env并重启服务"""
|
||
body = self.rfile.read(int(self.headers.get("Content-Length", 0)))
|
||
try:
|
||
data = json.loads(body)
|
||
env_path = Path(__file__).parent / ".env"
|
||
# 备份旧 .env
|
||
if env_path.exists():
|
||
backup_path = env_path.with_suffix(f".bak.{datetime.now().strftime('%Y%m%d_%H%M%S')}")
|
||
shutil.copy2(env_path, backup_path)
|
||
print(f"[RESTART] 已备份 .env → {backup_path.name}")
|
||
# 构建新配置
|
||
updates = {
|
||
"LLM_MODEL": data.get("m_model", ""),
|
||
"LLM_BASE_URL": data.get("m_url", ""),
|
||
"LLM_API_KEY": data.get("m_key", ""),
|
||
"LLM_MODEL_W": data.get("w_model", ""),
|
||
"LLM_BASE_URL_W": data.get("w_url", ""),
|
||
"LLM_API_KEY_W": data.get("w_key", ""),
|
||
"LLM_MODEL_M": data.get("mcp_model", ""),
|
||
"LLM_BASE_URL_M": data.get("mcp_url", ""),
|
||
"LLM_API_KEY_M": data.get("mcp_key", ""),
|
||
}
|
||
# 回写 .env
|
||
if env_path.exists():
|
||
lines = env_path.read_text(encoding="utf-8").split("\n")
|
||
new_lines = []
|
||
updated = set()
|
||
for line in lines:
|
||
key = line.split("=")[0].strip() if "=" in line else ""
|
||
if key in updates and updates[key]:
|
||
new_lines.append(f"{key}={updates[key]}")
|
||
updated.add(key)
|
||
elif key and key in updates and not updates[key]:
|
||
continue
|
||
else:
|
||
new_lines.append(line)
|
||
for k, v in updates.items():
|
||
if k not in updated and v:
|
||
new_lines.append(f"{k}={v}")
|
||
env_path.write_text("\n".join(new_lines), encoding="utf-8")
|
||
else:
|
||
with open(env_path, "w", encoding="utf-8") as f:
|
||
for k, v in updates.items():
|
||
if v:
|
||
f.write(f"{k}={v}\n")
|
||
self._json_response(200, {"success": True, "message": "配置已应用并备份,服务正在重启..."})
|
||
import threading
|
||
def restart():
|
||
import time
|
||
time.sleep(1)
|
||
os.environ["PYTHONIOENCODING"] = "utf-8"
|
||
os.execv(sys.executable, [sys.executable] + sys.argv)
|
||
threading.Thread(target=restart, daemon=True).start()
|
||
except Exception as e:
|
||
self._json_response(500, {"error": str(e)})
|
||
|
||
def _handle_save_config(self):
|
||
"""保存配置到 .env,热更新自动生效(无需重启)"""
|
||
body = self.rfile.read(int(self.headers.get("Content-Length", 0)))
|
||
try:
|
||
data = json.loads(body)
|
||
env_path = Path(__file__).parent / ".env"
|
||
updates = {
|
||
"LLM_MODEL": data.get("m_model", ""),
|
||
"LLM_BASE_URL": data.get("m_url", ""),
|
||
"LLM_API_KEY": data.get("m_key", ""),
|
||
"LLM_MODEL_W": data.get("w_model", ""),
|
||
"LLM_BASE_URL_W": data.get("w_url", ""),
|
||
"LLM_API_KEY_W": data.get("w_key", ""),
|
||
"LLM_MODEL_M": data.get("mcp_model", ""),
|
||
"LLM_BASE_URL_M": data.get("mcp_url", ""),
|
||
"LLM_API_KEY_M": data.get("mcp_key", ""),
|
||
}
|
||
if env_path.exists():
|
||
lines = env_path.read_text(encoding="utf-8").split("\n")
|
||
new_lines = []
|
||
updated = set()
|
||
for line in lines:
|
||
key = line.split("=")[0].strip() if "=" in line else ""
|
||
if key in updates and updates[key]:
|
||
new_lines.append(f"{key}={updates[key]}")
|
||
updated.add(key)
|
||
elif key and key in updates and not updates[key]:
|
||
continue
|
||
else:
|
||
new_lines.append(line)
|
||
for k, v in updates.items():
|
||
if k not in updated and v:
|
||
new_lines.append(f"{k}={v}")
|
||
env_path.write_text("\n".join(new_lines), encoding="utf-8")
|
||
else:
|
||
with open(env_path, "w", encoding="utf-8") as f:
|
||
for k, v in updates.items():
|
||
if v: f.write(f"{k}={v}\n")
|
||
self._json_response(200, {"success": True, "message": "配置已保存,热更新自动生效"})
|
||
except Exception as e:
|
||
self._json_response(500, {"error": str(e)})
|
||
|
||
def _handle_mcp_reconnect(self):
|
||
"""强制重连所有MCP服务器"""
|
||
from simple_agent.agents.mcp_manager_agent import reconnect_mcp
|
||
ok = reconnect_mcp()
|
||
self._json_response(200, {"success": ok, "message": "MCP 重连成功" if ok else "MCP 重连失败,使用降级模式"})
|
||
|
||
def start_chat_server(port=8765):
|
||
root_dir = Path(__file__).parent.resolve(); os.chdir(root_dir)
|
||
host = os.getenv("AGENT_HOST", "127.0.0.1")
|
||
server = HTTPServer((host, port), ChatHandler)
|
||
print(f"[Chat] 聊天界面已启动:http://{host}:{port}/static/chat.html")
|
||
server.serve_forever()
|
||
|
||
def _find_langgraph_bin():
|
||
base = Path(__file__).parent
|
||
for p in [base/".venv"/"Scripts"/"langgraph.exe", base/".venv"/"bin"/"langgraph"]:
|
||
if p.exists(): return str(p)
|
||
found = shutil.which("langgraph")
|
||
if found: return found
|
||
return None
|
||
|
||
def main():
|
||
setup_all_logging(); initialize_skills()
|
||
# 启动文件监听线程(热更新配置)
|
||
from simple_agent.utils.hotreload import start_watcher
|
||
start_watcher()
|
||
sched = start_scheduler_background()
|
||
chat_thread = threading.Thread(target=start_chat_server, daemon=True); chat_thread.start()
|
||
# MCP 定时重连(每5分钟)
|
||
def mcp_retry_loop():
|
||
import time
|
||
time.sleep(30) # 先等30秒让首次连接完成
|
||
while True:
|
||
time.sleep(300)
|
||
try:
|
||
from simple_agent.agents.mcp_manager_agent import reconnect_mcp
|
||
reconnect_mcp()
|
||
except Exception:
|
||
pass
|
||
threading.Thread(target=mcp_retry_loop, daemon=True).start()
|
||
host = os.getenv("AGENT_HOST", "127.0.0.1")
|
||
if sys.platform == "win32": webbrowser.open(f"http://{host}:8765/static/chat.html")
|
||
lb = _find_langgraph_bin()
|
||
if not lb: print("[ERROR] 找不到 langgraph"); chat_thread.join(); return
|
||
subprocess.run([lb, "dev", "--allow-blocking", "--host", "0.0.0.0", "--port", "2024", "--no-browser", "--no-reload"],
|
||
env={**os.environ, "PYTHONIOENCODING": "utf-8"})
|
||
|
||
if __name__ == "__main__":
|
||
main()
|