赞
踩
@RM算法求解函数根
RM(Robbins-Monro)算法求解g(w)= w**3 -5 =0的根
学习西湖大学赵世钰老师github网址、b站课程视频强化学习的数学原理第6课时,对RM算法求根的算法进行编程浮现,发现一直提示g(w)太大超出可计算范围。
考虑噪声、以及设置更新最大值之后的全部代码
import matplotlib.pyplot as plt import numpy as np import math def g_w(w_k): result = math.pow(w_k, 3) - 5 return result def g_tilde(w_k,eta): result = g_w(w_k) + eta return result def RM_algorithm_improved(w_initial, max_iterations=100, max_value=1000, max_update=10): w = w_initial ws = [w] # to store all estimates etas = [w] for k in range(1, max_iterations + 1): eta = np.random.normal(0, 1) # noise with mean 0 and standard deviation 1 g_tilde_w = g_tilde(w, eta) # Update step with a check to prevent too large updates # update = alpha_k / k * g_tilde_w update = 1 / k * g_tilde_w if abs(update) > max_update: # Limit the update to prevent drastic changes update = np.sign(update) * max_update w = w - update ''' if abs(w) > max_value: # Bailout if w becomes too large print(f"Bailing out at iteration {k} due to large value of w.") break ''' etas.append(eta) ws.append(w) return ws, etas # 初始化参数,求解函数 w_initial = 0 estimates_improved, etas = RM_algorithm_improved(w_initial) # Plot the convergence of the RM algorithm with improvements plt.figure(figsize=(15, 12)) plt.subplot(2,1, 1) plt.plot(estimates_improved, label='w_k estimates') plt.axhline(y=5**(1/3), color='r', linestyle='--', label='True root') plt.xlabel('Iteration') plt.ylabel('Estimate of w') plt.title('Convergence of RM Algorithm (Improved)') plt.legend() plt.grid(True) '''画出eta随迭代变化的情况''' plt.subplot(2,1, 2) plt.plot(etas, label='eta estimates') plt.xlabel('Iteration') plt.ylabel('Eta') plt.title('noisy of RM Algorithm (Improved)') plt.legend() plt.grid(True) plt.show() # Show the length to see how many iterations were run before stopping
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。