赞
踩
- class SGD:
- def __init__(self, lr=0.01):
- self.lr = lr
- def update(self, params, grads):
- for key in params.keys():
- params[key] -= self.lr * grads[key]
- network = TwoLayerNet(...)
- optimizer = SGD()
- for i in range(10000):
- ...
- x_batch, t_batch = get_mini_batch(...) # mini-batch
- grads = network.gradient(x_batch, t_batch)
- params = network.params
- optimizer.update(params, grads)
- ...
像这样,通过单独实现进行最优化的类,功能的模块化变得更简单。比如,后面我们马上会实现另一个最优化方法Momentum,它同样会实现成拥有update(params, grads)这个共同方法的形式。这样一来,只需要将optimizer = SGD()这一语句换成optimizer = Momentum(),就可以从SGD切
- for key in params.keys():
- self.v[key] = self.momentum*self.v[key] - self.lr*grads[key]
- params[key] += self.v[key]
- class Momentum:
-
- """Momentum SGD"""
-
- def __init__(self, lr=0.01, momentum=0.9):
- self.lr = lr
- self.momentum = momentum
- self.v = None
-
- def update(self, params, grads):
- if self.v is None:
- self.v = {}
- for key, val in params.items():
- self.v[key] = np.zeros_like(val)
-
- for key in params.keys():
- self.v[key] = self.momentum*self.v[key] - self.lr*grads[key]
- params[key] += self.v[key]
- self.h[key] += grads[key] * grads[key]
-
- params[key] -= self.lr * grads[key] / (np.sqrt(self.h[key]) + 1e-7)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。