当前位置:   article > 正文

昇思MindSpore学习笔记2-01 LLM原理和实践 --基于 MindSpore 实现 BERT 对话情绪识别

昇思MindSpore学习笔记2-01 LLM原理和实践 --基于 MindSpore 实现 BERT 对话情绪识别

摘要:

通过识别BERT对话情绪状态的实例,展现在昇思MindSpore AI框架中大语言模型的原理和实际使用方法、步骤。

一、环境配置

  1. %%capture captured_output
  2. # 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
  3. !pip uninstall mindspore -y
  4. !pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
  5. # 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行`!pip install mindnlp==0.3.1`
  6. !pip install mindnlp

输出:

  1. Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
  2. Collecting mindnlp
  3. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/72/37/ef313c23fd587c3d1f46b0741c98235aecdfd93b4d6d446376f3db6a552c/mindnlp-0.3.1-py3-none-any.whl (5.7 MB)
  4. ━━━━━━━━━━━━━━━━ 5.7/5.7 MB 14.2 MB/s eta 0:00:0000:0100:01
  5. Requirement already satisfied: mindspore in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (2.2.14)
  6. Requirement already satisfied: tqdm in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (4.66.4)
  7. Requirement already satisfied: requests in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (2.32.3)
  8. Collecting datasets (from mindnlp)
  9. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/60/2d/963b266bb8f88492d5ab4232d74292af8beb5b6fdae97902df9e284d4c32/datasets-2.20.0-py3-none-any.whl (547 kB)
  10. ━━━━━━━━━━━━━━━━ 547.8/547.8 kB 21.2 MB/s eta 0:00:00
  11. Collecting evaluate (from mindnlp)
  12. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c2/d6/ff9baefc8fc679dcd9eb21b29da3ef10c81aa36be630a7ae78e4611588e1/evaluate-0.4.2-py3-none-any.whl (84 kB)
  13. ━━━━━━━━━━━━━━━━ 84.1/84.1 kB 24.8 MB/s eta 0:00:00
  14. Collecting tokenizers (from mindnlp)
  15. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ba/26/139bd2371228a0e203da7b3e3eddcb02f45b2b7edd91df00e342e4b55e13/tokenizers-0.19.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (3.6 MB)
  16. ━━━━━━━━━━━━━━━━ 3.6/3.6 MB 14.7 MB/s eta 0:00:00a 0:00:01
  17. Collecting safetensors (from mindnlp)
  18. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c6/02/28e6280ed0f1bde89eed644b80f2ece4e5ae212dc9ee70d7f56fadc93602/safetensors-0.4.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (1.2 MB)
  19. ━━━━━━━━━━━━━━━━ 1.2/1.2 MB 17.8 MB/s eta 0:00:00a 0:00:01
  20. Collecting sentencepiece (from mindnlp)
  21. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a3/69/e96ef68261fa5b82379fdedb325ceaf1d353c6e839ec346d8244e0da5f2f/sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (1.3 MB)
  22. ━━━━━━━━━━━━━━━━ 1.3/1.3 MB 14.4 MB/s eta 0:00:00a 0:00:01
  23. Collecting regex (from mindnlp)
  24. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/70/70/fea4865c89a841432497d1abbfd53878513b55c6543245fabe31cf8df0b8/regex-2024.5.15-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (774 kB)
  25. ━━━━━━━━━━━━━━━━ 774.7/774.7 kB 15.3 MB/s eta 0:00:00a 0:00:01
  26. Collecting addict (from mindnlp)
  27. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/6a/00/b08f23b7d7e1e14ce01419a467b583edbb93c6cdb8654e54a9cc579cd61f/addict-2.4.0-py3-none-any.whl (3.8 kB)
  28. Collecting ml-dtypes (from mindnlp)
  29. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/50/96/13d7c3cc82d5ef597279216cf56ff461f8b57e7096a3ef10246a83ca80c0/ml_dtypes-0.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (2.2 MB)
  30. ━━━━━━━━━━━━━━━━ 2.2/2.2 MB 11.9 MB/s eta 0:00:00a 0:00:01
  31. Collecting pyctcdecode (from mindnlp)
  32. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a5/8a/93e2118411ae5e861d4f4ce65578c62e85d0f1d9cb389bd63bd57130604e/pyctcdecode-0.5.0-py2.py3-none-any.whl (39 kB)
  33. Collecting jieba (from mindnlp)
  34. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c6/cb/18eeb235f833b726522d7ebed54f2278ce28ba9438e3135ab0278d9792a2/jieba-0.42.1.tar.gz (19.2 MB)
  35. ━━━━━━━━━━━━━━━━ 19.2/19.2 MB 16.5 MB/s eta 0:00:0000:0100:01
  36. Preparing metadata (setup.py) ... done
  37. Collecting pytest==7.2.0 (from mindnlp)
  38. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/67/68/a5eb36c3a8540594b6035e6cdae40c1ef1b6a2bfacbecc3d1a544583c078/pytest-7.2.0-py3-none-any.whl (316 kB)
  39. ━━━━━━━━━━━━━━━━ 316.8/316.8 kB 16.7 MB/s eta 0:00:00
  40. Requirement already satisfied: attrs>=19.2.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (23.2.0)
  41. Requirement already satisfied: iniconfig in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (2.0.0)
  42. Requirement already satisfied: packaging in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (23.2)
  43. Requirement already satisfied: pluggy<2.0,>=0.12 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (1.5.0)
  44. Requirement already satisfied: exceptiongroup>=1.0.0rc8 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (1.2.0)
  45. Requirement already satisfied: tomli>=1.0.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (2.0.1)
  46. Requirement already satisfied: filelock in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (3.15.3)
  47. Requirement already satisfied: numpy>=1.17 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (1.26.4)
  48. Collecting pyarrow>=15.0.0 (from datasets->mindnlp)
  49. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/87/60/cc0645eb4ef73f88847e40a7f9d238bae6b7409d6c1f6a5d200d8ade1f09/pyarrow-16.1.0-cp39-cp39-manylinux_2_28_aarch64.whl (38.1 MB)
  50. ━━━━━━━━━━━━━━━━ 38.1/38.1 MB 14.2 MB/s eta 0:00:0000:0100:01
  51. Collecting pyarrow-hotfix (from datasets->mindnlp)
  52. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/e4/f4/9ec2222f5f5f8ea04f66f184caafd991a39c8782e31f5b0266f101cb68ca/pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)
  53. Requirement already satisfied: dill<0.3.9,>=0.3.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (0.3.8)
  54. Requirement already satisfied: pandas in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (2.2.2)
  55. Collecting xxhash (from datasets->mindnlp)
  56. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/7c/b9/93f860969093d5d1c4fa60c75ca351b212560de68f33dc0da04c89b7dc1b/xxhash-3.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (220 kB)
  57. ━━━━━━━━━━━━━━━━ 220.6/220.6 kB 15.6 MB/s eta 0:00:00
  58. Collecting multiprocess (from datasets->mindnlp)
  59. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/da/d9/f7f9379981e39b8c2511c9e0326d212accacb82f12fbfdc1aa2ce2a7b2b6/multiprocess-0.70.16-py39-none-any.whl (133 kB)
  60. ━━━━━━━━━━━━━━━━ 133.4/133.4 kB 15.8 MB/s eta 0:00:00
  61. Collecting fsspec<=2024.5.0,>=2023.1.0 (from fsspec[http]<=2024.5.0,>=2023.1.0->datasets->mindnlp)
  62. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ba/a3/16e9fe32187e9c8bc7f9b7bcd9728529faa725231a0c96f2f98714ff2fc5/fsspec-2024.5.0-py3-none-any.whl (316 kB)
  63. ━━━━━━━━━━━━━━━━ 316.1/316.1 kB 16.8 MB/s eta 0:00:00
  64. Collecting aiohttp (from datasets->mindnlp)
  65. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/eb/45/eebe8d2215328434f33ccb44a05d2741ff7ed4b96b56ca507e2ecf598b73/aiohttp-3.9.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (1.2 MB)
  66. ━━━━━━━━━━━━━━━━ 1.2/1.2 MB 17.1 MB/s eta 0:00:0000:0100:01
  67. Requirement already satisfied: huggingface-hub>=0.21.2 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (0.23.4)
  68. Requirement already satisfied: pyyaml>=5.1 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (6.0.1)
  69. Requirement already satisfied: charset-normalizer<4,>=2 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->mindnlp) (3.3.2)
  70. Requirement already satisfied: idna<4,>=2.5 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->mindnlp) (3.7)
  71. Requirement already satisfied: urllib3<3,>=1.21.1 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->mindnlp) (2.2.2)
  72. Requirement already satisfied: certifi>=2017.4.17 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->mindnlp) (2024.6.2)
  73. Requirement already satisfied: protobuf>=3.13.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (5.27.1)
  74. Requirement already satisfied: asttokens>=2.0.4 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (2.0.5)
  75. Requirement already satisfied: pillow>=6.2.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (10.3.0)
  76. Requirement already satisfied: scipy>=1.5.4 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (1.13.1)
  77. Requirement already satisfied: psutil>=5.6.1 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (5.9.0)
  78. Requirement already satisfied: astunparse>=1.6.3 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (1.6.3)
  79. Collecting pygtrie<3.0,>=2.1 (from pyctcdecode->mindnlp)
  80. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ec/cd/bd196b2cf014afb1009de8b0f05ecd54011d881944e62763f3c1b1e8ef37/pygtrie-2.5.0-py3-none-any.whl (25 kB)
  81. Collecting hypothesis<7,>=6.14 (from pyctcdecode->mindnlp)
  82. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ae/ea/526a7a629fcf6c78a1a6d37f988ca7e02e5b5785ec4de8a194deb40529f4/hypothesis-6.104.2-py3-none-any.whl (462 kB)
  83. ━━━━━━━━━━━━━━━━ 462.4/462.4 kB 14.4 MB/s eta 0:00:00
  84. Requirement already satisfied: six in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from asttokens>=2.0.4->mindspore->mindnlp) (1.16.0)
  85. Requirement already satisfied: wheel<1.0,>=0.23.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from astunparse>=1.6.3->mindspore->mindnlp) (0.43.0)
  86. Collecting aiosignal>=1.1.2 (from aiohttp->datasets->mindnlp)
  87. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/76/ac/a7305707cb852b7e16ff80eaf5692309bde30e2b1100a1fcacdc8f731d97/aiosignal-1.3.1-py3-none-any.whl (7.6 kB)
  88. Collecting frozenlist>=1.1.1 (from aiohttp->datasets->mindnlp)
  89. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/57/15/172af60c7e150a1d88ecc832f2590721166ae41eab582172fe1e9844eab4/frozenlist-1.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (239 kB)
  90. ━━━━━━━━━━━━━━━━ 239.4/239.4 kB 17.1 MB/s eta 0:00:00
  91. Collecting multidict<7.0,>=4.5 (from aiohttp->datasets->mindnlp)
  92. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/d0/10/2ff646c471e84af25fe8111985ffb8ec85a3f6e1ade8643bfcfcc0f4d2b1/multidict-6.0.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (125 kB)
  93. ━━━━━━━━━━━━━━━━ 125.9/125.9 kB 31.0 MB/s eta 0:00:00
  94. Collecting yarl<2.0,>=1.0 (from aiohttp->datasets->mindnlp)
  95. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c6/d6/5b30ae1d8a13104ee2ceb649f28f2db5ad42afbd5697fd0fc61528bb112c/yarl-1.9.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (300 kB)
  96. ━━━━━━━━━━━━━━━━ 300.9/300.9 kB 20.5 MB/s eta 0:00:00
  97. Collecting async-timeout<5.0,>=4.0 (from aiohttp->datasets->mindnlp)
  98. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a7/fa/e01228c2938de91d47b307831c62ab9e4001e747789d0b05baf779a6488c/async_timeout-4.0.3-py3-none-any.whl (5.7 kB)
  99. Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from huggingface-hub>=0.21.2->datasets->mindnlp) (4.11.0)
  100. Collecting sortedcontainers<3.0.0,>=2.1.0 (from hypothesis<7,>=6.14->pyctcdecode->mindnlp)
  101. Downloading https://pypi.tuna.tsinghua.edu.cn/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl (29 kB)
  102. Requirement already satisfied: python-dateutil>=2.8.2 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pandas->datasets->mindnlp) (2.9.0.post0)
  103. Requirement already satisfied: pytz>=2020.1 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pandas->datasets->mindnlp) (2024.1)
  104. Requirement already satisfied: tzdata>=2022.7 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pandas->datasets->mindnlp) (2024.1)
  105. Building wheels for collected packages: jieba
  106. Building wheel for jieba (setup.py) ... done
  107. Created wheel for jieba: filename=jieba-0.42.1-py3-none-any.whl size=19314459 sha256=352f23b7dc8b4bade2f918165e055bc707601544400a4918136ba69f220ce9f6
  108. Stored in directory: /home/nginx/.cache/pip/wheels/1a/76/68/b6d79c4db704bb18d54f6a73ab551185f4711f9730c0c15d97
  109. Successfully built jieba
  110. Installing collected packages: sortedcontainers, sentencepiece, pygtrie, jieba, addict, xxhash, safetensors, regex, pytest, pyarrow-hotfix, pyarrow, multiprocess, multidict, ml-dtypes, hypothesis, fsspec, frozenlist, async-timeout, yarl, pyctcdecode, aiosignal, tokenizers, aiohttp, datasets, evaluate, mindnlp
  111. Attempting uninstall: pytest
  112. Found existing installation: pytest 8.0.0
  113. Uninstalling pytest-8.0.0:
  114. Successfully uninstalled pytest-8.0.0
  115. Attempting uninstall: fsspec
  116. Found existing installation: fsspec 2024.6.0
  117. Uninstalling fsspec-2024.6.0:
  118. Successfully uninstalled fsspec-2024.6.0
  119. Successfully installed addict-2.4.0 aiohttp-3.9.5 aiosignal-1.3.1 async-timeout-4.0.3 datasets-2.20.0 evaluate-0.4.2 frozenlist-1.4.1 fsspec-2024.5.0 hypothesis-6.104.2 jieba-0.42.1 mindnlp-0.3.1 ml-dtypes-0.4.0 multidict-6.0.5 multiprocess-0.70.16 pyarrow-16.1.0 pyarrow-hotfix-0.6 pyctcdecode-0.5.0 pygtrie-2.5.0 pytest-7.2.0 regex-2024.5.15 safetensors-0.4.3 sentencepiece-0.2.0 sortedcontainers-2.4.0 tokenizers-0.19.1 xxhash-3.4.1 yarl-1.9.4
  120. [notice] A new release of pip is available: 24.1 -> 24.1.1
  121. [notice] To update, run: python -m pip install --upgrade pip

显示mindspore模块的基本信息

!pip show mindspore

输出:

  1. Name: mindspore
  2. Version: 2.2.14
  3. Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
  4. Home-page: https://www.mindspore.cn
  5. Author: The MindSpore Authors
  6. Author-email: contact@mindspore.cn
  7. License: Apache 2.0
  8. Location: /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages
  9. Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
  10. Required-by: mindnlp

二、模型简介

BERT是一种新型语言模型

全称Bidirectional Encoder Representations from Transformers

中文:双向表达的编码变换

Google发布于2018年

用于自然语言处理场景类似的预训练语言模型有:

        问答

        命名实体识别

        自然语言推理

        文本分类等

BERT模型涉及

        Transformer的Encoder

        双向结构

BERT模型的主要创新点

        pre-train方法

                用Masked Language Model捕捉词语

                用Next Sentence Prediction捕捉句子

用Masked Language Model方法训练BERT对话

        随机把语料库中15%的单词做Mask操作。

        Mask操作的三种情况:

                80%的单词直接用[Mask]替换

                10%的单词直接替换成另一个新的单词

                10%的单词保持不变。

问答Question Answering (QA) 

自然语言推断Natural Language Inference (NLI)

Next Sentence Prediction预训练任务

        目的:

                让模型理解两个句子之间的联系。

        训练内容:

                输入是句子A和B

                B有一半的几率是A的下一句

                预测B是不是A的下一句

        训练结果:

                Embedding table

                12层Transformer权重(BERT-BASE)

                或24层Transformer权重(BERT-LARGE)。

        微调Fine-tuning下游任务:

                文本分类

                相似度判断

                阅读理解等。

对话情绪识别Emotion Detection简称EmoTect

        对话文本

        判断文本情绪类别

                积极

                消极

                中性

        计算置信度。

导入mindspore dataset nn context mindnlp等模块

  1. import os
  2. import mindspore
  3. from mindspore.dataset import text, GeneratorDataset, transforms
  4. from mindspore import nn, context
  5. from mindnlp._legacy.engine import Trainer, Evaluator
  6. from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
  7. from mindnlp._legacy.metrics import Accuracy

输出:

  1. Building prefix dict from the default dictionary ...
  2. Dumping model to file cache /tmp/jieba.cache
  3. Loading model cost 1.037 seconds.
  4. Prefix dict has been built successfully.

三、准备数据集

1. 数据集说明

实验数据集采用百度飞桨机器人聊天数据

        已标注

        分词预处理

数据两列制表符('\t')分隔

        情绪分类

                0消极

                1中性

                2积极

        中文文本

                空格分词

                utf8编码

数据示例:

  1. label--text_a
  2. 0--谁骂人了?我从来不骂人,我骂的都不是人,你是人吗 ?
  3. 1--我有事等会儿就回来和你聊
  4. 2--我见到你很高兴谢谢你帮我

2.下载数据集

  1. # download dataset
  2. !wget https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz -O emotion_detection.tar.gz
  3. !tar xvf emotion_detection.tar.gz

输出:

  1. --2024-07-01 13:38:50-- https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz
  2. Resolving baidu-nlp.bj.bcebos.com (baidu-nlp.bj.bcebos.com)... 119.249.103.5, 113.200.2.111, 2409:8c04:1001:1203:0:ff:b0bb:4f27
  3. Connecting to baidu-nlp.bj.bcebos.com (baidu-nlp.bj.bcebos.com)|119.249.103.5|:443... connected.
  4. HTTP request sent, awaiting response... 200 OK
  5. Length: 1710581 (1.6M) [application/x-gzip]
  6. Saving to: ‘emotion_detection.tar.gz’
  7. emotion_detection.t 100%[===================>] 1.63M 8.04MB/s in 0.2s
  8. 2024-07-01 13:38:50 (8.04 MB/s) - ‘emotion_detection.tar.gz’ saved [1710581/1710581]
  9. data/
  10. data/test.tsv
  11. data/infer.tsv
  12. data/dev.tsv
  13. data/train.tsv
  14. data/vocab.txt

3.定义数据集类

  1. # prepare dataset
  2. class SentimentDataset:
  3. """Sentiment Dataset"""
  4. def __init__(self, path):
  5. self.path = path
  6. self._labels, self._text_a = [], []
  7. self._load()
  8. def _load(self):
  9. with open(self.path, "r", encoding="utf-8") as f:
  10. dataset = f.read()
  11. lines = dataset.split("\n")
  12. for line in lines[1:-1]:
  13. label, text_a = line.split("\t")
  14. self._labels.append(int(label))
  15. self._text_a.append(text_a)
  16. def __getitem__(self, index):
  17. return self._labels[index], self._text_a[index]
  18. def __len__(self):
  19. return len(self._labels)

四、数据加载和数据预处理

数据加载和预处理函数

process_dataset()

  1. import numpy as np
  2. def process_dataset(source, tokenizer, max_seq_len=64, batch_size=32, shuffle=True):
  3. is_ascend = mindspore.get_context('device_target') == 'Ascend'
  4. column_names = ["label", "text_a"]
  5. dataset = GeneratorDataset(source, column_names=column_names, shuffle=shuffle)
  6. # transforms
  7. type_cast_op = transforms.TypeCast(mindspore.int32)
  8. def tokenize_and_pad(text):
  9. if is_ascend:
  10. tokenized = tokenizer(text, padding='max_length',
  11. truncation=True, max_length=max_seq_len)
  12. else:
  13. tokenized = tokenizer(text)
  14. return tokenized['input_ids'], tokenized['attention_mask']
  15. # map dataset
  16. dataset = dataset.map(operations=tokenize_and_pad, input_columns="text_a",
  17. output_columns=['input_ids', 'attention_mask'])
  18. dataset = dataset.map(operations=[type_cast_op], input_columns="label",
  19. output_columns='labels')
  20. # batch dataset
  21. if is_ascend:
  22. dataset = dataset.batch(batch_size)
  23. else:
  24. dataset = dataset.padded_batch(batch_size,
  25. pad_info={'input_ids': (None, tokenizer.pad_token_id),
  26. 'attention_mask': (None, 0)})
  27. return dataset

数据预处理部分采用静态Shape处理

        昇腾NPU环境下暂不支持动态Shape

  1. from mindnlp.transformers import BertTokenizer
  2. tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

输出:

  1. 100%━━━━━━━━━━━━━━━━━━━━━ 49.0/49.0 [00:00<00:00, 3.05kB/s]
  2.  ━107k/0.00 [00:05<00:00, 36.3kB/s]
  3.  ━263k/0.00 [00:15<00:00, 10.2kB/s]
  4.  ━━━━━━━━━━━━━━━━━━━━━ 624/? [00:00<00:00, 56.0kB/s]

tokenizer.pad_token_id

输出:

0

取训练数据集的列名:

  1. dataset_train = process_dataset(SentimentDataset("data/train.tsv"), tokenizer)
  2. dataset_val = process_dataset(SentimentDataset("data/dev.tsv" ), tokenizer)
  3. dataset_test = process_dataset(SentimentDataset("data/test.tsv" ), tokenizer, shuffle=False)
  4. dataset_train.get_col_names()

输出:

['input_ids', 'attention_mask', 'labels']

遍历显示训练数据集

print(next(dataset_train.create_tuple_iterator()))

输出:

  1. [Tensor(shape=[32, 64], dtype=Int64, value=
  2. [[ 101, 2769, 4638 ... 0, 0, 0],
  3. [ 101, 2769, 3221 ... 0, 0, 0],
  4. [ 101, 758, 1282 ... 0, 0, 0],
  5. ...
  6. [ 101, 1217, 678 ... 0, 0, 0],
  7. [ 101, 872, 679 ... 0, 0, 0],
  8. [ 101, 872, 3766 ... 0, 0, 0]]),
  9. Tensor(shape=[32, 64], dtype=Int64, value=
  10. [[1, 1, 1 ... 0, 0, 0],
  11. [1, 1, 1 ... 0, 0, 0],
  12. [1, 1, 1 ... 0, 0, 0],
  13. ...
  14. [1, 1, 1 ... 0, 0, 0],
  15. [1, 1, 1 ... 0, 0, 0],
  16. [1, 1, 1 ... 0, 0, 0]]),
  17. Tensor(shape=[32], dtype=Int32, value=
  18. [1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1])]

五、模型构建

BERT 模型

        BertForSequenceClassification模块构建

                加载预训练权重

                设置情感三分类

        自动混合精度

        实例化优化器

        实例化评价指标

        设置模型训练的权重保存策略

        构建训练器

        模型开始训练

  1. from mindnlp.transformers import BertForSequenceClassification, BertModel
  2. from mindnlp._legacy.amp import auto_mixed_precision
  3. # set bert config and define parameters for training
  4. model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=3)
  5. model = auto_mixed_precision(model, 'O1')
  6. optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)

(), learning_rate=2e-5)

输出:

100%━━━━━━━━━━━━━━━━━━ 392M/392M [00:53<00:00, 6.82MB/s]

The following parameters in checkpoint files are not loaded:

['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']

The following parameters in models are missing parameter:

['classifier.weight', 'classifier.bias']

  1. metric = Accuracy()
  2. # define callbacks to save checkpoints
  3. ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='bert_emotect', epochs=1, keep_checkpoint_max=2)
  4. best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='bert_emotect_best', auto_load=True)
  5. # 构建训练器
  6. trainer = Trainer(network=model, train_dataset=dataset_train,
  7. eval_dataset=dataset_val, metrics=metric,
  8. epochs=5, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb])
  9. %%time
  10. # start training
  11. trainer.run(tgt_columns="labels")

输出:

  1. The train will start from the checkpoint saved in 'checkpoint'.
  2. Epoch  0: 100%━━━━━━━━━━━━━━ 302/302 [04:07<00:00,  2.25s/it, loss=0.3460012]
  3. Checkpoint: 'bert_emotect_epoch_0.ckpt' has been saved in epoch: 0.
  4. Evaluate: 100%━━━━━━━━━━━━━━ 34/34 [00:07<00:00,  1.07it/s]
  5. Evaluate Score: {'Accuracy': 0.9351851851851852}
  6. ---------------Best Model: 'bert_emotect_best.ckpt' has been saved in epoch: 0.---------------
  7. Epoch  1: 100%━━━━━━━━━━━━━━ 302/302 [02:38<00:00,  1.95it/s, loss=0.19017023]
  8. Checkpoint: 'bert_emotect_epoch_1.ckpt' has been saved in epoch: 1.
  9. Evaluate: 100%━━━━━━━━━━━━━━ 34/34 [00:05<00:00,  7.48it/s]
  10. Evaluate Score: {'Accuracy': 0.9564814814814815}
  11. ---------------Best Model: 'bert_emotect_best.ckpt' has been saved in epoch: 1.---------------
  12. Epoch  2: 100%━━━━━━━━━━━━━━ 302/302 [02:40<00:00,  1.92it/s, loss=0.12662967]
  13. The maximum number of stored checkpoints has been reached.
  14. Checkpoint: 'bert_emotect_epoch_2.ckpt' has been saved in epoch: 2.
  15. Evaluate: 100%━━━━━━━━━━━━━━ 34/34 [00:04<00:00,  7.59it/s]
  16. Evaluate Score: {'Accuracy': 0.9740740740740741}
  17. ---------------Best Model: 'bert_emotect_best.ckpt' has been saved in epoch: 2.---------------
  18. Epoch  3: 100%━━━━━━━━━━━━━━ 302/302 [02:40<00:00,  1.92it/s, loss=0.08593981]
  19. The maximum number of stored checkpoints has been reached.
  20. Checkpoint: 'bert_emotect_epoch_3.ckpt' has been saved in epoch: 3.
  21. Evaluate: 100%━━━━━━━━━━━━━━ 34/34 [00:04<00:00,  7.51it/s]
  22. Evaluate Score: {'Accuracy': 0.9833333333333333}
  23. ---------------Best Model: 'bert_emotect_best.ckpt' has been saved in epoch: 3.---------------
  24. Epoch  4: 100%━━━━━━━━━━━━━━ 302/302 [02:41<00:00,  1.92it/s, loss=0.05900709]
  25. The maximum number of stored checkpoints has been reached.
  26. Checkpoint: 'bert_emotect_epoch_4.ckpt' has been saved in epoch: 4.
  27. Evaluate: 100%━━━━━━━━━━━━━━ 34/34 [00:04<00:00,  7.39it/s]
  28. Evaluate Score: {'Accuracy': 0.9879629629629629}
  29. ---------------Best Model: 'bert_emotect_best.ckpt' has been saved in epoch: 4.---------------
  30. Loading best model from 'checkpoint' with '['Accuracy']': [0.9879629629629629]...
  31. ---------------The model is already load the best model from 'bert_emotect_best.ckpt'.---------------
  32. CPU times: user 22min 58s, sys: 13min 25s, total: 36min 24s
  33. Wall time: 15min 30s

六、模型验证

验证评估

        测试数据集

        准确率

  1. evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
  2. evaluator.run(tgt_columns="labels")

输出:

Evaluate: 100%━━━━━━━━━━━━━━ 33/33 [00:08<00:00,  1.20s/it]

Evaluate Score: {'Accuracy': 0.8822393822393823}

七、模型推理

遍历推理数据集,展示结果与标签。

  1. dataset_infer = SentimentDataset("data/infer.tsv")
  2. def predict(text, label=None):
  3. label_map = {0: "消极", 1: "中性", 2: "积极"}
  4. text_tokenized = Tensor([tokenizer(text).input_ids])
  5. logits = model(text_tokenized)
  6. predict_label = logits[0].asnumpy().argmax()
  7. info = f"inputs: '{text}', predict: '{label_map[predict_label]}'"
  8. if label is not None:
  9. info += f" , label: '{label_map[label]}'"
  10. print(info)
  11. from mindspore import Tensor
  12. for label, text in dataset_infer:
  13. predict(text, label)

输出:

  1. inputs: '我 要 客观', predict: '中性' , label: '中性'
  2. inputs: '靠 你 真是 说 废话 吗', predict: '消极' , label: '消极'
  3. inputs: '口嗅 会', predict: '中性' , label: '中性'
  4. inputs: '每次 是 表妹 带 窝 飞 因为 窝路痴', predict: '中性' , label: '中性'
  5. inputs: '别说 废话 我 问 你 个 问题', predict: '消极' , label: '消极'
  6. inputs: '4967 是 新加坡 那 家 银行', predict: '中性' , label: '中性'
  7. inputs: '是 我 喜欢 兔子', predict: '积极' , label: '积极'
  8. inputs: '你 写 过 黄山 奇石 吗', predict: '中性' , label: '中性'
  9. inputs: '一个一个 慢慢来', predict: '中性' , label: '中性'
  10. inputs: '我 玩 过 这个 一点 都 不 好玩', predict: '消极' , label: '消极'
  11. inputs: '网上 开发 女孩 的 QQ', predict: '中性' , label: '中性'
  12. inputs: '背 你 猜 对 了', predict: '中性' , label: '中性'
  13. inputs: '我 讨厌 你 , 哼哼 哼 。 。', predict: '消极' , label: '消极'

inputs: '我 讨厌 你 , 哼哼 哼 。 。', predict: '消极' , label: '消极'

八、自定义推理数据集

predict("家人们咱就是说一整个无语住了 绝绝子叠buff")

输出:

inputs: '家人们咱就是说一整个无语住了 绝绝子叠buff', predict: '中性'

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/秋刀鱼在做梦/article/detail/782354
推荐阅读
相关标签
  

闽ICP备14008679号