Python强化学习框架的Actor-Critic模型实现:并行采样与分布式梯度更新策略

Python强化学习框架Actor-Critic模型实现:并行采样与分布式梯度更新策略

大家好,今天我们来深入探讨Actor-Critic模型在Python强化学习框架中的实现,重点聚焦于并行采样和分布式梯度更新策略。Actor-Critic方法是强化学习中一类非常强大的算法,它结合了策略梯度(Policy Gradient)方法的优点和时序差分(Temporal Difference, TD)学习的优势。策略梯度方法擅长处理连续动作空间,但方差较高;TD学习方法学习效率高,但容易受到环境偏差的影响。Actor-Critic模型通过Actor学习策略,Critic评估策略的价值,从而实现更稳定和高效的学习过程。

1. Actor-Critic模型基础

Actor-Critic模型由两部分组成:

  • Actor (策略网络): 负责学习策略π(a|s),即在给定状态s下采取动作a的概率。Actor的目标是最大化期望回报。
  • Critic (价值网络): 负责评估当前策略的价值函数V(s)或Q(s, a)。Critic的目标是准确估计策略的价值,为Actor提供指导。

Actor-Critic模型的训练过程通常如下:

  1. 采样: Actor根据当前策略π(a|s)与环境交互,生成一系列状态-动作-奖励-下一个状态的样本 (s, a, r, s’)。
  2. 评估: Critic使用TD学习方法(例如SARSA或Q-learning)更新价值函数V(s)或Q(s, a)。
  3. 更新: Actor使用Critic提供的价值估计作为反馈,调整策略π(a|s)。

常见的Actor-Critic算法包括:

  • A2C (Advantage Actor-Critic): 使用优势函数A(s, a) = Q(s, a) – V(s)来降低方差。
  • A3C (Asynchronous Advantage Actor-Critic): 使用多个Actor-Critic agent并行与环境交互,异步更新全局网络。
  • DDPG (Deep Deterministic Policy Gradient): 用于连续动作空间,Actor输出确定性动作,Critic评估动作的价值。
  • TD3 (Twin Delayed Deep Deterministic Policy Gradient): 对DDPG的改进,使用两个Critic网络来降低Q值估计的偏差。
  • SAC (Soft Actor-Critic): 引入熵正则化,鼓励探索,学习更鲁棒的策略。

2. 并行采样策略

传统的强化学习算法通常使用单个agent与环境交互,效率较低。并行采样策略通过多个agent同时与环境交互,收集更多的训练数据,从而加速学习过程。

2.1 多进程采样

使用Python的multiprocessing库可以实现多进程采样。每个进程维护一个独立的Actor-Critic agent和环境副本。

import multiprocessing
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# 定义Actor网络
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_dim) # 假设动作是连续的

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        action = torch.tanh(self.fc3(x)) # 将动作限制在[-1, 1]
        return action

# 定义Critic网络
class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 1)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        value = self.fc3(x)
        return value

# 定义采样函数
def collect_samples(actor, critic, env, num_episodes, gamma, actor_optimizer, critic_optimizer):
    episodes = []
    for _ in range(num_episodes):
        state = env.reset()
        episode = []
        done = False
        while not done:
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            action = actor(state_tensor).detach().numpy()[0]
            next_state, reward, done, _ = env.step(action)
            episode.append((state, action, reward, next_state, done))
            state = next_state

        # 计算TD误差并更新网络
        states, actions, rewards, next_states, dones = zip(*episode)
        states = torch.FloatTensor(np.array(states))
        actions = torch.FloatTensor(np.array(actions))
        rewards = torch.FloatTensor(np.array(rewards))
        next_states = torch.FloatTensor(np.array(next_states))
        dones = torch.FloatTensor(np.array(dones))

        values = critic(states, actions).squeeze()
        next_actions = actor(next_states).detach()
        next_values = critic(next_states, next_actions).squeeze()
        td_targets = rewards + gamma * next_values * (1 - dones)
        td_errors = td_targets - values

        # 更新Critic网络
        critic_optimizer.zero_grad()
        critic_loss = torch.mean(td_errors ** 2)
        critic_loss.backward()
        critic_optimizer.step()

        # 更新Actor网络
        actor_optimizer.zero_grad()
        policy_loss = -torch.mean(td_errors.detach() * critic(states, actions).squeeze()) # 使用detach()避免梯度传播到critic
        policy_loss.backward()
        actor_optimizer.step()

        episodes.append(episode)
    return episodes

# 主函数
if __name__ == '__main__':
    env_name = "Pendulum-v1" # 使用Pendulum环境,需要安装gym
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    actor = Actor(state_dim, action_dim)
    critic = Critic(state_dim, action_dim)

    actor_optimizer = optim.Adam(actor.parameters(), lr=1e-3)
    critic_optimizer = optim.Adam(critic.parameters(), lr=1e-3)
    gamma = 0.99

    num_processes = 4 # 进程数量
    num_episodes_per_process = 10 # 每个进程采样的episode数量

    processes = []
    for i in range(num_processes):
        # 为每个进程创建一个新的环境实例
        env_process = gym.make(env_name)
        # 创建一个本地actor和critic的副本
        actor_process = Actor(state_dim, action_dim)
        critic_process = Critic(state_dim, action_dim)
        # 加载主actor和critic的参数
        actor_process.load_state_dict(actor.state_dict())
        critic_process.load_state_dict(critic.state_dict())
        process = multiprocessing.Process(target=collect_samples, args=(actor_process, critic_process, env_process, num_episodes_per_process, gamma, actor_optimizer, critic_optimizer))
        processes.append(process)
        process.start()

    for process in processes:
        process.join()

    print("Training complete!")

代码解释:

  • ActorCritic 类定义了策略网络和价值网络,使用了简单的全连接层。
  • collect_samples 函数负责与环境交互,收集样本,并计算TD误差,然后更新Actor和Critic网络。
  • multiprocessing.Process 创建多个进程,每个进程运行 collect_samples 函数。
  • 需要注意的是,在多进程中,每个进程都有自己独立的Actor和Critic网络副本。 为了简化,这里每个进程使用独立的optimizer,实际应用中,可以使用共享内存来更新主网络的参数,或者使用异步更新策略。
  • 在启动每个进程时,我们首先创建了本地的Actor和Critic网络的副本,然后使用load_state_dict方法将主网络的参数复制到副本。这样可以确保每个进程都从相同的初始参数开始训练。
  • 每个进程完成后,其训练结果(梯度更新)并不会自动同步到主网络。你需要设计一种机制来合并这些梯度更新,例如使用共享内存或消息队列。
  • 重要提示: 在实际应用中,直接在子进程中使用主进程的优化器(例如 actor_optimizercritic_optimizer)通常是不安全的,因为它们可能不是进程安全的。更好的方法是让每个子进程计算梯度,然后将梯度发送回主进程,由主进程来更新主网络的参数。

2.2 使用torch.multiprocessing

PyTorch提供了torch.multiprocessing模块,可以更方便地进行并行计算。它支持共享内存,可以在多个进程之间共享Tensor。

import torch.multiprocessing as mp
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# 定义Actor网络
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_dim) # 假设动作是连续的

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        action = torch.tanh(self.fc3(x)) # 将动作限制在[-1, 1]
        return action

# 定义Critic网络
class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 1)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        value = self.fc3(x)
        return value

# 定义采样函数
def collect_samples(actor, critic, env, num_episodes, gamma, actor_optimizer, critic_optimizer, actor_state_dict, critic_state_dict, lock):
    for _ in range(num_episodes):
        state = env.reset()
        episode = []
        done = False
        while not done:
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            # 使用共享内存中的网络参数
            with lock:
                actor.load_state_dict(actor_state_dict.copy())
            action = actor(state_tensor).detach().numpy()[0]
            next_state, reward, done, _ = env.step(action)
            episode.append((state, action, reward, next_state, done))
            state = next_state

        # 计算TD误差并更新网络
        states, actions, rewards, next_states, dones = zip(*episode)
        states = torch.FloatTensor(np.array(states))
        actions = torch.FloatTensor(np.array(actions))
        rewards = torch.FloatTensor(np.array(rewards))
        next_states = torch.FloatTensor(np.array(next_states))
        dones = torch.FloatTensor(np.array(dones))

        values = critic(states, actions).squeeze()
        with lock:
            actor.load_state_dict(actor_state_dict.copy())
        next_actions = actor(next_states).detach()
        next_values = critic(next_states, next_actions).squeeze()
        td_targets = rewards + gamma * next_values * (1 - dones)
        td_errors = td_targets - values

        # 更新Critic网络
        critic_optimizer.zero_grad()
        critic_loss = torch.mean(td_errors ** 2)
        critic_loss.backward()
        critic_optimizer.step()

        # 更新Actor网络
        actor_optimizer.zero_grad()
        policy_loss = -torch.mean(td_errors.detach() * critic(states, actions).squeeze()) # 使用detach()避免梯度传播到critic
        policy_loss.backward()
        actor_optimizer.step()

        # 更新共享内存中的网络参数
        with lock:
            for k, v in actor.state_dict().items():
                actor_state_dict[k] = v

            for k, v in critic.state_dict().items():
                critic_state_dict[k] = v
# 主函数
if __name__ == '__main__':
    mp.set_start_method('spawn') # 强制使用'spawn'启动方法,避免共享内存问题

    env_name = "Pendulum-v1" # 使用Pendulum环境,需要安装gym
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    actor = Actor(state_dim, action_dim)
    critic = Critic(state_dim, action_dim)

    # 将网络参数放入共享内存
    actor_state_dict = mp.Manager().dict()
    critic_state_dict = mp.Manager().dict()
    for k, v in actor.state_dict().items():
        actor_state_dict[k] = v
    for k, v in critic.state_dict().items():
        critic_state_dict[k] = v

    actor_optimizer = optim.Adam(actor.parameters(), lr=1e-3)
    critic_optimizer = optim.Adam(critic.parameters(), lr=1e-3)
    gamma = 0.99

    num_processes = 4 # 进程数量
    num_episodes_per_process = 10 # 每个进程采样的episode数量
    lock = mp.Lock() # 创建一个锁,用于同步对共享内存的访问

    processes = []
    for i in range(num_processes):
        # 为每个进程创建一个新的环境实例
        env_process = gym.make(env_name)
        process = mp.Process(target=collect_samples, args=(actor, critic, env_process, num_episodes_per_process, gamma, actor_optimizer, critic_optimizer, actor_state_dict, critic_state_dict, lock))
        processes.append(process)
        process.start()

    for process in processes:
        process.join()

    print("Training complete!")

代码解释:

  • 使用了 mp.Manager().dict() 创建共享字典,用于存储 Actor 和 Critic 网络的参数。
  • mp.Lock() 创建一个锁,用于同步对共享内存的访问,避免数据竞争。
  • 每个进程在采样前,先从共享内存中加载最新的网络参数。
  • 每个进程在更新网络后,将更新后的参数写入共享内存。
  • mp.set_start_method('spawn') 强制使用 ‘spawn’ 启动方法。 在某些平台上,默认的启动方法(例如 ‘fork’)可能导致共享内存出现问题。 ‘spawn’ 方法会创建一个全新的 Python 进程,从而避免了这些问题。
  • 重要提示: 由于所有进程共享同一个优化器,这可能会导致不稳定的训练。 更安全的方法是每个进程计算梯度,然后将梯度发送回主进程,由主进程来更新主网络的参数。 可以使用 torch.distributed 模块来实现分布式梯度更新。

2.3 使用Ray

Ray是一个通用的分布式计算框架,可以方便地实现并行采样。

import ray
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# 定义Actor网络
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_dim) # 假设动作是连续的

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        action = torch.tanh(self.fc3(x)) # 将动作限制在[-1, 1]
        return action

# 定义Critic网络
class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 1)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        value = self.fc3(x)
        return value

@ray.remote
class Worker:
    def __init__(self, env_name):
        self.env = gym.make(env_name)
        self.actor = Actor(self.env.observation_space.shape[0], self.env.action_space.shape[0])
        self.critic = Critic(self.env.observation_space.shape[0], self.env.action_space.shape[0])
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=1e-3)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=1e-3)
        self.gamma = 0.99

    def collect_samples(self, num_episodes):
        episodes = []
        for _ in range(num_episodes):
            state = self.env.reset()
            episode = []
            done = False
            while not done:
                state_tensor = torch.FloatTensor(state).unsqueeze(0)
                action = self.actor(state_tensor).detach().numpy()[0]
                next_state, reward, done, _ = self.env.step(action)
                episode.append((state, action, reward, next_state, done))
                state = next_state

            # 计算TD误差并更新网络
            states, actions, rewards, next_states, dones = zip(*episode)
            states = torch.FloatTensor(np.array(states))
            actions = torch.FloatTensor(np.array(actions))
            rewards = torch.FloatTensor(np.array(rewards))
            next_states = torch.FloatTensor(np.array(next_states))
            dones = torch.FloatTensor(np.array(dones))

            values = self.critic(states, actions).squeeze()
            next_actions = self.actor(next_states).detach()
            next_values = self.critic(next_states, next_actions).squeeze()
            td_targets = rewards + self.gamma * next_values * (1 - dones)
            td_errors = td_targets - values

            # 更新Critic网络
            self.critic_optimizer.zero_grad()
            critic_loss = torch.mean(td_errors ** 2)
            critic_loss.backward()
            self.critic_optimizer.step()

            # 更新Actor网络
            self.actor_optimizer.zero_grad()
            policy_loss = -torch.mean(td_errors.detach() * self.critic(states, actions).squeeze()) # 使用detach()避免梯度传播到critic
            policy_loss.backward()
            self.actor_optimizer.step()

            episodes.append(episode)
        return episodes

    def get_weights(self):
        return self.actor.state_dict(), self.critic.state_dict()

    def set_weights(self, actor_weights, critic_weights):
        self.actor.load_state_dict(actor_weights)
        self.critic.load_state_dict(critic_weights)

if __name__ == '__main__':
    ray.init()

    env_name = "Pendulum-v1" # 使用Pendulum环境,需要安装gym
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    actor = Actor(state_dim, action_dim)
    critic = Critic(state_dim, action_dim)

    actor_optimizer = optim.Adam(actor.parameters(), lr=1e-3)
    critic_optimizer = optim.Adam(critic.parameters(), lr=1e-3)
    gamma = 0.99

    num_workers = 4
    num_episodes_per_worker = 10

    # 创建worker
    workers = [Worker.remote(env_name) for _ in range(num_workers)]

    # 初始化worker权重
    actor_weights, critic_weights = actor.state_dict(), critic.state_dict()
    for worker in workers:
        worker.set_weights.remote(actor_weights, critic_weights)

    # 并行采样
    results = [worker.collect_samples.remote(num_episodes_per_worker) for worker in workers]
    ray.wait(results)

    # 获取worker权重并更新主网络
    worker_weights = ray.get([worker.get_weights.remote() for worker in workers])
    # 简单的平均权重更新
    for i in range(len(actor_weights)):
      actor_weights[list(actor_weights.keys())[i]] = torch.mean(torch.stack([w[0][list(actor_weights.keys())[i]] for w in worker_weights]),dim=0)
    for i in range(len(critic_weights)):
      critic_weights[list(critic_weights.keys())[i]] = torch.mean(torch.stack([w[1][list(critic_weights.keys())[i]] for w in worker_weights]),dim=0)

    actor.load_state_dict(actor_weights)
    critic.load_state_dict(critic_weights)

    print("Training complete!")
    ray.shutdown()

代码解释:

  • 使用 @ray.remote 装饰器将 Worker 类转换为一个 Ray Actor。
  • Worker 类包含一个环境副本和 Actor-Critic 模型。
  • collect_samples 方法负责与环境交互,收集样本,并更新 Actor 和 Critic 网络。
  • ray.init() 初始化 Ray 集群。
  • Worker.remote() 创建一个远程 Actor 实例。
  • worker.collect_samples.remote() 异步调用 collect_samples 方法。
  • ray.wait() 等待所有任务完成。
  • ray.get() 获取任务结果。
  • 关键点:Ray Actor允许在独立的进程或机器上运行,并且可以异步地调用其方法。这使得我们可以轻松地实现并行采样。

3. 分布式梯度更新策略

在分布式环境中,每个agent计算的梯度需要进行聚合,然后更新全局网络。常见的分布式梯度更新策略包括:

  • 同步更新: 所有agent将梯度发送到中心服务器,服务器聚合梯度后更新全局网络,然后将更新后的网络参数发送给所有agent。
  • 异步更新: 每个agent独立计算梯度并更新全局网络,无需等待其他agent。

3.1 使用torch.distributed进行同步梯度更新

PyTorch提供了torch.distributed模块,可以方便地进行分布式训练。

import torch.distributed as dist
import torch.multiprocessing as mp
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# 定义Actor网络
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_dim) # 假设动作是连续的

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        action = torch.tanh(self.fc3(x)) # 将动作限制在[-1, 1]
        return action

# 定义Critic网络
class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 1)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        value = self.fc3(x)
        return value

def collect_samples(rank, world_size, actor, critic, env, num_episodes, gamma, actor_optimizer, critic_optimizer):
    for _ in range(num_episodes):
        state = env.reset()
        episode = []
        done = False
        while not done:
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            action = actor(state_tensor).detach().numpy()[0]
            next_state, reward, done, _ = env.step(action)
            episode.append((state, action, reward, next_state, done))
            state = next_state

        # 计算TD误差并更新网络
        states, actions, rewards, next_states, dones = zip(*episode)
        states = torch.FloatTensor(np.array(states))
        actions = torch.FloatTensor(np.array(actions))
        rewards = torch.FloatTensor(np.array(rewards))
        next_states = torch.FloatTensor(np.array(next_states))
        dones = torch.FloatTensor(np.array(dones))

        values = critic(states, actions).squeeze()
        next_actions = actor(next_states).detach()
        next_values = critic(next_states, next_actions).squeeze()
        td_targets = rewards + gamma * next_values * (1 - dones)
        td_errors = td_targets - values

        # 更新Critic网络
        critic_optimizer.zero_grad()
        critic_loss = torch.mean(td_errors ** 2)
        critic_loss.backward()
        for param in critic.parameters():
            dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
            param.grad.data /= world_size # 平均梯度
        critic_optimizer.step()

        # 更新Actor网络
        actor_optimizer.zero_grad()
        policy_loss = -torch.mean(td_errors.detach() * critic(states, actions).squeeze()) # 使用detach()避免梯度传播到critic
        policy_loss.backward()
        for param in actor.parameters():
            dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
            param.grad.data /= world_size # 平均梯度
        actor_optimizer.step()

def init_process(rank, world_size, fn, backend, env_name, num_episodes, gamma, learning_rate):
    dist.init_process_group(backend, rank=rank, world_size=world_size)

    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    actor = Actor(state_dim, action_dim)
    critic = Critic(state_dim, action_dim)

    actor_optimizer = optim.Adam(actor.parameters(), lr=learning_rate)
    critic_optimizer = optim.Adam(critic.parameters(), lr=learning_rate)

    fn(rank, world_size, actor, critic, env, num_episodes, gamma, actor_optimizer, critic_optimizer)

if __name__ == "__main__":
    mp.set_start_method('spawn')
    world_size = 4  # 使用的进程数量
    env_name = "Pendulum-v1"
    num_episodes = 10
    gamma = 0.99
    learning_rate = 1e-3
    backend = 'gloo'  # 可以选择 'gloo' 或 'nccl'

    processes = []
    for rank in range(world_size):
        p = mp.Process(target=init_process, args=(rank, world_size, collect_samples, backend, env_name, num_episodes, gamma, learning_rate))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()
    print("Training complete!")

代码解释:

  • dist.init_process_group() 初始化分布式环境。
  • dist.all_reduce() 将所有进程的梯度进行聚合(求和)。
  • 每个进程将聚合后的梯度除以进程数量,得到平均梯度。
  • 所有进程使用平均梯度更新自己的网络参数。
  • 重要提示: torch.distributed 需要正确配置分布式环境,例如设置 RANKWORLD_SIZE 环境变量。
  • backend可以选择gloo 或者 ncclgloo适合CPU,nccl适合GPU。
  • 需要每个进程都执行相同的代码。
  • 可以使用 torch.nn.parallel.DistributedDataParallel 封装模型,简化分布式训练。

3.2 使用参数服务器进行异步梯度更新

参数服务器是一种常见的分布式训练架构。它包含一个中心服务器,负责存储和更新全局模型参数,以及多个worker节点,负责计算梯度并更新参数服务器上的模型参数。

由于直接实现参数服务器比较复杂,这里提供一个伪代码的示例,展示参数服务器的基本原理:

参数服务器 (Parameter Server):

# 参数服务器
class ParameterServer:
    def __init__(self, model):
        self.model = model
        self.optimizer = optim.Adam(model.parameters(), lr=1e-3)
        self.lock = threading.Lock() # 线程锁

    def get_parameters(self):
        with self.lock:
            return [param.data.clone() for param in self.model.parameters()] # 返回模型参数的副本

    def update_parameters(self, gradients):
        with self.lock:
            # 将梯度应用到模型参数
            for param, grad in zip(self.model.parameters(), gradients):
                param.grad = grad
            self.optimizer.step()
            self.optimizer.zero_grad() # 清空梯度

Worker节点 (Worker Node):

# Worker节点
def worker(parameter_server, env, model):
    while True:
        # 1. 从参数服务器获取最新的模型参数
        parameters = parameter_server.get_parameters()
        for param, server_param in zip(model.parameters(), parameters):
            param.data = server_param.clone() # 使用服务器参数更新本地模型

        # 2. 使用本地模型与环境交互,计算梯度
        # (省略与环境交互和计算梯度的代码)
        gradients = calculate_gradients(model, env) # 假设这是一个计算梯度的函数

        # 3. 将梯度发送到参数服务器,更新模型参数
        parameter_server.update_parameters(gradients)

代码解释:

  • 参数服务器维护全局模型参数,并提供 get_parametersupdate_parameters 方法。
  • Worker节点定期从参数服务器获取最新的模型参数,并使用这些参数与环境交互。
  • Worker节点计算梯度后,将梯度发送到参数服务器,参数服务器使用这些梯度更新全局模型参数。
  • 使用线程锁保证对模型参数的并发访问安全。
  • 重要提示: 实际的参数服务器实现需要考虑更多的因素,例如梯度压缩、异步梯度更新、容错等。

4. Actor-Critic模型训练过程的表格总结

步骤 Actor (策略网络) Critic (价值网络)
1. 采样 根据当前策略 π(a|s) 选择动作 a /
2. 评估 / 使用 TD 学习更新价值函数 V(s) 或 Q(s, a)
3. 更新 使用 Critic 提供的价值估计作为反馈,调整策略 π(a|s) /
并行采样优化 多进程/多线程/Ray并行采样,加速数据收集 /
分布式梯度更新优化 同步/异步梯度聚合,更新全局网络参数 同步/异步梯度聚合,更新全局网络参数

5. 选择合适的策略:并行采样 vs. 分布式梯度更新

并行采样和分布式梯度更新是两种不同的策略,用于加速强化学习训练。

  • 并行采样主要用于加速数据收集过程。它通过多个agent同时与环境交互,收集更多的训练数据。这对于那些数据收集成本较高的环境非常有用。

  • 分布式梯度更新主要用于加速模型训练过程。它通过多个worker节点并行计算梯度,并将梯度聚合后更新全局模型参数。这对于那些模型较大、计算量较高的任务非常有用。

在实际应用中,可以根据具体的需求选择合适的策略。例如,如果环境交互成本较高,可以选择并行采样;如果模型较大、计算量较高,可以选择分布式梯度更新。 也可以将两种策略结合起来使用,以

更多IT精英技术系列讲座,到智猿学院

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注