当前位置:   article > 正文

AI作诗,模仿周杰伦创作歌词<->实战项目

歌词仿写 开源

点击上方码农的后花园”,选择星标” 公众号

精选文章,第一时间送达

很久以来,我们都想让机器自己创作诗歌,当无数作家、编辑还没有抬起笔时,AI已经完成了数千篇文章。现在,这里是第一步....

8bc4caa4a23777a620708baace439a64.png这诗做的很有感觉啊,这都是勤奋的结果啊,基本上学习了全唐诗的所有精华才有了这么牛逼的能力,这一般人能做到?

甚至还可以模仿周杰伦创作歌词 !!955f467cf7261b4e0819695a2cd35389.png怎么说,目前由于缺乏训练文本,导致我们的AI做的歌词有点....额,还好啦,有那么一点忧郁之风。

1.下载代码和数据集

Github地址: https://github.com/jinfagang/tensorflow_poems

数据集: 存放于项目的data文件夹内

fa4fb44039883035f308f81c9320ac2d.png

2.环境导入

  1. import os
  2. import tensorflow as tf
  3. from poems.model import rnn_model
  4. from poems.poems import process_poems, generate_batch
  5. import argparse
  6. from pathlib import Path

3.参数设置

  1. parser = argparse.ArgumentParser()
  2. #type是要传入的参数的数据类型  help是该参数的提示信息
  3. parser.add_argument('--batch_size'type=int, help='batch_size',default=64)
  4. parser.add_argument('--learning_rate'type=float, help='learning_rate',default=0.0001)
  5. parser.add_argument('--model_dir'type=Path, help='model save path.',default='./model')
  6. parser.add_argument('--file_path'type=Path, help='file name of poems.',default='./data/poems.txt')
  7. parser.add_argument('--model_prefix'type=str, help='model save prefix.',default='poems')
  8. parser.add_argument('--epochs'type=int, help='train how many epochs.',default=126)
  9. args = parser.parse_args(args=[])

4.训练

下载的代码中的./model/中包含最新的训练模型,再次训练会接着训练。如果训练路径报错,需要删除./model的模型,重新开始训练。

  1. def run_training():
  2.     if not os.path.exists(args.model_dir):
  3.         os.makedirs(args.model_dir)
  4.     poems_vector, word_to_int, vocabularies = process_poems(args.file_path)
  5.     batches_inputs, batches_outputs = generate_batch(args.batch_size, poems_vector, word_to_int)
  6.     input_data = tf.placeholder(tf.int32, [args.batch_size, None])
  7.     output_targets = tf.placeholder(tf.int32, [args.batch_size, None])
  8.     end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len(
  9.         vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=args.learning_rate)
  10.     saver = tf.train.Saver(tf.global_variables())
  11.     init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
  12.     with tf.Session() as sess:
  13.         # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess)
  14.         # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
  15.         sess.run(init_op)
  16.         start_epoch = 0
  17.         checkpoint = tf.train.latest_checkpoint(args.model_dir)
  18.         if checkpoint:
  19.             saver.restore(sess, checkpoint)
  20.             print("## restore from the checkpoint {0}".format(checkpoint))
  21.             start_epoch += int(checkpoint.split('-')[-1])
  22.         print('## start training...')
  23.         try:
  24.             n_chunk = len(poems_vector) // args.batch_size
  25.             for epoch in range(start_epoch, args.epochs):
  26.                 n = 0
  27.                 for batch in range(n_chunk):
  28.                     loss, _, _ = sess.run([
  29.                         end_points['total_loss'],
  30.                         end_points['last_state'],
  31.                         end_points['train_op']
  32.                     ], feed_dict={input_data: batches_inputs[n], output_targets: batches_outputs[n]})
  33.                     n += 1
  34.                 print('Epoch: %d, batch: %d, training loss: %.6f' % (epoch, batch, loss))
  35.                 #if epoch % 5 == 0:
  36.                 saver.save(sess, os.path.join(args.model_dir, args.model_prefix), global_step=epoch)
  37.         except KeyboardInterrupt:
  38.             print('## Interrupt manually, try saving checkpoint for now...')
  39.             saver.save(sess, os.path.join(args.model_dir, args.model_prefix), global_step=epoch)
  40.             print('## Last epoch were saved, next time will start from epoch {}.'.format(epoch))
  41.             
  42. run_training()

5.诗词生成

  1. import tensorflow as tf
  2. from poems.model import rnn_model
  3. from poems.poems import process_poems
  4. import numpy as np
  5. start_token = 'B'
  6. end_token = 'E'
  7. model_dir = './model/'
  8. corpus_file = './data/poems.txt'
  9. lr = 0.0002
  10. def to_word(predict, vocabs):
  11.     predict = predict[0]       
  12.     predict /= np.sum(predict)
  13.     sample = np.random.choice(np.arange(len(predict)), p=predict)
  14.     if sample > len(vocabs):
  15.         return vocabs[-1]
  16.     else:
  17.         return vocabs[sample]
  18. def gen_poem(begin_word):
  19.     batch_size = 1
  20.     print('## loading corpus from %s' % model_dir)
  21.     tf.reset_default_graph()
  22.     
  23.     poems_vector, word_int_map, vocabularies = process_poems(corpus_file)
  24.     input_data = tf.placeholder(tf.int32, [batch_size, None])
  25.     end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=len(
  26.         vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=lr)
  27.     saver = tf.train.Saver(tf.global_variables())
  28.     init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
  29.     with tf.Session() as sess:
  30.         sess.run(init_op)
  31.         checkpoint = tf.train.latest_checkpoint(model_dir)
  32.         saver.restore(sess, checkpoint)
  33.         x = np.array([list(map(word_int_map.get, start_token))])
  34.         [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
  35.                                          feed_dict={input_data: x})
  36.         word = begin_word or to_word(predict, vocabularies)
  37.         poem_ = ''
  38.         i = 0
  39.         while word != end_token:
  40.             poem_ += word
  41.             i += 1
  42.             if i > 24:
  43.                 break
  44.             x = np.array([[word_int_map[word]]])
  45.             [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
  46.                                              feed_dict={input_data: x, end_points['initial_state']: last_state})
  47.             word = to_word(predict, vocabularies)
  48.         return poem_
  49. def pretty_print_poem(poem_):
  50.     poem_sentences = poem_.split('。')
  51.     for s in poem_sentences:
  52.         if s != '' and len(s) > 10:
  53.             print(s + '。')

6.测试运行

5c5f84595f5a8c12d4461dae750357da.png

7. 运行环境

本次使用框架TensorFlow1.13.1,本项目可以在华为提供的JupyterLab环境中运行。参考华为的实践案例:《AI作诗》:http://su.modelarts.club/dqTT https://developer.huaweicloud.com/signup/e4240e984d1c4d20bfcc83e7f7648b6c?

后台回复关键字:项目实战,可下载完整代码。

ed0131cc7caf4482de1a18ef98944741.png

13125fdc342eb0446b6ed9607364de08.gif

·················END·················

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号