当前位置:   article > 正文

使用TensorFlow.js的AI聊天机器人二:训练Trivia Expert AI_tensorflow.js 聊天机器人

tensorflow.js 聊天机器人

目录

设置TensorFlow.js代码

TriviaQA数据集

单词嵌入和标记

训练AI模型

聊天机器人(Trivia Chatbot)在行动

终点线

下一步是什么?


TensorFlow + JavaScript。现在,最流行,最先进的AI框架支持地球上使用最广泛的编程语言。因此,让我们在Web浏览器中通过深度学习使文本和NLP(自然语言处理)聊天机器人神奇地发生,使用TensorFlow.js通过WebGL加速GPU!

上一篇文章中,我们带您完成了一个AI模型的训练过程,该模型可以使用TensorFlow在浏览器中为任何英语句子计算27种情绪之一。在这一部分中,我们将构建一个聊天机器人。

很好地回答聊天问题需要知道无数事实,并且能够准确地回忆起相关知识。利用计算机的大脑真是一个绝佳的机会!

让我们训练一个聊天机器人,使用递归神经网络(RNN)为我们提供数百个不同聊天问题的答案。

设置TensorFlow.js代码

在此项目中,我们将与聊天机器人进行交互,因此让我们将一些输入元素和文本响应从该机器人添加到我们的模板网页中。

  1. <html>
  2. <head>
  3. <title>Trivia Know-It-All: Chatbots in the Browser with TensorFlow.js</title>
  4. <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js"></script>
  5. </head>
  6. <body>
  7. <h1 id="status">Trivia Know-It-All Bot</h1>
  8. <label>Ask a trivia question:</label>
  9. <input id="question" type="text" />
  10. <button id="submit">Submit</button>
  11. <p id="bot-question"></p>
  12. <p id="bot-answer"></p>
  13. <script>
  14. function setText( text ) {
  15. document.getElementById( "status" ).innerText = text;
  16. }
  17. (async () => {
  18. // Your Code Goes Here
  19. })();
  20. </script>
  21. </body>
  22. </html>

TriviaQA数据集

我们将用于训练神经网络的数据来自华盛顿大学提供的TriviaQA数据集2.5万个压缩文件中有9.5万个聊天问答对,可供下载。

现在,我们将使用一个较小的子集verified-wikipedia-dev.json,该子集包含在该项目的示例代码中。

TriviaQA JSON文件由一个Data数组组成,该数组具有各个QA元素,这些元素看起来类似于以下示例文件

  1. {
  2. "Data": [
  3. {
  4. "Answer": {
  5. "Aliases": [
  6. "Sunset Blvd",
  7. "West Sunset Boulevard",
  8. "Sunset Boulevard",
  9. "Sunset Bulevard",
  10. "Sunset Blvd."
  11. ],
  12. "MatchedWikiEntityName": "Sunset Boulevard",
  13. "NormalizedAliases": [
  14. "west sunset boulevard",
  15. "sunset blvd",
  16. "sunset boulevard",
  17. "sunset bulevard"
  18. ],
  19. "NormalizedMatchedWikiEntityName": "sunset boulevard",
  20. "NormalizedValue": "sunset boulevard",
  21. "Type": "WikipediaEntity",
  22. "Value": "Sunset Boulevard"
  23. },
  24. "EntityPages": [
  25. {
  26. "DocSource": "TagMe",
  27. "Filename": "Andrew_Lloyd_Webber.txt",
  28. "LinkProbability": "0.02934",
  29. "Rho": "0.22520",
  30. "Title": "Andrew Lloyd Webber"
  31. }
  32. ],
  33. "Question": "Which Lloyd Webber musical premiered in the US on 10th December 1993?",
  34. "QuestionId": "tc_33",
  35. "QuestionSource": "http://www.triviacountry.com/",
  36. "SearchResults": [
  37. {
  38. "Description": "The official website for Andrew Lloyd Webber, ... from the Andrew Lloyd Webber/Jim Steinman musical Whistle ... American premiere on 9th December 1993 at the ...",
  39. "DisplayUrl": "www.andrewlloydwebber.com",
  40. "Filename": "35/35_995.txt",
  41. "Rank": 0,
  42. "Title": "Andrew Lloyd Webber | The official website for Andrew ...",
  43. "Url": "http://www.andrewlloydwebber.com/"
  44. }
  45. ]
  46. }
  47. ],
  48. "Domain": "Web",
  49. "VerifiedEval": false,
  50. "Version": 1.0
  51. }

我们可以像这样在我们的代码中加载数据:

  1. (async () => {
  2. // Load TriviaQA data
  3. let triviaData = await fetch( "web/verified-wikipedia-dev.json" ).then( r => r.json() );
  4. let data = triviaData.Data;
  5. // Process all QA to map to answers
  6. let questions = data.map( qa => qa.Question );
  7. })();

单词嵌入和标记

对于这些聊天问题以及一般的英语句子、单词的位置和顺序可能会影响其含义。因此,当将句子变成向量时,我们不能简单地使用不保留单词位置信息的单词袋。因此,在准备训练数据时,我们将使用一种称为word embedding的方法,并创建一个表示单词及其位置的单词索引列表。

首先,我们将遍历所有可用数据,并在所有问题中识别每个唯一的单词,就像准备一袋单词时一样。我们想在wordReference索引中添加+1以保留索引0作为TensorFlow中的填充令牌。

  1. let bagOfWords = {};
  2. let allWords = [];
  3. let wordReference = {};
  4. questions.forEach( q => {
  5. let words = q.replace(/[^a-z ]/gi, "").toLowerCase().split( " " ).filter( x => !!x );
  6. words.forEach( w => {
  7. if( !bagOfWords[ w ] ) {
  8. bagOfWords[ w ] = 0;
  9. }
  10. bagOfWords[ w ]++; // Counting occurrence just for word frequency fun
  11. });
  12. });
  13. allWords = Object.keys( bagOfWords );
  14. allWords.forEach( ( w, i ) => {
  15. wordReference[ w ] = i + 1;
  16. });

在拥有包含所有单词及其索引的完整词汇表之后,我们可以采用每个疑问句并创建与每个单词的索引相对应的正整数数组。我们需要确保输入向量(进入网络)的长度相同。我们可以将句子的最大数量限制为30个单词,并且任何少于30个单词的问题都可以设置零索引来表示空白填充。

让我们还生成预期的输出分类向量,这些向量映射到每个不同的问答对。

  1. // Create a tokenized vector for each question
  2. const maxSentenceLength = 30;
  3. let vectors = [];
  4. questions.forEach( q => {
  5. let qVec = [];
  6. // Use a regex to only get spaces and letters and remove any blank elements
  7. let words = q.replace(/[^a-z ]/gi, "").toLowerCase().split( " " ).filter( x => !!x );
  8. for( let i = 0; i < maxSentenceLength; i++ ) {
  9. if( words[ i ] ) {
  10. qVec.push( wordReference[ words[ i ] ] );
  11. }
  12. else {
  13. // Add padding to keep the vectors the same length
  14. qVec.push( 0 );
  15. }
  16. }
  17. vectors.push( qVec );
  18. });
  19. let outputs = questions.map( ( q, index ) => {
  20. let output = [];
  21. for( let i = 0; i < questions.length; i++ ) {
  22. output.push( i === index ? 1 : 0 );
  23. }
  24. return output;
  25. });

训练AI模型

TensorFlow为像我们刚刚创建的标记化矢量提供了一种嵌入层类型,并将其转化为可用于神经网络的密集矢量。我们使用RNN架构是因为单词的顺序在每个问题中都很重要。我们可以使用简单的RNN层或双向的神经网络来训练神经网络。随意取消注释/注释代码行,并尝试其中之一。

网络应返回一个分类向量,其中最大值的索引将与问题答案对的索引对应。模型的完成设置应如下所示:

  1. // Define our RNN model with several hidden layers
  2. const model = tf.sequential();
  3. // Add 1 to inputDim for the "padding" character
  4. model.add(tf.layers.embedding( { inputDim: allWords.length + 1, outputDim: 128, inputLength: maxSentenceLength } ) );
  5. // model.add(tf.layers.simpleRNN( { units: 32 } ) );
  6. model.add(tf.layers.bidirectional( { layer: tf.layers.simpleRNN( { units: 32 } ), mergeMode: "concat" } ) );
  7. model.add(tf.layers.dense( { units: 50 } ) );
  8. model.add(tf.layers.dense( { units: 25 } ) );
  9. model.add(tf.layers.dense( {
  10. units: questions.length,
  11. activation: "softmax"
  12. } ) );
  13. model.compile({
  14. optimizer: tf.train.adam(),
  15. loss: "categoricalCrossentropy",
  16. metrics: [ "accuracy" ]
  17. });

最后,我们可以将输入数据转换为张量并训练网络。

  1. const xs = tf.stack( vectors.map( x => tf.tensor1d( x ) ) );
  2. const ys = tf.stack( outputs.map( x => tf.tensor1d( x ) ) );
  3. await model.fit( xs, ys, {
  4. epochs: 20,
  5. shuffle: true,
  6. callbacks: {
  7. onEpochEnd: ( epoch, logs ) => {
  8. setText( `Training... Epoch #${epoch} (${logs.acc})` );
  9. console.log( "Epoch #", epoch, logs );
  10. }
  11. }
  12. } );

聊天机器人(Trivia Chatbot)在行动

我们快准备好了。

要测试我们的聊天机器人,我们需要能够通过提交问题并使其回答做出交谈。让我们在机器人经过训练并准备就绪时通知用户,并处理用户输入:

  1. setText( "Trivia Know-It-All Bot is Ready!" );
  2. document.getElementById( "question" ).addEventListener( "keyup", function( event ) {
  3. // Number 13 is the "Enter" key on the keyboard
  4. if( event.keyCode === 13 ) {
  5. // Cancel the default action, if needed
  6. event.preventDefault();
  7. // Trigger the button element with a click
  8. document.getElementById( "submit" ).click();
  9. }
  10. });
  11. document.getElementById( "submit" ).addEventListener( "click", async function( event ) {
  12. let text = document.getElementById( "question" ).value;
  13. document.getElementById( "question" ).value = "";
  14. // Our prediction code will go here
  15. });

最后,在我们的“click”事件处理程序中,我们可以像对待训练问题一样,标记用户提交的问题。然后,我们可以让模型发挥作用,预测最可能被问到的问题,并显示聊天问题和答案。

在测试聊天机器人时,您可能会注意到单词的顺序似乎影响太大,或者问题中的第一个单词会严重影响其输出。我们将在下一篇文章中对此进行改进。同时,您可以使用另一种方法来解决此问题,该方法称为Attention,以训练bot权衡某些单词的权重。

如果您想了解更多信息,我建议您查看这篇关于可视化的文章,其中介绍了注意在序列到序列模型中如何有用。

终点线

现在这是我们的完整代码:

  1. <html>
  2. <head>
  3. <title>Trivia Know-It-All: Chatbots in the Browser with TensorFlow.js</title>
  4. <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js"></script>
  5. </head>
  6. <body>
  7. <h1 id="status">Trivia Know-It-All Bot</h1>
  8. <label>Ask a trivia question:</label>
  9. <input id="question" type="text" />
  10. <button id="submit">Submit</button>
  11. <p id="bot-question"></p>
  12. <p id="bot-answer"></p>
  13. <script>
  14. function setText( text ) {
  15. document.getElementById( "status" ).innerText = text;
  16. }
  17. (async () => {
  18. // Load TriviaQA data
  19. let triviaData = await fetch( "web/verified-wikipedia-dev.json" ).then( r => r.json() );
  20. let data = triviaData.Data;
  21. // Process all QA to map to answers
  22. let questions = data.map( qa => qa.Question );
  23. let bagOfWords = {};
  24. let allWords = [];
  25. let wordReference = {};
  26. questions.forEach( q => {
  27. let words = q.replace(/[^a-z ]/gi, "").toLowerCase().split( " " ).filter( x => !!x );
  28. words.forEach( w => {
  29. if( !bagOfWords[ w ] ) {
  30. bagOfWords[ w ] = 0;
  31. }
  32. bagOfWords[ w ]++; // Counting occurrence just for word frequency fun
  33. });
  34. });
  35. allWords = Object.keys( bagOfWords );
  36. allWords.forEach( ( w, i ) => {
  37. wordReference[ w ] = i + 1;
  38. });
  39. // Create a tokenized vector for each question
  40. const maxSentenceLength = 30;
  41. let vectors = [];
  42. questions.forEach( q => {
  43. let qVec = [];
  44. // Use a regex to only get spaces and letters and remove any blank elements
  45. let words = q.replace(/[^a-z ]/gi, "").toLowerCase().split( " " ).filter( x => !!x );
  46. for( let i = 0; i < maxSentenceLength; i++ ) {
  47. if( words[ i ] ) {
  48. qVec.push( wordReference[ words[ i ] ] );
  49. }
  50. else {
  51. // Add padding to keep the vectors the same length
  52. qVec.push( 0 );
  53. }
  54. }
  55. vectors.push( qVec );
  56. });
  57. let outputs = questions.map( ( q, index ) => {
  58. let output = [];
  59. for( let i = 0; i < questions.length; i++ ) {
  60. output.push( i === index ? 1 : 0 );
  61. }
  62. return output;
  63. });
  64. // Define our RNN model with several hidden layers
  65. const model = tf.sequential();
  66. // Add 1 to inputDim for the "padding" character
  67. model.add(tf.layers.embedding( { inputDim: allWords.length + 1, outputDim: 128, inputLength: maxSentenceLength, maskZero: true } ) );
  68. model.add(tf.layers.simpleRNN( { units: 32 } ) );
  69. // model.add(tf.layers.bidirectional( { layer: tf.layers.simpleRNN( { units: 32 } ), mergeMode: "concat" } ) );
  70. model.add(tf.layers.dense( { units: 50 } ) );
  71. model.add(tf.layers.dense( { units: 25 } ) );
  72. model.add(tf.layers.dense( {
  73. units: questions.length,
  74. activation: "softmax"
  75. } ) );
  76. model.compile({
  77. optimizer: tf.train.adam(),
  78. loss: "categoricalCrossentropy",
  79. metrics: [ "accuracy" ]
  80. });
  81. const xs = tf.stack( vectors.map( x => tf.tensor1d( x ) ) );
  82. const ys = tf.stack( outputs.map( x => tf.tensor1d( x ) ) );
  83. await model.fit( xs, ys, {
  84. epochs: 20,
  85. shuffle: true,
  86. callbacks: {
  87. onEpochEnd: ( epoch, logs ) => {
  88. setText( `Training... Epoch #${epoch} (${logs.acc})` );
  89. console.log( "Epoch #", epoch, logs );
  90. }
  91. }
  92. } );
  93. setText( "Trivia Know-It-All Bot is Ready!" );
  94. document.getElementById( "question" ).addEventListener( "keyup", function( event ) {
  95. // Number 13 is the "Enter" key on the keyboard
  96. if( event.keyCode === 13 ) {
  97. // Cancel the default action, if needed
  98. event.preventDefault();
  99. // Trigger the button element with a click
  100. document.getElementById( "submit" ).click();
  101. }
  102. });
  103. document.getElementById( "submit" ).addEventListener( "click", async function( event ) {
  104. let text = document.getElementById( "question" ).value;
  105. document.getElementById( "question" ).value = "";
  106. // Run the calculation things
  107. let qVec = [];
  108. let words = text.replace(/[^a-z ]/gi, "").toLowerCase().split( " " ).filter( x => !!x );
  109. for( let i = 0; i < maxSentenceLength; i++ ) {
  110. if( words[ i ] ) {
  111. qVec.push( wordReference[ words[ i ] ] );
  112. }
  113. else {
  114. // Add padding to keep the vectors the same length
  115. qVec.push( 0 );
  116. }
  117. }
  118. let prediction = await model.predict( tf.stack( [ tf.tensor1d( qVec ) ] ) ).data();
  119. // Get the index of the highest value in the prediction
  120. let id = prediction.indexOf( Math.max( ...prediction ) );
  121. document.getElementById( "bot-question" ).innerText = questions[ id ];
  122. document.getElementById( "bot-answer" ).innerText = data[ id ].Answer.Value;
  123. });
  124. })();
  125. </script>
  126. </body>
  127. </html>

下一步是什么?

使用RNN,我们创建了一个深度学习聊天机器人来识别问题,并在浏览器中直接从大量聊天问题/答案对中为我们提供答案。接下来,我们将研究嵌入整个句子而不是单个单词,以便在从文本中检测情感时获得更准确的结果。

和我一起参加本系列的下一篇文章中,使用TensorFlow.js在浏览器中改进文本情感检测。

https://www.codeproject.com/Articles/5282688/AI-Chatbots-With-TensorFlow-js-Training-a-Trivia-E

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/337171
推荐阅读
相关标签
  

闽ICP备14008679号