SAC(Soft Actor Critic)学习记录

SAC(Soft Actor Critic)算法在近年来受到了许多的关注,得到了不少深度强化学习研究者的好评。这篇文章主要包含的内容有SAC算法的理论分析和核心代码实现。

m a x π θ [ ∑ t γ t ( r ( S t , A t ) + α H ( π θ ( ⋅ ∣ S t ) ) ) ] {max}_{\pi_{\theta}}\left[\sum_{t}\gamma^{t}\left(r(S_{t},A_{t})+\alpha\mathcal{H}(\pi_{\theta}(\cdot|S_{t}))\right)\right] maxπθ​​[t∑​γt(r(St​,At​)+αH(πθ​(⋅∣St​)))]
V ^ ϕ π θ ( s t ) ≡ ∣ E a t ∼ π θ ( . ∣ s t ) [ Q ^ ϕ π θ ( s t , a t ) ] \begin{array}{l}{{\hat{V}_{\phi}^{\pi}\theta\left(\mathbf{s}_{t}\right)\equiv\ |\mathbf{E}_{\mathbf{a}_{t}}\sim\pi_{\theta}(.|\mathbf{s}_{t})\ \left[\hat{Q}_{\phi}^{\pi}\theta\left(\mathbf{s}_{t},\mathbf{a}_{t}\right)\right]}}\end{array} V^ϕπ​θ(st​)≡ ∣Eat​​∼πθ​(.∣st​) [Q^​ϕπ​θ(st​,at​)]​
在SAC中我们使用soft update
V ^ ϕ π θ ( s t ) = E a t ∼ π θ ( . ∣ s t ) [ Q ^ ϕ π θ ( s t , a t ) ] + α H ( π θ ( . ∣ s t ) ) = E a t ∼ π θ ( . ∣ s t ) [ Q ^ ϕ π θ ( s t , a t ) ] + α E a t ∼ π θ ( . ∣ s t ) [ − log ⁡ π θ ( a t ∣ s t ) ] = E a t ∼ π θ ( . ∣ s t ) [ Q ^ ϕ π θ ( s t , a t ) − α log ⁡ π θ ( a t ∣ s t ) ] \begin{aligned} \hat{V}_{\phi}^{\pi_{\boldsymbol{\theta}}}\left(\mathbf{s}_{t}\right) &=\mathbb{E}_{\mathbf{a}_{t} \sim \pi_{\boldsymbol{\theta}}\left(. \mid \mathbf{s}_{t}\right)}\left[\hat{Q}_{\phi}^{\pi_{\theta}}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)\right]+\alpha \mathcal{H}\left(\pi_{\boldsymbol{\theta}}\left(. \mid \mathbf{s}_{t}\right)\right) \\ &=\mathbb{E}_{\mathbf{a}_{t} \sim \pi_{\boldsymbol{\theta}}\left(. \mid \mathbf{s}_{t}\right)}\left[\hat{Q}_{\phi}^{\pi_{\theta}}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)\right]+\alpha \mathbb{E}_{\mathbf{a}_{t} \sim \pi_{\boldsymbol{\theta}}\left(. \mid \mathbf{s}_{t}\right)}\left[-\log \pi_{\boldsymbol{\theta}}\left(\mathbf{a}_{t} \mid \mathbf{s}_{t}\right)\right] \\ &=\mathbb{E}_{\mathbf{a}_{t} \sim \pi_{\boldsymbol{\theta}}\left(. \mid \mathbf{s}_{t}\right)}\left[\hat{Q}_{\phi}^{\pi_{\theta}}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)-\alpha \log \pi_{\boldsymbol{\theta}}\left(\mathbf{a}_{t} \mid \mathbf{s}_{t}\right)\right] \end{aligned} V^ϕπθ​​(st​)​=Eat​∼πθ​(.∣st​)​[Q^​ϕπθ​​(st​,at​)]+αH(πθ​(.∣st​))=Eat​∼πθ​(.∣st​)​[Q^​ϕπθ​​(st​,at​)]+αEat​∼πθ​(.∣st​)​[−logπθ​(at​∣st​)]=Eat​∼πθ​(.∣st​)​[Q^​ϕπθ​​(st​,at​)−αlogπθ​(at​∣st​)]​
SAC有两个版本,第一版使用了Q network, V network,Policy network,熵正则化的系数为定值。第二版的SAC中将V network取消,使用了Double Q network,并且提出了能够动态调节熵正则化系数的方法。这里将先介绍第一种SAC算法,再介绍第二种SAC算法。


V network的目标函数
J V ( ψ ) = E s t ⁣ ∼ ⁣ D [ 1 2 ( V ψ ( s t ) − E a t ∼ π ϕ [ Q θ ( s t , a t ) − log ⁡ π ϕ ( a t ∣ s t ) ] ) 2 ] J_{V}(\psi)=\mathbb{E}_{\mathbf{s}_{t}}\!\sim\!D\;\left[{\frac{1}{2}}\left(V_{\psi}(\mathbf{s}_{t})-\mathbb{E}_{\mathbf{a}_{t}\sim\pi_{\phi}}\left[Q_{\theta}(\mathbf{s}_{t},\mathbf{a}_{t})-\log\pi_{\phi}(\mathbf{a}_{t}|\mathbf{s}_{t})\right]\right)^{2}\right] JV​(ψ)=Est​​∼D[21​(Vψ​(st​)−Eat​∼πϕ​​[Qθ​(st​,at​)−logπϕ​(at​∣st​)])2]
Q network的目标函数
J Q ( θ ) = E ( s t , a t ) ∼ D [ 1 2 ( Q θ ( s t , a t ) − Q ^ ( s t , a t ) ) 2 ] J_{Q}(\theta)=\mathbb{E}_{\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right) \sim \mathcal{D}}\left[\frac{1}{2}\left(Q_{\theta}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)-\hat{Q}\left(\mathbf{s}_{t}, \mathbf{a}_{t}\right)\right)^{2}\right] JQ​(θ)=E(st​,at​)∼D​[21​(Qθ​(st​,at​)−Q^​(st​,at​))2]
Policy network的目标函数
J π ( ϕ ) = E S t ∼ D [ D K L ( π ϕ ( ⋅ ∣ s t ) ∣ ∣ exp ⁡ ( Q θ ( s t , ⋅ ) ) Z θ ( s t ) ) ] J_{\pi}(\phi)=\mathbb{E}_{\mathbb{S}_{t}\sim D}\left[\mathrm{D}_{\mathrm{KL}}\left(\pi_{\phi}(\cdot|\mathbf{s}_{t})\left|\right|{\frac{\exp\left(Q_{\theta}(\mathbf{s}_{t},\cdot)\right)}{Z_{\theta}(\mathbf{s}_{t})}}\right)\right] Jπ​(ϕ)=ESt​∼D​[DKL​(πϕ​(⋅∣st​)∣∣Zθ​(st​)exp(Qθ​(st​,⋅))​)]
初看Policy network的目标函数的表示可能会有些不太理解,其实 exp ⁡ ( Q θ ( s t , ⋅ ) ) Z θ ( s t ) \frac{\exp\left(Q_{\theta}(\mathbf{s}_{t},\cdot)\right)}{Z_{\theta}(\mathbf{s}_{t})} Zθ​(st​)exp(Qθ​(st​,⋅))​是下面的式子的解(其中 Z θ ( s t ) Z_{\theta}(\mathbf{s}_{t}) Zθ​(st​)用于归一化, Z ( s ) = ∑ a exp ⁡ ( 1 α Q ( s , a ) ) Z(s)=\sum_{a}\exp\left({\textstyle{\frac{1}{\alpha}}}Q(s,a)\right) Z(s)=a∑​exp(α1​Q(s,a)))
在SAC版本一中,使用了三个网络。但是其实V network和Q network本身是有联系的,所以后面在SAC第二个版本的提出中去掉了V network,使用了Double Q network来解决高估问题。并且提供了动态调节 α \alpha α的方法。一般来说,推荐使用第二个版本的SAC算法。版本二的SAC在很多方面都和SAC相似,本文重点介绍不同的方面。

自动化调节正则化参数的方法可以通过最下化下面的损失函数来实现其中 k = − d i m ( A ) k=-dim(A) k=−dim(A)
J ( α ) = E a ∼ π θ [ − α log ⁡ π θ ( a ∣ s ) − α κ ] J(\alpha)=\mathbb{E}_{a\sim\pi_{\theta}}\left[-\alpha\log\pi_{\theta}(a|s)-\alpha\kappa\right] J(α)=Ea∼πθ​​[−αlogπθ​(a∣s)−ακ]


重参数化能够降低期望估计的方差并且有利于梯度的反向传播,在SAC中使用了重参数化的技巧。假设我们已经知道了动作的均值和标准差 μ θ \mu_{\theta} μθ​和 σ θ \sigma_{\theta} σθ​,我们需要令
a t = t a n h ( μ θ + ϵ ⋅ σ θ ) , ϵ ∼ N ( 0 , 1 ) a_t = tanh(\mu_{\theta}+\epsilon\cdot\sigma_{\theta}),\epsilon\sim\mathcal{N}(0,1)\qquad at​=tanh(μθ​+ϵ⋅σθ​),ϵ∼N(0,1)

from torch.distributions import Normal
normal = Normal(mean, std)
z = normal.rsample()



Policy network

class Actor(nn.Module):def __init__(self, state_dim, action_dim, max_action=1, init_w=3e-3):super(Actor, self).__init__()self.l1 = nn.Linear(state_dim, 128)self.l2 = nn.Linear(128, 128)self.l3_mean = nn.Linear(128, action_dim)self.log_std_linear = nn.Linear(128, action_dim)self.max_action =, init_w), init_w), init_w), init_w)def forward(self, x):x = F.relu(self.l1(x))x = F.relu(self.l2(x))mean = self.l3_mean(x)log_std = self.log_std_linear(x)log_std = torch.clamp(log_std, -20, 2)return mean, log_stddef evaluate(self, state, epsilon=1e-6):mean, log_std = self.forward(state)std = log_std.exp()normal = Normal(mean, std)z = normal.rsample()action = torch.tanh(z)log_prob = normal.log_prob(z) - torch.log(1 - action.pow(2) + epsilon)log_prob = log_prob.sum(1, keepdim=True)return action, log_probdef select_action(self, state):state = torch.FloatTensor(state).to(device)mean, log_std = self.forward(state)std = log_std.exp()normal = Normal(mean, std)z = normal.rsample()action = torch.tanh(z)action = action.detach().cpu().numpy()return action


 log_prob = normal.log_prob(z) - torch.log(1 - action.pow(2) + epsilon)

代码中的log_prob对应的是 log ⁡ π ( a ∣ s ) \log\pi(\mathbf{a}|\mathbf{s}) logπ(a∣s),这行代码的理论依据为论文原文的这个公式,epsilon的添加是为了避免第二项出现无穷小。

log ⁡ π ( a ∣ s ) = log ⁡ μ ( u ∣ s ) − ∑ i = 1 D l o g ( 1 − tanh ⁡ 2 ( u i ) ) \log\pi(\mathbf{a}|\mathbf{s})=\log\mu(\mathbf{u}|\mathbf{s})-\sum_{i=1}^{D}\mathbf{log}\left(1-\operatorname{tanh}^{2}(\mathbf{u}_{i})\right) logπ(a∣s)=logμ(u∣s)−i=1∑D​log(1−tanh2(ui​))

Q network

class Critic(nn.Module):def __init__(self, state_dim, action_dim, init_w=3e-3):super(Critic, self).__init__()self.l1 = nn.Linear(state_dim + action_dim, 128)self.l2 = nn.Linear(128, 128)self.l3 = nn.Linear(128, 1), init_w), init_w)def forward(self, x, u):x = F.relu(self.l1([x, u], 1)))x = F.relu(self.l2(x))x = self.l3(x)return x

这部分和以前接触的Q network的定义并没有太多的不同

Update parameters

def update(self):# Sample replay bufferstate, action, reward, next_state, done = self.replay_buffer.sample(args.batch_size)state = torch.FloatTensor(state).to(device)action = torch.FloatTensor(action).to(device)reward = torch.FloatTensor(reward).to(device)next_state = torch.FloatTensor(next_state).to(device)done = torch.FloatTensor(1 - done).to(device)next_action, next_log_prob = self.policy_network.evaluate(next_state)# Compute the target Q valuetarget_Q_1 = self.critic_target_1(next_state, next_action)target_Q_2 = self.critic_target_2(next_state, next_action)target_Q = torch.min(target_Q_1, target_Q_2) - next_log_probmy_target_Q = reward.reshape((100, 1)) + (done * args.gamma * target_Q)# Get current Q estimatecurrent_Q_1 = self.critic_1(state, action)current_Q_2 = self.critic_2(state, action)# Compute critic losscritic_loss_1 = F.mse_loss(current_Q_1, my_target_Q.detach())critic_loss_2 = F.mse_loss(current_Q_2, my_target_Q.detach())critic_loss = critic_loss_1 + critic_loss_2# Optimize the criticself.critic_optimizer_1.zero_grad()self.critic_optimizer_2.zero_grad()critic_loss.backward()self.critic_optimizer_1.step()self.critic_optimizer_2.step()if self.update_step % 2 == 0:new_action, log_prob = self.policy_network.evaluate(state)# Compute actor lossmin_q = torch.min(self.critic_1(state, new_action),self.critic_2(state, new_action))actor_loss = (log_prob - min_q).mean()# Optimize the actorself.actor_optimizer.zero_grad()actor_loss.backward()self.actor_optimizer.step()# Update the frozen target modelsfor param, target_param in zip(self.critic_1.parameters(), self.critic_target_1.parameters()) * + (1 - args.tau) * param, target_param in zip(self.critic_2.parameters(), self.critic_target_2.parameters()) * + (1 - args.tau) * += 1

参数更新主要分为三个部分,第一个部分为Q network,第二部分为 Policy network, 第三部分为 α \alpha α。在上述代码中我没有实现第三部分的更新,读者如果想实现自动调节只需根据公式完成代码的编写即可。


1:Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor

2:Soft Actor-Critic Algorithms and Applications

3:Deep Reinforcement Learning Fundamentals, Research and Applications

4:From Policy Gradient to Actor-Critic methods Soft Actor Critic, ISIR



