赞
踩
FizzBuzz
FizzBuzz是一个非常简单的小游戏。游戏规则如下:从1开始往上数数,当遇到3的倍数的时候,说fizz,当遇到5的倍数,说buzz,当遇到15的倍数,就说fizzbuzz,其他情况则正常数数。
我们可以写一个简单的小程序来解决要返回正常数值还是fizz,buzz或者fizzbuzz。
- # One-hot encode the desired outputs: [number, "fizz", "buzz", "fizzbuzz"]
- def fizz_buzz_encode(i):
- if i % 15 == 0: return 3
- elif i % 5 == 0: return 2
- elif i % 3 == 0: return 1
- else: return 0
-
- def fizz_buzz_decode(i, prediction):
-
- #[str(i), "fizz", "buzz", "fizzbuzz"]为一个列表,里面有4个元素,对应为四类,可用索引来访问
- return [str(i), "fizz", "buzz", "fizzbuzz"][prediction]
-
- for i in range(1,16):
- print(fizz_buzz_decode(i, fizz_buzz_encode(i)))
定义模型的输入与输出(训练数据)
- import numpy as np
- import torch
-
- NUM_DIGITS = 10
-
- # Represent each input by an array of its binary digits.
- def binary_encode(i, num_digits):
- return np.array([i >> d & 1 for d in range(num_digits)])
- trX = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(101, 2 ** NUM_DIGITS)])
- trY = torch.LongTensor([fizz_buzz_encode(i) for i in range(101, 2 ** NUM_DIGITS)])
用pyTorch定义模型
- # Define the model
- NUM_HIDDEN = 100
- model = torch.nn.Sequential(
- torch.nn.Linear(NUM_DIGITS, NUM_HIDDEN),
- torch.nn.ReLU(),
- torch.nn.Linear(NUM_HIDDEN, 4)
- )
定义损失函数和优化器
- loss_fn = torch.nn.CrossEntropyLoss()
- optimizer = torch.optim.SGD(model.parameters(), lr = 0.05)
模型训练代码
- # Start training it
- BATCH_SIZE = 128
- for epoch in range(10000):
- for start in range(0, len(trX), BATCH_SIZE):
- end = start + BATCH_SIZE
- batchX = trX[start:end]
- batchY = trY[start:end]
-
- #predict
- y_pred = model(batchX) #forward pass
- loss = loss_fn(y_pred, batchY)
-
- #optimizer,3 steps
- optimizer.zero_grad() #clear grad
- loss.backward() #backward pass
- optimizer.step() #gradient descent
-
- # Find loss on training data
- loss = loss_fn(model(trX), trY).item()
- print('Epoch:', epoch, 'Loss:', loss)

用cpu跑的损失函数值还行,一般
用训练好的模型尝试在1到100这些数字上玩FizzBuzz游戏
- # Output now
- testX = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(1, 101)])
- with torch.no_grad():
- testY = model(testX)
- predictions = zip(range(1, 101), list(testY.max(1)[1].data.tolist()))
-
- print([fizz_buzz_decode(i, x) for (i, x) in predictions])
tips:with torch.no_grad(),测试时每个tensor都带有grad,为了优化参数,测试时不需要grad;否则徒占memory,还易爆。
1到10个错3个,11到20全对……
参考https://blog.csdn.net/u013075024/article/details/104365933
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。