当前位置:   article > 正文

《神经网络与深度学习:案例与实践》动手练习1.3

《神经网络与深度学习:案例与实践》动手练习1.3

飞桨AI Studio星河社区-人工智能学习与实训社区

动手练习1.3

执行上述算子的反向过程,并验证梯度是否正确。

  1. import math
  2. class Op(object):
  3. def __init__(self):
  4. pass
  5. def __call__(self, inputs):
  6. return self.forward(inputs)
  7. # 前向函数
  8. # 输入:张量inputs
  9. # 输出:张量outputs
  10. def forward(self, inputs):
  11. # return outputs
  12. raise NotImplementedError
  13. # 反向函数
  14. # 输入:最终输出对outputs的梯度outputs_grads
  15. # 输出:最终输出对inputs的梯度inputs_grads
  16. def backward(self, outputs_grads):
  17. # return inputs_grads
  18. raise NotImplementedError
  19. class add(Op):
  20. def __init__(self):
  21. super(add, self).__init__()
  22. def __call__(self, x, y):
  23. return self.forward(x, y)
  24. def forward(self, x, y):
  25. self.x = x
  26. self.y = y
  27. outputs = x + y
  28. return outputs
  29. def backward(self, grads):
  30. grads_x = grads * 1
  31. grads_y = grads * 1
  32. return grads_x, grads_y
  33. class multiply(Op):
  34. def __init__(self):
  35. super(multiply, self).__init__()
  36. def __call__(self, x, y):
  37. return self.forward(x, y)
  38. def forward(self, x, y):
  39. self.x = x
  40. self.y = y
  41. outputs = x * y
  42. return outputs
  43. def backward(self, grads):
  44. grads_x = grads * self.y
  45. grads_y = grads * self.x
  46. return grads_x, grads_y
  47. class exponential(Op):
  48. def __init__(self):
  49. super(exponential, self).__init__()
  50. def forward(self, x):
  51. self.x = x
  52. outputs = math.exp(x)
  53. return outputs
  54. def backward(self, grads):
  55. grads = grads * math.exp(self.x)
  56. return grads
  57. a, b, c, d = 2, 3, 2, 2
  58. multiply_op1 = multiply()
  59. f1=multiply_op1(a,b)
  60. multiply_op2 = multiply()
  61. f2=multiply_op2(c,d)
  62. add_op = add()
  63. f3=add_op(f1,f2)
  64. exp_op = exponential()
  65. f4=exp_op(f3)
  66. print(f4)
  67. val1=exp_op.backward(grads=1)
  68. val2=add_op.backward(val1)
  69. val3=multiply_op1.backward(val2[0])
  70. val4=multiply_op2.backward(val2[0])
  71. print(val3)
  72. print(val4)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/498562
推荐阅读
  

闽ICP备14008679号