赞
踩
class Adam: """Adam (http://arxiv.org/abs/1412.6980v8)""" def __init__(self, lr=0.001, beta1=0.9, beta2=0.999): self.lr = lr self.beta1 = beta1 self.beta2 = beta2 self.iter = 0 self.m = None self.v = None def update(self, params, grads): if self.m is None: self.m, self.v = {}, {} for key, val in params.items(): self.m[key] = np.zeros_like(val) self.v[key] = np.zeros_like(val) self.iter += 1 lr_t = self.lr * np.sqrt(1.0 - self.beta2**self.iter) / (1.0 - self.beta1**self.iter) for key in params.keys(): #self.m[key] = self.beta1*self.m[key] + (1-self.beta1)*grads[key] #self.v[key] = self.beta2*self.v[key] + (1-self.beta2)*(grads[key]**2) self.m[key] += (1 - self.beta1) * (grads[key] - self.m[key]) self.v[key] += (1 - self.beta2) * (grads[key]**2 - self.v[key]) params[key] -= lr_t * self.m[key] / (np.sqrt(self.v[key]) + 1e-7) #unbias_m += (1 - self.beta1) * (grads[key] - self.m[key]) # correct bias #unbisa_b += (1 - self.beta2) * (grads[key]*grads[key] - self.v[key]) # correct bias #params[key] += self.lr * unbias_m / (np.sqrt(unbisa_b) + 1e-7)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。