当前位置:   article > 正文

网络模型(Seq2Seq-注意力机制-编解码)_基于注意力机制的编解码网络

基于注意力机制的编解码网络

概念

用于处理序列问题:翻译(N vs N)、信息提取(N vs 1)、生成(1 vs N)。
Seq2Seq
RNN 要求输入队列和输出队列等长,Seq2Seq 可以解决输入队列与输出队列不等长的问题。

实验(验证码识别)

数据集:生成 4 位数字的验证码图片(测试集和训练集各 1000 张),图片名称为 index.code.jpg,截取 code 作为标签。

网络结构:

  • 编码:全连接 + 标准化(BN)+ 激活(ReLU)+ LSTM。
  • 解码:LSTM + 全连接 + softmax(多分类)。

优化器:Adam。

损失函数:均方差(MSELoss)。

输出:4 个 one-hot 类型,结果为最大的索引值。

生成验证码

import random
from PIL import Image, ImageDraw, ImageFont


# 随机数字
def rand_char():
    return chr(random.randint(48, 57))


# 随机背景颜色
def rand_bg():
    return (random.randint(50, 150), random.randint(50, 150), random.randint(50, 150))


# 随机数字颜色
def rand_color():
    return (random.randint(100, 255), random.randint(100, 255), random.randint(100, 255))


width = 240
height = 60
font = ImageFont.truetype("arial.ttf", size=36)
for i in range(1000):
    img = Image.new("RGB", (width, height), (255, 255, 255))
    draw = ImageDraw.ImageDraw(img)
    # 画背景
    for x in range(width):
        for y in range(height):
            draw.point((x, y), rand_bg())
    # 写数字
    chrs = []
    for n in range(4):
        each = rand_char()
        chrs.append(each)
        draw.text((n * 60 + 10
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/514856
推荐阅读
相关标签
  

闽ICP备14008679号