尊敬的各位技术同仁,大家好!
今天,我们将深入探讨一个在现代计算中日益重要的话题:计算负载均衡,尤其是在我们如何将那些耗时且复杂的“思维链”计算,高效地分发到我们身边那些常常被闲置的本地算力节点上运行。
在我们的日常工作中,无论是进行复杂的数据分析、大规模的模拟计算、AI模型的训练与推理,还是执行某些需要大量迭代与试错的优化算法,我们常常会遇到一个瓶颈:单台机器的计算能力不足以在可接受的时间内完成任务。此时,我们可能会想到使用云计算资源,但对于一些对数据敏感、对延迟有高要求,或者仅仅是想充分利用现有硬件资源而不想增加额外开销的场景,将计算任务智能地分发到本地网络中多台机器上,无疑是一个极具吸引力的解决方案。
我们将把这个过程类比为人类的“思维链”:一个宏大而复杂的思考过程,往往可以被分解成一系列更小、更具体的子问题,这些子问题可以并行解决,或者以特定的顺序依赖解决。我们的目标,就是构建一个系统,能够像一个高效的大脑,将这些“思维碎片”智能地分配给网络中那些“空闲的大脑”去处理,最终将结果汇集起来,形成完整的“思考”。
1. 本地分布式计算的必要性与核心挑战
首先,让我们明确一下,我们所说的“耗时思维链计算”通常指的是哪些类型。它们可以是:
- 科学模拟与工程计算: 如蒙特卡洛模拟、有限元分析、分子动力学。
- 大数据处理: ETL批处理、日志分析、报表生成。
- 人工智能: 大规模数据集上的模型训练、批量推理、超参数调优。
- 优化问题: 组合优化、遗传算法、强化学习环境的并行探索。
- 开发与测试: 大规模并行测试用例执行、编译农场。
这些任务的共同特点是:它们往往可以被分解为多个独立的子任务,或者虽然存在依赖,但依赖关系清晰且可控,从而为并行处理提供了可能。
为什么选择本地分布式而非云计算?
- 数据隐私与安全: 敏感数据无需离开本地网络。
- 低延迟: 局域网通信速度远超广域网,对于交互式或实时性要求高的任务至关重要。
- 成本效益: 充分利用已有的PC、服务器、工作站等硬件资源,无需额外支付云服务费用。
- 资源利用率: 许多办公室或实验室的PC在非工作时间或用户执行轻量级任务时,CPU、GPU等资源处于闲置状态。
核心挑战:
- 节点发现与管理: 如何知道网络中有哪些节点可用?它们的状态如何?
- 任务分解与调度: 如何将复杂任务拆分?如何智能地将子任务分配给最合适的节点?
- 数据传输与同步: 子任务所需数据如何高效传输?结果如何汇总?
- 容错性: 如果某个节点故障,如何处理?任务是否可以重新分配?
- 异构性: 如何处理节点间计算能力、内存、存储等差异?
接下来的讲座中,我们将逐一攻克这些挑战,并通过实际代码示例来演示如何构建一个实用的本地负载均衡系统。
2. 基础架构与组件:构建分布式系统的基石
一个典型的本地负载均衡系统,通常包含以下核心组件:
| 组件名称 | 职责 | 关键技术点 |
|---|---|---|
| 主控节点 (Master/Scheduler) | 接收原始任务,分解任务,管理任务队列,发现和监控工作节点,调度任务到工作节点,收集和聚合结果。 | 任务队列管理,节点注册与健康检查,负载均衡算法,结果汇总。 |
| 工作节点 (Worker/Agent) | 向主控节点注册,接收并执行分配的任务,向主控节点报告任务状态和结果,监控自身资源使用情况。 | 任务执行器,心跳机制,资源监控,结果上报。 |
| 任务 (Task) | 待执行的最小计算单元,包含执行逻辑和所需数据。 | 任务序列化与反序列化,任务ID,状态管理。 |
| 通信机制 | 主控节点与工作节点之间进行数据和控制指令交换的方式。 | TCP/UDP Socket,RPC (Remote Procedure Call),消息队列。 |
我们将重点关注主控-工作节点(Master-Worker)架构,因为它在实现上相对直观,且非常适合我们讨论的“思维链”计算场景。
3. 节点发现与健康管理
在动态的本地网络环境中,我们首先需要知道有哪些节点可以参与计算,以及它们当前的健康状况和负载情况。
3.1 节点发现策略
- 静态配置: 最简单的方式,预先配置好所有工作节点的IP地址和端口。适用于节点数量固定且不常变化的场景。
- UDP广播/多播: 工作节点在启动时向局域网发送广播/多播消息,宣告自己的存在;主控节点监听这些消息以发现工作节点。
- 服务注册与发现: 工作节点向一个中心化的服务注册中心(如Consul, etcd, ZooKeeper,或者一个简单的Redis实例)注册自己,主控节点从注册中心查询可用节点。
对于本地环境,UDP广播/多播是一种轻量且无需额外服务部署的有效方式。
代码示例:UDP广播进行节点发现 (Python)
Worker 端 (worker_discovery.py):
import socket
import time
import json
import psutil # For resource monitoring
BROADCAST_PORT = 50000
MASTER_PORT = 50001 # Master will listen on this port for specific commands
WORKER_ID = f"worker_{socket.gethostname()}_{time.time_ns()}"
def get_system_info():
"""获取系统资源信息"""
cpu_percent = psutil.cpu_percent(interval=1)
memory_info = psutil.virtual_memory()
disk_info = psutil.disk_usage('/')
return {
"cpu_percent": cpu_percent,
"memory_percent": memory_info.percent,
"disk_percent": disk_info.percent,
"worker_id": WORKER_ID,
"worker_ip": socket.gethostbyname(socket.gethostname()),
"status": "available"
}
def start_worker_discovery():
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
print(f"Worker {WORKER_ID} starting discovery...")
while True:
try:
sys_info = get_system_info()
message = json.dumps(sys_info).encode('utf-8')
sock.sendto(message, ('<broadcast>', BROADCAST_PORT))
print(f"Sent discovery message: {sys_info['cpu_percent']}% CPU")
time.sleep(5) # 每5秒广播一次
except Exception as e:
print(f"Error during discovery: {e}")
time.sleep(5)
if __name__ == "__main__":
start_worker_discovery()
Master 端 (master_discovery.py):
import socket
import json
import threading
import time
BROADCAST_PORT = 50000
MASTER_LISTEN_PORT = 50001
WORKER_TIMEOUT = 15 # Seconds
class WorkerRegistry:
def __init__(self):
self.workers = {} # {worker_id: {"last_seen": timestamp, "info": {...}}}
self.lock = threading.Lock()
def update_worker(self, worker_id, worker_info):
with self.lock:
self.workers[worker_id] = {
"last_seen": time.time(),
"info": worker_info
}
print(f"Worker {worker_id} updated. CPU: {worker_info['cpu_percent']}%")
def get_available_workers(self):
available = []
current_time = time.time()
with self.lock:
for worker_id, data in list(self.workers.items()):
if current_time - data["last_seen"] < WORKER_TIMEOUT:
available.append(data["info"])
else:
print(f"Worker {worker_id} timed out.")
del self.workers[worker_id]
return available
def print_workers(self):
with self.lock:
print("n--- Current Workers ---")
for worker_id, data in self.workers.items():
print(f" ID: {worker_id}, IP: {data['info']['worker_ip']}, CPU: {data['info']['cpu_percent']}%, Last Seen: {time.time() - data['last_seen']:.2f}s ago")
print("-----------------------n")
def listen_for_workers(registry):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# 绑定到0.0.0.0以接收所有接口的广播消息
sock.bind(('', BROADCAST_PORT))
print(f"Master listening for worker broadcasts on port {BROADCAST_PORT}...")
while True:
try:
data, addr = sock.recvfrom(1024)
worker_info = json.loads(data.decode('utf-8'))
registry.update_worker(worker_info["worker_id"], worker_info)
except Exception as e:
print(f"Error receiving worker broadcast: {e}")
if __name__ == "__main__":
worker_registry = WorkerRegistry()
# 启动一个线程来监听工作节点广播
listener_thread = threading.Thread(target=listen_for_workers, args=(worker_registry,), daemon=True)
listener_thread.start()
# 主线程定期打印工作节点列表
while True:
worker_registry.print_workers()
time.sleep(10)
运行master_discovery.py,然后在不同的终端或机器上运行worker_discovery.py,你将看到主控节点能够发现并跟踪工作节点。
3.2 健康监控与资源评估
除了发现节点,我们还需要持续监控它们的健康状态和资源利用率,以便进行智能的任务调度。心跳机制是常用的手段:工作节点定期向主控节点发送“我活着”的消息,并附带当前的资源使用情况(CPU、内存、网络IO等)。如果主控节点在一定时间内没有收到某个节点的心跳,则认为该节点离线或故障。
在上面的UDP广播示例中,我们已经将资源信息包含在了广播消息中,这可以作为一种简化的心跳机制。
4. 任务的定义、序列化与传输
“思维链计算”的核心在于“任务”。如何定义这些任务,使它们能够在不同节点间传输和执行,是关键。
4.1 任务粒度
- 粗粒度任务: 一个任务包含大量计算,通信开销相对较小。优点是管理简单,缺点是并行度可能受限,且如果任务失败,需要重做的工作量大。
- 细粒度任务: 一个任务只包含少量计算,通信开销相对较大。优点是并行度高,容错性好,缺点是管理复杂,通信开销可能成为瓶颈。
对于“思维链”计算,我们通常会尝试将其分解为中等粒度的独立或弱依赖子任务,以平衡通信和计算的开销。
4.2 任务序列化
为了在网络上传输任务,我们需要将任务的执行逻辑、参数和数据从一种编程语言的对象形式转换为可传输的字节流,这个过程称为序列化。接收方则进行反序列化。
常用的序列化格式:
- JSON: 易读,跨语言兼容性好,但对于复杂对象和二进制数据支持不佳。
- Pickle (Python): Python特有的序列化库,能处理几乎所有Python对象,但存在安全风险(反序列化恶意代码),且非跨语言。
- Protocol Buffers (Protobuf), Apache Thrift, Avro: 语言无关,高效,紧凑,需要定义Schema。
- MessagePack: 类似于JSON,但更紧凑,更快。
对于Python环境,如果信任所有工作节点,pickle方便但有安全隐患。更安全的做法是,将任务定义为简单的函数名和参数列表,然后使用JSON或MessagePack序列化。任务的实际代码逻辑需要预先部署到所有工作节点上。
代码示例:任务定义与序列化
task_definitions.py (所有节点共享):
import time
import math
import random
def heavy_computation_task(task_id, iterations):
"""一个模拟耗时计算的函数"""
print(f"[{task_id}] Starting heavy computation with {iterations} iterations...")
result = 0
for i in range(iterations):
# 模拟复杂的数学运算
result += math.sin(i) * math.cos(math.sqrt(i + random.random()))
if i % (iterations // 10 if iterations > 10 else 1) == 0:
time.sleep(0.001) # 模拟IO或少量等待
print(f"[{task_id}] Finished heavy computation. Result snippet: {result:.4f}")
return {"task_id": task_id, "status": "completed", "result_snippet": result}
def another_task(task_id, data_size):
"""另一个简单的任务"""
print(f"[{task_id}] Starting another task with data size {data_size}...")
time.sleep(data_size / 1000) # 模拟基于数据量的处理时间
result = f"Processed {data_size} units."
print(f"[{task_id}] Finished another task. Result: {result}")
return {"task_id": task_id, "status": "completed", "result": result}
# 任务注册表,用于在worker端根据名称找到对应的函数
TASK_FUNCTIONS = {
"heavy_computation": heavy_computation_task,
"another_task": another_task,
}
Master端创建任务时:
import json
import uuid
from task_definitions import TASK_FUNCTIONS # 导入只是为了知道有哪些任务可用,实际不执行
class Task:
def __init__(self, function_name, args, kwargs=None, task_id=None):
self.task_id = task_id if task_id else str(uuid.uuid4())
self.function_name = function_name
self.args = args if args is not None else []
self.kwargs = kwargs if kwargs is not None else {}
self.status = "PENDING"
self.result = None
self.worker_id = None
self.start_time = None
self.end_time = None
def to_json(self):
# 将任务对象序列化为JSON字符串
return json.dumps({
"task_id": self.task_id,
"function_name": self.function_name,
"args": self.args,
"kwargs": self.kwargs,
"status": self.status,
# 结果和worker_id等信息在传输时不包含,因为是worker执行后才有的
})
@staticmethod
def from_json(json_string):
data = json.loads(json_string)
task = Task(data["function_name"], data["args"], data["kwargs"], data["task_id"])
task.status = data.get("status", "PENDING")
return task
# 示例:创建任务
task1 = Task("heavy_computation", args=[1000000])
task2 = Task("another_task", kwargs={"data_size": 500})
print(f"Serialized Task 1: {task1.to_json()}")
deserialized_task = Task.from_json(task1.to_json())
print(f"Deserialized Task 1 function name: {deserialized_task.function_name}")
这里我们定义了一个Task类来封装任务信息,并提供了to_json和from_json方法进行序列化和反序列化。task_definitions.py文件中的实际函数代码需要预先部署到所有工作节点上。
5. 负载均衡策略与任务调度
有了可用的工作节点和定义好的任务,下一步就是如何有效地将任务分配给节点。这就是负载均衡的核心。
5.1 静态与动态负载均衡
- 静态策略: 基于预设规则(如轮询、哈希)分配任务,不考虑节点实时负载。
- 轮询 (Round Robin): 依次将任务分配给每个工作节点。简单,但可能导致慢节点过载。
- 加权轮询 (Weighted Round Robin): 根据节点的计算能力分配不同的权重,能力强的节点获得更多任务。
- 动态策略: 根据节点的实时负载、可用资源、响应时间等信息动态调整任务分配。
- 最少连接/最少任务 (Least Connections/Least Tasks): 将任务分配给当前活动连接数最少或正在执行任务最少的节点。
- 最少资源利用率 (Least Resource Utilization): 将任务分配给CPU、内存等资源利用率最低的节点。
- 任务窃取 (Work Stealing): 闲置的工作节点主动从繁忙的工作节点“窃取”任务来执行。
对于我们讨论的“思维链”计算,动态策略通常更优,尤其是基于最少资源利用率或最少任务数的策略,配合主控节点进行集中调度。
5.2 主控-工作节点架构下的调度器实现
在Master-Worker模式中,主控节点扮演着中央调度器的角色。它维护一个任务队列和工作节点注册表。
调度流程:
- 任务入队: 新任务生成后,被添加到主控节点的待处理任务队列。
- 节点选择: 调度器定期检查可用工作节点列表,根据负载均衡策略选择一个“最佳”节点。
- 任务分发: 调度器将一个待处理任务从队列中取出,序列化后发送给选定的工作节点。
- 状态更新: 任务状态从“PENDING”更新为“ASSIGNED”或“RUNNING”。
- 结果收集: 工作节点完成任务后,将结果发回主控节点。主控节点更新任务状态为“COMPLETED”并存储结果。
- 错误处理: 如果工作节点报告任务失败,或长时间未响应,调度器可以将任务标记为失败或重新排队。
代码示例:Master-Worker核心调度逻辑 (Python)
我们将扩展之前的master_discovery.py和worker_discovery.py,加入任务调度和执行功能。
Master 端 (master_node.py):
import socket
import json
import threading
import time
import queue
import uuid
import select
from task_definitions import TASK_FUNCTIONS # 导入只是为了参考,master不执行任务
BROADCAST_PORT = 50000
MASTER_LISTEN_PORT = 50001
WORKER_TIMEOUT = 15 # Seconds for worker heartbeat
TASK_EXECUTION_TIMEOUT = 300 # Seconds for a single task execution
class Task:
def __init__(self, function_name, args, kwargs=None, task_id=None):
self.task_id = task_id if task_id else str(uuid.uuid4())
self.function_name = function_name
self.args = args if args is not None else []
self.kwargs = kwargs if kwargs is not None else {}
self.status = "PENDING" # PENDING, ASSIGNED, RUNNING, COMPLETED, FAILED
self.result = None
self.worker_id = None
self.start_time = None
self.end_time = None
self.retries = 0
def to_json(self):
return json.dumps({
"task_id": self.task_id,
"function_name": self.function_name,
"args": self.args,
"kwargs": self.kwargs,
})
@staticmethod
def from_json(json_string):
data = json.loads(json_string)
task = Task(data["function_name"], data["args"], data["kwargs"], data["task_id"])
return task
def __repr__(self):
return f"Task(ID: {self.task_id[:8]}, Func: {self.function_name}, Status: {self.status}, Worker: {self.worker_id})"
class MasterNode:
def __init__(self):
self.worker_registry = {} # {worker_id: {"last_seen": timestamp, "info": {...}, "socket": conn_socket}}
self.worker_lock = threading.Lock()
self.task_queue = queue.Queue() # PENDING tasks
self.running_tasks = {} # {task_id: Task object}
self.completed_tasks = {} # {task_id: Task object}
self.failed_tasks = {} # {task_id: Task object}
self.task_lock = threading.Lock()
self.master_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.master_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.master_socket.bind(('', MASTER_LISTEN_PORT))
self.master_socket.listen(5)
self.master_socket.setblocking(False) # Non-blocking for select
print(f"Master listening for worker connections on port {MASTER_LISTEN_PORT}...")
self.inputs = [self.master_socket] # Sockets to monitor for readability
# Start background threads
threading.Thread(target=self._listen_for_worker_broadcasts, daemon=True).start()
threading.Thread(target=self._manage_worker_connections, daemon=True).start()
threading.Thread(target=self._task_scheduler, daemon=True).start()
threading.Thread(target=self._monitor_tasks, daemon=True).start()
threading.Thread(target=self._display_status, daemon=True).start()
def _listen_for_worker_broadcasts(self):
# UDP socket for discovery
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(('', BROADCAST_PORT))
print(f"Master listening for worker broadcasts on port {BROADCAST_PORT}...")
while True:
try:
data, addr = sock.recvfrom(1024)
worker_info = json.loads(data.decode('utf-8'))
worker_id = worker_info["worker_id"]
worker_ip = worker_info["worker_ip"]
with self.worker_lock:
if worker_id not in self.worker_registry:
print(f"Discovered new worker: {worker_id} at {worker_ip}. Waiting for TCP connection.")
# Update worker info from broadcast
if worker_id in self.worker_registry:
self.worker_registry[worker_id]["last_seen"] = time.time()
self.worker_registry[worker_id]["info"].update(worker_info)
else:
# Placeholder until TCP connection is established
self.worker_registry[worker_id] = {
"last_seen": time.time(),
"info": worker_info,
"socket": None,
"is_connected": False,
"active_tasks": 0 # Number of tasks assigned to this worker
}
except Exception as e:
print(f"Error receiving worker broadcast: {e}")
def _manage_worker_connections(self):
while True:
readable, _, _ = select.select(self.inputs, [], [], 1) # 1 second timeout
for s in readable:
if s is self.master_socket:
conn, addr = self.master_socket.accept()
conn.setblocking(False)
self.inputs.append(conn)
print(f"Accepted connection from {addr}")
else:
try:
data = s.recv(4096)
if data:
self._handle_worker_message(s, data)
else: # Connection closed by client
self._remove_worker_connection(s)
except ConnectionResetError:
self._remove_worker_connection(s)
except Exception as e:
# print(f"Error handling worker message: {e}")
pass # Non-blocking socket might raise errors if no data
self._check_worker_timeouts()
time.sleep(0.1) # Small delay to avoid busy-waiting
def _handle_worker_message(self, conn_socket, data):
message = data.decode('utf-8')
try:
msg_obj = json.loads(message)
msg_type = msg_obj.get("type")
worker_id = msg_obj.get("worker_id")
if msg_type == "REGISTER":
# Worker sends its ID after connecting
with self.worker_lock:
if worker_id in self.worker_registry:
self.worker_registry[worker_id]["socket"] = conn_socket
self.worker_registry[worker_id]["is_connected"] = True
print(f"Worker {worker_id} registered successfully via TCP.")
# Send back a confirmation
conn_socket.sendall(json.dumps({"type": "REGISTER_ACK"}).encode('utf-8'))
else:
print(f"Worker {worker_id} connected but not discovered yet. Closing connection.")
conn_socket.close()
self.inputs.remove(conn_socket)
elif msg_type == "TASK_RESULT":
task_id = msg_obj["task_id"]
status = msg_obj["status"]
result = msg_obj.get("result")
error = msg_obj.get("error")
with self.task_lock:
if task_id in self.running_tasks:
task = self.running_tasks.pop(task_id)
task.end_time = time.time()
task.worker_id = worker_id # Ensure worker_id is set
if status == "COMPLETED":
task.status = "COMPLETED"
task.result = result
self.completed_tasks[task_id] = task
print(f"Task {task_id[:8]} completed by {worker_id[:8]}. Result: {result}")
elif status == "FAILED":
task.status = "FAILED"
task.result = {"error": error}
self.failed_tasks[task_id] = task
print(f"Task {task_id[:8]} failed by {worker_id[:8]}. Error: {error}")
with self.worker_lock:
if worker_id in self.worker_registry:
self.worker_registry[worker_id]["active_tasks"] -= 1
else:
print(f"Received result for unknown/stale task {task_id[:8]} from {worker_id[:8]}.")
elif msg_type == "HEARTBEAT":
# Update last_seen from TCP heartbeat, more reliable than UDP for active workers
with self.worker_lock:
if worker_id in self.worker_registry:
self.worker_registry[worker_id]["last_seen"] = time.time()
self.worker_registry[worker_id]["info"].update(msg_obj.get("info", {})) # Update resource info
else:
print(f"Unknown message type: {msg_type}")
except json.JSONDecodeError:
print(f"Received malformed JSON from worker {conn_socket.getpeername()}: {message}")
except Exception as e:
print(f"Error processing worker message: {e}")
def _remove_worker_connection(self, conn_socket):
worker_id_to_remove = None
with self.worker_lock:
for worker_id, data in list(self.worker_registry.items()):
if data["socket"] == conn_socket:
worker_id_to_remove = worker_id
break
if worker_id_to_remove:
print(f"Worker {worker_id_to_remove[:8]} disconnected.")
self.worker_registry[worker_id_to_remove]["is_connected"] = False
self.worker_registry[worker_id_to_remove]["socket"] = None
self.inputs.remove(conn_socket)
# Re-queue tasks assigned to this worker
with self.task_lock:
for task_id, task in list(self.running_tasks.items()):
if task.worker_id == worker_id_to_remove:
print(f"Re-queueing task {task_id[:8]} due to worker disconnection.")
task.status = "PENDING"
task.worker_id = None
task.retries += 1
self.task_queue.put(task)
del self.running_tasks[task_id]
def _check_worker_timeouts(self):
current_time = time.time()
with self.worker_lock:
for worker_id, data in list(self.worker_registry.items()):
if data["is_connected"] and (current_time - data["last_seen"] > WORKER_TIMEOUT):
print(f"Worker {worker_id[:8]} timed out (no heartbeat). Marking as disconnected.")
self.worker_registry[worker_id]["is_connected"] = False
# No need to close socket here, it might be handled by select if it truly disconnected
# Re-queue tasks assigned to this worker
with self.task_lock:
for task_id, task in list(self.running_tasks.items()):
if task.worker_id == worker_id:
print(f"Re-queueing task {task_id[:8]} due to worker timeout.")
task.status = "PENDING"
task.worker_id = None
task.retries += 1
self.task_queue.put(task)
del self.running_tasks[task_id]
def _monitor_tasks(self):
while True:
current_time = time.time()
with self.task_lock:
for task_id, task in list(self.running_tasks.items()):
if task.start_time and (current_time - task.start_time > TASK_EXECUTION_TIMEOUT):
print(f"Task {task_id[:8]} timed out on worker {task.worker_id[:8]}. Re-queueing.")
task.status = "PENDING"
task.worker_id = None
task.retries += 1
self.task_queue.put(task)
del self.running_tasks[task_id]
with self.worker_lock:
if task.worker_id in self.worker_registry:
self.worker_registry[task.worker_id]["active_tasks"] -= 1
time.sleep(5) # Check every 5 seconds
def _task_scheduler(self):
while True:
if not self.task_queue.empty():
worker = self._select_worker()
if worker:
task = self.task_queue.get()
self._dispatch_task(task, worker)
else:
time.sleep(1) # No workers available, wait a bit
else:
time.sleep(1) # No tasks, wait a bit
def _select_worker(self):
"""
负载均衡策略:选择空闲任务数最少,且CPU利用率相对较低的可用工作节点。
"""
eligible_workers = []
with self.worker_lock:
for worker_id, data in self.worker_registry.items():
if data["is_connected"] and data["socket"] and data["info"]["status"] == "available":
# Filter out workers that are too busy (e.g., CPU > 90%)
if data["info"]["cpu_percent"] < 90:
eligible_workers.append(data)
if not eligible_workers:
return None
# Sort by active_tasks (ascending) then by cpu_percent (ascending)
eligible_workers.sort(key=lambda x: (x["active_tasks"], x["info"]["cpu_percent"]))
return eligible_workers[0] # Select the best worker
def _dispatch_task(self, task, worker_data):
worker_id = worker_data["info"]["worker_id"]
worker_socket = worker_data["socket"]
try:
message = json.dumps({
"type": "EXECUTE_TASK",
"task": Task.from_json(task.to_json()).__dict__ # Send raw dict to avoid circular deps
}).encode('utf-8')
worker_socket.sendall(message)
task.status = "RUNNING"
task.worker_id = worker_id
task.start_time = time.time()
with self.task_lock:
self.running_tasks[task.task_id] = task
with self.worker_lock:
self.worker_registry[worker_id]["active_tasks"] += 1
print(f"Dispatched task {task.task_id[:8]} to worker {worker_id[:8]}.")
except Exception as e:
print(f"Failed to dispatch task {task.task_id[:8]} to worker {worker_id[:8]}: {e}. Re-queueing.")
task.status = "PENDING"
task.worker_id = None
task.retries += 1
self.task_queue.put(task) # Put back to queue
# Note: active_tasks for this worker might be wrongly incremented if sendall fails mid-way.
# A more robust solution might involve an ACK from worker.
def add_task(self, function_name, *args, **kwargs):
task = Task(function_name, args, kwargs)
self.task_queue.put(task)
print(f"Added new task {task.task_id[:8]} to queue.")
return task.task_id
def _display_status(self):
while True:
time.sleep(5)
print("n--- Master Status ---")
print(f"Pending tasks: {self.task_queue.qsize()}")
print(f"Running tasks: {len(self.running_tasks)}")
print(f"Completed tasks: {len(self.completed_tasks)}")
print(f"Failed tasks: {len(self.failed_tasks)}")
print("nConnected Workers:")
with self.worker_lock:
for worker_id, data in self.worker_registry.items():
if data["is_connected"]:
cpu = data["info"].get("cpu_percent", "N/A")
mem = data["info"].get("memory_percent", "N/A")
print(f" ID: {worker_id[:8]}, IP: {data['info']['worker_ip']}, CPU: {cpu}%, Mem: {mem}%, Active Tasks: {data['active_tasks']}, Last Seen: {time.time() - data['last_seen']:.2f}s ago")
print("--------------------n")
if __name__ == "__main__":
master = MasterNode()
# Example: Add some tasks
master.add_task("heavy_computation", 5000000)
master.add_task("another_task", data_size=2000)
master.add_task("heavy_computation", 7000000)
master.add_task("heavy_computation", 3000000)
master.add_task("another_task", data_size=1500)
master.add_task("heavy_computation", 6000000)
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
print("Master shutting down.")
Worker 端 (worker_node.py):
import socket
import time
import json
import psutil
import threading
from task_definitions import TASK_FUNCTIONS # Worker需要执行任务的函数定义
BROADCAST_PORT = 50000
MASTER_PORT = 50001
WORKER_ID = f"worker_{socket.gethostname()}_{time.time_ns()}"
class WorkerNode:
def __init__(self):
self.master_ip = None
self.master_conn = None
self.is_connected_to_master = False
self.worker_id = WORKER_ID
self.worker_ip = socket.gethostbyname(socket.gethostname())
self.active_tasks = threading.Semaphore(psutil.cpu_count(logical=False)) # Limit concurrent tasks to physical cores
self.task_threads = {} # To keep track of running task threads
# Start background threads
threading.Thread(target=self._start_discovery_broadcast, daemon=True).start()
threading.Thread(target=self._connect_to_master, daemon=True).start()
threading.Thread(target=self._send_heartbeats, daemon=True).start()
threading.Thread(target=self._listen_for_master_commands, daemon=True).start()
def get_system_info(self):
"""获取系统资源信息"""
cpu_percent = psutil.cpu_percent(interval=1)
memory_info = psutil.virtual_memory()
disk_info = psutil.disk_usage('/')
return {
"cpu_percent": cpu_percent,
"memory_percent": memory_info.percent,
"disk_percent": disk_info.percent,
"worker_id": self.worker_id,
"worker_ip": self.worker_ip,
"status": "available" if self.is_connected_to_master else "disconnected"
}
def _start_discovery_broadcast(self):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
print(f"Worker {self.worker_id[:8]} starting discovery broadcast...")
while True:
try:
sys_info = self.get_system_info()
message = json.dumps(sys_info).encode('utf-8')
sock.sendto(message, ('<broadcast>', BROADCAST_PORT))
# print(f"Sent discovery message: {sys_info['cpu_percent']}% CPU")
time.sleep(5)
except Exception as e:
print(f"Error during discovery broadcast: {e}")
time.sleep(5)
def _connect_to_master(self):
while not self.is_connected_to_master:
# First, try to find master IP from broadcast (simplified for now, assumes master IP is known or discovered via broadcast)
# In a real scenario, worker might listen for a specific master broadcast or connect to a known IP.
# For this example, let's assume master IP is localhoset for testing
if not self.master_ip:
self.master_ip = "127.0.0.1" # For testing locally, replace with actual master IP or implement discovery
print(f"Worker {self.worker_id[:8]} attempting to connect to master at {self.master_ip}:{MASTER_PORT}...")
try:
self.master_conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.master_conn.connect((self.master_ip, MASTER_PORT))
self.master_conn.setblocking(True) # Blocking for receiving full messages
# Register with master
register_msg = json.dumps({"type": "REGISTER", "worker_id": self.worker_id}).encode('utf-8')
self.master_conn.sendall(register_msg)
# Wait for ACK
response = self.master_conn.recv(1024).decode('utf-8')
if json.loads(response).get("type") == "REGISTER_ACK":
self.is_connected_to_master = True
print(f"Worker {self.worker_id[:8]} successfully connected and registered with master.")
else:
raise Exception("Master registration failed.")
except Exception as e:
print(f"Failed to connect/register with master: {e}. Retrying in 5 seconds.")
self.is_connected_to_master = False
if self.master_conn:
self.master_conn.close()
time.sleep(5)
def _send_heartbeats(self):
while True:
if self.is_connected_to_master:
try:
sys_info = self.get_system_info()
heartbeat_msg = json.dumps({
"type": "HEARTBEAT",
"worker_id": self.worker_id,
"info": sys_info
}).encode('utf-8')
self.master_conn.sendall(heartbeat_msg)
# print(f"Sent heartbeat from {self.worker_id[:8]}")
except Exception as e:
print(f"Error sending heartbeat: {e}. Reconnecting to master.")
self.is_connected_to_master = False
if self.master_conn:
self.master_conn.close()
# Trigger reconnection in _connect_to_master thread
time.sleep(2) # Send heartbeat every 2 seconds
def _listen_for_master_commands(self):
buffer = ""
while True:
if not self.is_connected_to_master:
time.sleep(1)
continue
try:
data = self.master_conn.recv(4096).decode('utf-8')
if not data: # Master disconnected
print("Master disconnected. Attempting to reconnect.")
self.is_connected_to_master = False
if self.master_conn:
self.master_conn.close()
continue
# Simple message framing: assume each message is a complete JSON object
# For robust system, you'd need length-prefixed messages or a delimiter.
buffer += data
while "{" in buffer and "}" in buffer:
start = buffer.find("{")
end = buffer.find("}", start)
if start != -1 and end != -1:
try:
message_str = buffer[start : end + 1]
message = json.loads(message_str)
self._handle_master_command(message)
buffer = buffer[end + 1:] # Remove processed message
except json.JSONDecodeError:
# Malformed JSON, might be fragmented, try to find next valid JSON
print(f"Malformed JSON received: {message_str}. Skipping.")
buffer = buffer[end + 1:]
else:
break # No complete JSON object found
except Exception as e:
print(f"Error listening for master commands: {e}. Reconnecting.")
self.is_connected_to_master = False
if self.master_conn:
self.master_conn.close()
time.sleep(1)
def _handle_master_command(self, command):
command_type = command.get("type")
if command_type == "EXECUTE_TASK":
task_data = command.get("task")
task_id = task_data["task_id"]
function_name = task_data["function_name"]
args = task_data["args"]
kwargs = task_data["kwargs"]
print(f"Received task {task_id[:8]} ({function_name}) from master.")
# Use a separate thread to execute the task to avoid blocking the main worker loop
task_thread = threading.Thread(target=self._execute_task,
args=(task_id, function_name, args, kwargs))
task_thread.daemon = True # Allow main program to exit even if tasks are running
self.task_threads[task_id] = task_thread
task_thread.start()
else:
print(f"Unknown master command: {command_type}")
def _execute_task(self, task_id, function_name, args, kwargs):
self.active_tasks.acquire() # Acquire a slot to run a task
try:
if function_name in TASK_FUNCTIONS:
task_func = TASK_FUNCTIONS[function_name]
result = task_func(task_id, *args, **kwargs)
self._report_task_result(task_id, "COMPLETED", result)
else:
error_msg = f"Unknown function: {function_name}"
print(f"[{task_id[:8]}] Error: {error_msg}")
self._report_task_result(task_id, "FAILED", error=error_msg)
except Exception as e:
error_msg = f"Task execution failed: {e}"
print(f"[{task_id[:8]}] Error: {error_msg}")
self._report_task_result(task_id, "FAILED", error=error_msg)
finally:
self.active_tasks.release() # Release the slot
if task_id in self.task_threads:
del self.task_threads[task_id]
def _report_task_result(self, task_id, status, result=None, error=None):
if not self.is_connected_to_master:
print(f"Cannot report result for task {task_id[:8]}. Not connected to master.")
return
report_msg = {
"type": "TASK_RESULT",
"worker_id": self.worker_id,
"task_id": task_id,
"status": status,
}
if result:
report_msg["result"] = result
if error:
report_msg["error"] = error
try:
self.master_conn.sendall(json.dumps(report_msg).encode('utf-8'))
print(f"Reported task {task_id[:8]} as {status} to master.")
except Exception as e:
print(f"Failed to report task result for {task_id[:8]}: {e}. Master connection might be lost.")
if __name__ == "__main__":
worker = WorkerNode()
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
print("Worker shutting down.")
要运行这个系统:
- 确保
task_definitions.py文件在Master和所有Worker节点上都存在。 - 首先运行
python master_node.py。 - 然后在一个或多个终端(或不同机器上)运行
python worker_node.py。 - 观察Master的输出,它会发现Worker,分配任务,并收集结果。Worker会执行任务并报告状态。
通信协议设计 (简化):
我们在这里使用了一个简化的基于JSON的TCP协议。为了提高鲁棒性,实际生产系统会采用更复杂的协议,例如:
- 消息长度前缀: 在每个JSON消息前面加上一个固定长度的字段,表示消息体的字节数。接收方先读取长度,再读取相应字节数的消息体。
- 消息分隔符: 使用一个特殊字符序列作为消息之间的分隔符。
- 心跳与超时: 更加严谨的心跳机制和任务执行超时检查。
6. 数据传输与存储策略
对于“思维链”计算,任务往往需要输入数据,并产生输出数据。高效的数据传输是避免分布式系统瓶颈的关键。
- 任务内嵌数据: 对于小数据量,可以直接将数据序列化后作为任务参数的一部分传输。
- 共享存储:
- 网络文件系统 (NFS/SMB): 所有节点挂载同一个共享文件系统。任务只需传递文件路径,数据读写在本地进行,避免了网络传输开销。适用于数据量大、且所有节点都可以访问共享存储的场景。
- 分布式文件系统 (HDFS/Ceph): 更为复杂的解决方案,适用于大规模集群。
- 点对点传输: 工作节点之间直接传输数据,而不是通过主控节点中转。需要更复杂的协调机制。
- 数据缓存: 常用数据在工作节点本地缓存,减少重复传输。
- 数据库/消息队列: 将任务输入/输出存储在共享数据库(如Redis, MongoDB)或消息队列中,任务只传递数据的ID。
在我们的Master-Worker示例中,我们假设任务的参数(如iterations, data_size)是足够小的,可以直接通过JSON传输。如果需要处理大数据,例如一个1GB的CSV文件,那么将文件路径作为参数,并依赖共享存储(如一个NFS挂载)会是更高效的选择。
7. 容错性与可靠性
分布式系统最困难的部分之一就是处理故障。
- 工作节点故障:
- 心跳检测: Master通过心跳机制检测到Worker离线。
- 任务重试: Master将分配给故障Worker的未完成任务重新放入任务队列,分配给其他可用Worker。
- 最大重试次数: 设置任务的最大重试次数,避免无限重试导致资源浪费。
- 主控节点故障:
- 单点故障: 在Master-Worker架构中,主控节点是单点故障。
- 高可用性: 可以通过主备模式(Active-Standby)或集群模式(选举新的Master)来解决。这通常需要更复杂的协调服务(如ZooKeeper, Consul)。
- 网络分区: 网络故障可能导致部分节点无法与Master通信,但它们自身可能仍在运行。这需要智能的故障检测和恢复策略。
- 幂等性: 设计任务时,尽量使其具有幂等性。即多次执行同一个任务,产生的结果是相同的。这对于任务重试至关重要。
在我们的示例中,Master已经实现了对Worker超时和任务超时的基本处理,会将任务重新排队。
8. 扩展与优化:走向更专业的解决方案
我们构建的Master-Worker系统是一个良好的起点,但对于更复杂的场景,可以考虑引入专业的库和框架:
8.1 Python 生态系统中的选择
| 框架/库 | 描述 | 适用场景 |
|---|---|---|
multiprocessing.managers |
Python标准库,用于创建共享对象,实现进程间的通信。 | 简单的工作队列,共享状态,适用于在同一台机器上跨进程或有限的本地网络。 |
concurrent.futures |
Python标准库,提供了高级的异步执行接口,包括ThreadPoolExecutor和ProcessPoolExecutor。 |
简化并行/并发任务的提交和结果收集,主要用于单机多核或多线程。 |
Celery |
一个成熟的分布式任务队列,支持多种消息代理 (Redis, RabbitMQ) 和结果存储。 | 需要可靠的任务队列、定时任务、任务链、灵活的任务路由和监控。 |
Dask |
专注于大数据分析的并行计算库,提供了并行化的DataFrame, Array等数据结构。 | 处理超出单机内存的大数据集,科学计算,机器学习。 |
Ray |
一个通用分布式计算框架,旨在简化分布式应用程序的开发,尤其在AI/ML领域。 | 分布式强化学习、超参数调优、复杂AI模型部署,支持动态任务图。 |
ZeroMQ |
一个高性能异步消息库,提供了多种消息模式(Pub/Sub, Req/Rep, Push/Pull)。 | 对消息吞吐量和延迟要求极高的场景,构建自定义消息协议。 |
使用 Celery 简化任务队列 (以 Redis 为 Broker):
Celery 是一个非常强大的选择,它将任务队列、工作节点管理、结果存储等功能封装得很好。
1. 安装 Celery 和 Redis 客户端:
pip install celery redis
2. tasks.py (所有节点共享):
from celery import Celery
import time
import math
import random
# 配置Celery,使用Redis作为消息代理和结果存储
app = Celery('my_distributed_app',
broker='redis://localhost:6379/0',
backend='redis://localhost:6379/0')
@app.task
def heavy_computation_task(iterations):
"""一个模拟耗时计算的函数"""
print(f"[Celery Task] Starting heavy computation with {iterations} iterations...")
result = 0
for i in range(iterations):
result += math.sin(i) * math.cos(math.sqrt(i + random.random()))
if i % (iterations // 10 if iterations > 10 else 1) == 0:
time.sleep(0.001)
print(f"[Celery Task] Finished heavy computation. Result snippet: {result:.4f}")
return {"status": "completed", "result_snippet": result}
@app.task
def another_task(data_size):
"""另一个简单的任务"""
print(f"[Celery Task] Starting another task with data size {data_size}...")
time.sleep(data_size / 1000)
result = f"Processed {data_size} units."
print(f"[Celery Task] Finished another task. Result: {result}")
return {"status": "completed", "result": result}
3. 启动 Redis 服务器。
4. 启动 Celery 工作节点 (在不同终端或机器上运行):
celery -A tasks worker -l info
5. 提交任务 (在另一个终端运行):
from tasks import heavy_computation_task, another_task
import time
print("Submitting tasks via Celery...")
# 提交一个耗时计算任务
result1 = heavy_computation_task.delay(5000000)
print(f"Submitted heavy_computation_task. Task ID: {result1.id}")
# 提交另一个任务
result2 = another_task.delay(data_size=2500)
print(f"Submitted another_task. Task ID: {result2.id}")
# 提交更多任务
for i in range(3):
heavy_computation_task.delay(random.randint(1000000, 8000000))
print("nWaiting for results (this is blocking, in real app, you'd check asynchronously)...")
# 获取结果,会阻塞直到任务完成
print(f"Result 1 status: {result1.status}, value: {result1.get(timeout=300)}")
print(f"Result 2 status: {result2.status}, value: {result2.get(timeout=300)}")
print("All tasks submitted and results retrieved.")
Celery 极大地简化了分布式任务的管理,提供了自动重试、任务优先级、监控等高级功能。
8.2 消息队列作为通信骨干
除了Celery,直接使用消息队列 (如Redis Streams/PubSub, RabbitMQ, Kafka) 作为 Master 和 Worker 之间的通信骨干,也是一种常见且强大的模式。
使用 Redis 作为简单任务队列:
Master (producer) 端:
import redis
import json
import uuid
import time
r = redis.Redis(host='localhost', port=6379, db=0)
def add_task_to_redis_queue(function_name, *args, **kwargs):
task_id = str(uuid.uuid4())
task_payload = {
"task_id": task_id,
"function_name": function_name,
"args": args,
"kwargs": kwargs
}
r.rpush("task_queue", json.dumps(task_payload)) # Push to right end of list
print(f"Added task {task_id[:8]} to Redis queue.")
return task_id
if __name__ == "__main__":
print("Master (producer) adding tasks to Redis...")
add_task_to_redis_queue("heavy_computation", 6000000)
add_task_to_redis_queue("another_task", data_size=3000)
add_task_to_redis_queue("heavy_computation", 8000000)
time.sleep(1)
print("Tasks added.")
Worker (consumer) 端:
import redis
import json
import time
import threading
import psutil
from task_definitions import TASK_FUNCTIONS # Worker需要执行任务的函数定义
r = redis.Redis(host='localhost', port=6379, db=0)
WORKER_ID = f"redis_worker_{socket.gethostname()}_{time.time_ns()}"
def get_system_info():
cpu_percent = psutil.cpu_percent(interval=1)
memory_info = psutil.virtual_memory()
return {
"worker_id": WORKER_ID,
"cpu_percent": cpu_percent,
"memory_percent": memory_info.percent,
"timestamp": time.time()
}
def execute_redis_task(task_payload):
task_id = task_payload["task_id"]
function_name = task_payload["function_name"]
args = task_payload["args"]
kwargs = task_payload["kwargs"]
print(f"[{WORKER_ID[:8]}] Executing task {task_id[:8]} ({function_name})...")
try:
if function_name in TASK_FUNCTIONS:
task_func = TASK_FUNCTIONS[function_name]
result = task_func(task_id, *args, **kwargs)
r.hset(f"task_results:{task_id}", "status", "COMPLETED")
r.hset(f"task_results:{task_id}", "result", json.dumps(result))
else:
error_msg = f"Unknown function: {function_name}"
r.hset(f"task_results:{task_id}", "status", "FAILED")
r.hset(f"task_results:{task_id}", "error", error_msg)
except Exception as e:
error_msg = f"Task execution failed: {e}"
r.hset(f"task_results:{task_id}", "status", "FAILED")
r.hset(f"task_results:{task_id}", "error", error_msg)
print(f"[{WORKER_ID[:8]}] Finished task {task_id[:8]}.")
def worker_loop():
print(f"Worker {WORKER_ID[:8]} starting to poll Redis queue...")
while True:
# BLPOP is blocking pop, waits for an item to appear in the list
_, task_json = r.blpop("task_queue", timeout=5)
if task_json:
task_payload = json.loads(task_json.decode('utf-8'))
threading.Thread(target=execute_redis_task, args=(task_payload,), daemon=True).start()
# Send heartbeat (simplified)
r.hset("worker_heartbeats", WORKER_ID, json.dumps(get_system_info()))
if __name__ == "__main__":
worker_loop()
这种基于Redis的任务队列模式,解耦了生产者和消费者,提高了系统的弹性和可伸缩性。Master只需将任务推入队列,Worker则从队列中拉取任务,无需直接通信。Redis还可以用于存储工作节点注册信息和任务结果。
9. 赋能本地计算资源:展望
我们今天探讨的本地计算负载均衡,不仅仅是技术上的挑战,更是一种思维方式的转变。它鼓励我们去审视和利用身边那些常常被忽视的计算资源。想象一下,一个办公室的十几台PC,在下班后或午休时,可以自动加入到一个计算网络中,为某个大型模拟或数据分析项目贡献算力,这无疑将极大地提升效率并节约成本。
未来,随着边缘计算和个人AI模型训练的需求增长,这种本地化的分布式计算模式将变得更加普遍。它将赋能个人用户、小型团队以及对数据隐私有严格要求的组织,让他们在不依赖昂贵云计算服务的前提下,也能处理复杂的计算任务。我们已经看到了如何从零开始搭建一个 Master-Worker 系统,以及如何利用像 Celery 或 Redis 这样的专业工具来构建更健壮、可扩展的解决方案。
充分利用本地计算潜力,让每一份闲置算力都发挥价值,这正是我们今天讨论的“计算负载均衡”所要达成的目标。