赞
踩
BiLSTM-Attention-情感评分-实战应用
情感分析旨在自动识别和提取文本中的倾向、立场、评价、观点等主观信息。它包含各式各样的任务,比如句子级情感分类、评价对象级情感分类、观点抽取、情绪分类等。这次实战运用主要是针对互联网上新闻数据,目前互联网上关于BiLSTM-Attention运用到文本情感评分的代码很多,理论多于实战。本文将从词向量、样本数据预处理、训练、保存训练结果、运用训练结果等方面介绍。
资源地址: 链接:https://pan.baidu.com/s/1J5h3fehNIxoxiAISbjmCOw 提取码:5jbj
名字 | 说明 |
java | 词向量训练代码、实战运用模型 |
python | 训练模型代码 |
训练模型 | 已经训练好的模型 |
Word2vec | 已经训练好的词向量 |
软件 | 版本 |
jdk | Jdk1.8 |
python | 3.4.3 |
tensorflow | 1.15.0 |
Java idea - eclipse | launcher |
Python idea - IntelliJ IDEA Community Edition | 14.1.4 |
本模型使用的是Word2vec,它是一群用来产生词向量的相关模型。这些模型为双层的神经网络,用来训练以重新建构语言学之词文本。
网络以词表现,并且需猜测相邻位置的输入词,在word2vec中词袋模型假设下,词的顺序是不重要的。训练完成之后,word2vec模型可用来映射每个词到一个向量,可用来表示词对词之间的关系,该向量为神经网络之隐藏层。
详细介绍略(自己上网翻)。
com.jt.dctsaple.word2vec.nlp.vec.Learn 详细训练代码,需要的直接看代码,github有大量的源码,大家可以根据自己的需要去寻找。
如果适配特定领域数据,需要寻找该领域的样本,训练该领域词向量。
如果文本分类对数字比较敏感,建议分词时特殊处理。
样本数据分成三份80%训练、10%测试、10%预测。
分类 | 分类标记 |
负面 | -1 |
中性 | 0 |
正面 | 1 |
本文的样本对数据中的数字、电话号码做了单独处理,所以大家可以根据自己的需要去做处理,别忘了词向量。
import numpy as np import tensorflow as tf def _read_word2vec(filepath): f = open(filepath, encoding='gbk', errors='ignore') # 返回一个文件对象 line = f.readline() # 调用文件的 readline()方法 print(line) i = 0 words_list = [] words_list_index = [] word_vectors = [] # for j in range(200): # print(j) while line: # print(i, ':', len(line)), # 后面跟 ',' 将忽略换行符 # print(line, end = '') # 在 Python 3中使用 line = f.readline() line = line.strip('\n') lines = line.split("\t") if i >= 1 and lines.__len__() == 202: # print(lines[0]) v = np.zeros((200)) for j in range(200): v[j] = float(lines[j+1]) words_list.append(lines[0]) words_list_index.append(i-1) word_vectors.append(v) else: print(line) i += 1 f.close() words_list_map = dict(zip(words_list,words_list_index)) return words_list, np.array(word_vectors), words_list_map def _read_train_data(filepath): ft = open(filepath, encoding='gbk', errors='ignore') # 返回一个文件对象 # line = f.readline() # 调用文件的 readline()方法 targets = [] words = [] # j = 0 for line in ft.readlines(): line = line.strip('\n') lines = line.split("<sos>") v = [] if lines.__len__() != 2: print(line) else: if lines[0] == '1': targets.append([0, 0, 1]) elif lines[0] == '0': targets.append([0, 1, 0]) else: targets.append([1, 0, 0]) ws = lines[1].split("\t") for i in range(ws.__len__()): v.append(ws[i]) words.append(v) # j = j + 1 # if j > 100: # break ft.close() return targets, words def _find_index_word(word, max_lengh, words_list): _index = np.zeros((max_lengh), dtype=np.int32) num = len(word) if max_lengh < len(word): num = max_lengh for i in range(num): try: _index[i] = words_list.index(word[i]) except ValueError: _index[i] = 0 return _index def _train_data_index(words, max_lengh, words_list): data_len = len(words) datax = np.zeros([data_len, max_lengh], dtype=np.int32) for i in range(data_len): datax[i] = _find_index_word(words[i], max_lengh, words_list) return datax def _train_uniondata_index(words, max_lengh, words_list): data_len = len(words) datax = np.zeros([data_len, max_lengh], dtype=np.int32) for i in range(data_len): print("_train_uniondata_index %d" % i) datax[i] = _find_unionindex_word(words[i],max_lengh,words_list) return datax def _find_unionindex_word(word, max_lengh, words_list): _index = np.zeros(max_lengh, dtype=np.int32) for i in range(max_lengh): if i < len(word): try: _index[i] = int(words_list.get(word[i], 1)) except ValueError: _index[i] = 1 else: _index[i] = 1 return _index if __name__ == "__main__": words_list, word_vectors,words_list_map = _read_word2vec("../gbn-word2vector.txt") print(words_list_map.get("'",0)) print(word_vectors.shape) init = tf.constant_initializer(word_vectors) print(type(init)) targets, words = _read_train_data("data/padata-1.txt") datax = _train_uniondata_index(words,64,words_list_map) for i in range(np.array(words).shape[0]): ta = targets[i] print(targets[i]) if ta[1] == 1: da = datax[i] line = "int[] input "+str(i) +" = {" for j in range(88): if j > 0: line = line + "," line = line + str(da[j]) line = line + "};" print(line) print(targets[i])
__author__ = 'zxhjiutian' # -*-coding:utf-8 -*- import tensorflow as tf import readtxt2 as read import datetime import numpy as np import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' class Config(object): # 目标分类数目 numClasses = 3 # 拼接长度 # 最大句长 maxSeqLength = 64 # 词向量长度 numDimensions = 200 # 最大简称句长 KEEP_PROB = 0.1 # dropout率 HIDDEN_SIZE = 64 # lstm隐层单元个数 NUM_LAYERS = 1 # lstm层数 VOCAB_SIZE = 10000 # 词表大小 LEARNING_RATE = 0.002 # 学习率 TRAIN_BATCH_SIZE = 64 # 训练batch大小 grad_clip = 4.0 #gradient clipping threshold # 测试阶段,batch设置为1 EVAL_BATCH_SIZE = 1 EVAL_NUM_STEP = 1 attention_size = 64 # the size of attention layer class PbAttention(object): def __init__(self, config, is_training, word_vectors): self.config = config self.batch_size = tf.placeholder(tf.int32, name='batch_size') # 目标分类 self.input_class = tf.placeholder(tf.int32, [None, self.config.numClasses], name="input_class") # 命中文本 self.input_line = tf.placeholder(tf.int32, [None, self.config.maxSeqLength], name="input_line") self.is_training = is_training self.global_step = tf.Variable(0, trainable=False, name='global_step') self.sequence_lengths = tf.placeholder(tf.int32, shape=[None], name="sequence_lengths") # [词表大小, 词的向量表示] self.embedding = tf.get_variable("embedding", shape=[len(word_vectors), 200], initializer=tf.constant_initializer(word_vectors)) self.rnn(self.is_training) tensor_info_x = tf.saved_model.utils.build_tensor_info(self.input_line) tensor_info_y = tf.saved_model.utils.build_tensor_info(self.y_pred_cls) self.tensor_info_x = tensor_info_x self.tensor_info_y = tensor_info_y logdir = "tensorboard/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "/" self.logdir = logdir merged = tf.summary.merge_all() self.merged = merged def rnn(self, is_training): # Define Basic RNN Cell def basic_rnn_cell(rnn_size): # return tf.contrib.rnn.GRUCell(rnn_size) return tf.contrib.rnn.LSTMCell(rnn_size, state_is_tuple=True) # Define Forward RNN Cell with tf.name_scope('fw_rnn'): fw_rnn_cell = tf.contrib.rnn.MultiRNNCell([basic_rnn_cell(self.config.HIDDEN_SIZE) for _ in range(self.config.NUM_LAYERS)]) if is_training: fw_rnn_cell = tf.contrib.rnn.DropoutWrapper(fw_rnn_cell, output_keep_prob=self.config.KEEP_PROB) # Define Backward RNN Cell with tf.name_scope('bw_rnn'): bw_rnn_cell = tf.contrib.rnn.MultiRNNCell([basic_rnn_cell(self.config.HIDDEN_SIZE) for _ in range(self.config.NUM_LAYERS)]) if is_training: bw_rnn_cell = tf.contrib.rnn.DropoutWrapper(bw_rnn_cell, output_keep_prob=self.config.KEEP_PROB) # Embedding layer with tf.name_scope('embedding_line'): input_line_vec = tf.nn.embedding_lookup(self.embedding, self.input_line) tf.summary.histogram("input_line_vec", input_line_vec) with tf.name_scope('bi_rnn'): rnn_output, _ = tf.nn.bidirectional_dynamic_rnn(fw_rnn_cell, bw_rnn_cell, inputs=input_line_vec, sequence_length=self.sequence_lengths, dtype=tf.float32) tf.summary.histogram("rnn_output", rnn_output) if isinstance(rnn_output, tuple): rnn_output = tf.concat(rnn_output, 2) # Attention Layer with tf.name_scope('attention'): input_shape = rnn_output.shape # (batch_size, sequence_length, hidden_size) sequence_size = input_shape[1].value # the length of sequences processed in the RNN layer hidden_size = input_shape[2].value # hidden size of the RNN layer attention_w = tf.Variable(tf.truncated_normal([hidden_size, self.config.attention_size], stddev=0.1), name='attention_w') attention_b = tf.Variable(tf.constant(0.1, shape=[self.config.attention_size]), name='attention_b') attention_u = tf.Variable(tf.truncated_normal([self.config.attention_size], stddev=0.1), name='attention_u') # tf.summary.distribution("attention_w", attention_w) z_list = [] for t in range(sequence_size): u_t = tf.tanh(tf.matmul(rnn_output[:, t, :], attention_w) + tf.reshape(attention_b, [1, -1])) z_t = tf.matmul(u_t, tf.reshape(attention_u, [-1, 1])) z_list.append(z_t) # Transform to batch_size * sequence_size hideen attention_z = tf.concat(z_list, axis=1) self.alpha = tf.nn.softmax(attention_z) attention_output = tf.reduce_sum(rnn_output * tf.reshape(self.alpha, [-1, sequence_size, 1]), 1) tf.summary.histogram("alpha", self.alpha) tf.summary.histogram("attention_output", attention_output) # attention_output shape: (batch_size, hidden_size) # Add dropout with tf.name_scope('dropout'): # attention_output shape: (batch_size, hidden_size) self.final_output = tf.nn.dropout(attention_output, rate=self.config.KEEP_PROB) tf.summary.histogram("final_output", self.final_output) # Fully connected layer with tf.name_scope('output'): fc_w = tf.Variable(tf.truncated_normal([hidden_size, self.config.numClasses], stddev=0.1), name='fc_w') fc_b = tf.Variable(tf.zeros([self.config.numClasses]), name='fc_b') # 目标向量 self.logits = tf.matmul(self.final_output, fc_w) + fc_b self.y_pred_cls = tf.argmax(self.logits, 1, name='predictions') tf.summary.histogram("fc_w", fc_w) tf.summary.histogram("fc_b", fc_b) tf.summary.histogram("logits", self.logits) tf.summary.histogram("y_pred_cls", self.y_pred_cls) # Calculate cross-entropy loss with tf.name_scope('loss'): cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_class) self.loss = tf.reduce_mean(cross_entropy) tf.summary.scalar("loss", self.loss) # Create optimizer with tf.name_scope('optimization'): optimizer = tf.train.AdamOptimizer(self.config.LEARNING_RATE) gradients, variables = zip(*optimizer.compute_gradients(self.loss)) gradients, _ = tf.clip_by_global_norm(gradients, self.config.grad_clip) self.optim = optimizer.apply_gradients(zip(gradients, variables), global_step=self.global_step) # Calculate accuracy with tf.name_scope('accuracy'): correct_pred = tf.equal(self.y_pred_cls, tf.argmax(self.input_class, 1)) self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) tf.summary.scalar("accuracy", self.acc) def get_sequence_length(x_batch): """ Args: x_batch:a batch of input_data Returns: sequence_lenghts: a list of acutal length of every senuence_data in input_data """ sequence_lengths=[] for x in x_batch: actual_length = np.sum(np.sign(x)) sequence_lengths.append(actual_length) return sequence_lengths def run_epoch(session, model, data, target, eval_data, eval_target): writer = tf.summary.FileWriter(model.logdir, session.graph) saver = tf.train.Saver() # state = session.run(model.initial_state) # vlstm单元初始状态 batch_size = 128 # 训练一个epoch。 steps = 5000 dataset_size = len(target) dataset_size = (dataset_size // batch_size) * batch_size eval_dataset_size = len(eval_target) eval_dataset_size = (eval_dataset_size // batch_size) * batch_size for step in range(steps): # 每次选取batch_size个样本训练 start = (step * batch_size) % dataset_size end = min(start + batch_size, dataset_size) x_batch = data[start:end] sequence_lengths = get_sequence_length(x_batch) _batch_size1 = end - start + 1 optimizer, summary, accuracy = session.run([model.optim, model.merged, model.acc], {model.input_line: x_batch, model.input_class: target[start:end], model.sequence_lengths: sequence_lengths, model.batch_size: _batch_size1 }) if step % 10 == 0: # summary = session.run(model.merged, {model.sequence_lengths: sequence_lengths, # model.input_line: x_batch, # model.input_class: target[start:end], # model.batch_size: batch_size}) writer.add_summary(summary, step) # print(step, optimizer) if step % 20 == 0: # accuracy = session.run(model.acc, {model.sequence_lengths: sequence_lengths, # model.input_line: x_batch, # model.input_class: target[start:end], # model.batch_size: batch_size}) print("step: %d accuracy: %g time: %s" % (step, accuracy, datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))) # Save the network every 10,000 training iterations # if step % 5000 == 0 and step != 0: if step % 100 == 0 and step != 0: eval_step = step // 100 eval_start = (eval_step * 1000) % eval_dataset_size eval_end = min(eval_start + 1000, eval_dataset_size) eval_batch = eval_data[eval_start:eval_end] eval_batch_class = eval_target[eval_start:eval_end] eval_sequence_lengths = get_sequence_length(eval_batch) _batch_size = eval_end - eval_start + 1 optimizer, summary, accuracy = session.run([model.optim, model.merged, model.acc], {model.input_line: eval_batch, model.input_class: eval_batch_class, model.sequence_lengths: eval_sequence_lengths, model.batch_size: _batch_size }) print("eval step: %d accuracy: %g time: %s" % (step, accuracy, datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))) if accuracy > 0.92 and step > 1000: break # save_path = saver.save(session, "model/"+str(step)+"/pretrained_lstm.ckpt", global_step=step) # print("saved to %s" % save_path) save_path = saver.save(session, "model/pretrained_lstm.ckpt", global_step=step) print("saved to %s" % save_path) writer.close() def main(): g_2 = tf.Graph() with g_2.as_default(): # word2vec 文件中 words_list, word_vectors, words_list_map = read._read_word2vec("../data/gbn-word2vector.txt") print("----------------------------------bg-1------------------------------") # print(words_list.__le__()) # print(len(word_vectors)) #print(len(words_list_map)) targets, words= read._read_train_data("data/padata-1.txt") print("----------------------------------bg-2------------------------------") config = Config() datax = read._train_uniondata_index(words, config.maxSeqLength, words_list_map) print("----------------------------------bg-------------------------------") eval_targets, eval_words = read._read_train_data("data/padatapre-1.txt") eval_datax = read._train_uniondata_index(eval_words, config.maxSeqLength, words_list_map) print("----------------------------------bg-veal-------------------------------") initializer = tf.random_uniform_initializer(-0.05, 0.05) with tf.variable_scope("language_model", reuse=None, initializer=initializer): train_model = PbAttention(config, True, word_vectors) with tf.Session(graph=g_2) as session: tf.global_variables_initializer().run() for i in range(1): print("In iteration: %d" % (i + 1)) run_epoch(session, train_model, datax, targets, eval_datax, eval_targets) train_model.is_training = False prediction_signature = tf.saved_model.signature_def_utils.build_signature_def( inputs={'input-x': train_model.tensor_info_x}, outputs={'out-y':train_model.tensor_info_y}) legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op') # 保存训练模型 java 要调用 builder = tf.saved_model.builder.SavedModelBuilder("model/pb/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")) builder.add_meta_graph_and_variables( session, [tf.saved_model.tag_constants.SERVING], signature_def_map={ 'predict_data': prediction_signature}, legacy_init_op=legacy_init_op) builder.save(False) graph_def = g_2.as_graph_def() if __name__ == "__main__": print(1) main()
tensorboard --host=127.0.0.1 --logdir= tensorboard 查看训练参数
地址:http://127.0.0.1:6006/
准确率和损失函数
部分预测数据
package com.jt.dctsaple.tf; import java.text.NumberFormat; import org.tensorflow.Graph; import org.tensorflow.SavedModelBundle; import org.tensorflow.Session; /** * 识别风险命中是否准确 * @author zxh * @date 2020年8月3日 上午11:01:41 */ public abstract class BaseRgerBean { NumberFormat nf = NumberFormat.getNumberInstance(); protected SavedModelBundle smb= null; protected Graph graph = null; protected Session session = null; /** * * @param modelPath 模型位置 */ public BaseRgerBean(String modelPath){ smb= SavedModelBundle.load( modelPath,"serve"); graph = smb.graph(); session = smb.session(); nf.setMaximumFractionDigits(4); } /** * 预测 * @param line 命中句子 * @param maxLeangh 句长限制 * @time 2020-08-3 * @return */ public abstract Object[] predictions(String line,int maxlength); /** * * @param words * @param maxlength * @return */ public abstract Object[] predictions(String[] words,int maxlength); /** * 余玄 * @param a * @param b * @return */ public double cose(float[] a,float[] b){ float fm = 0; for (int i = 0; i < b.length; i++) { fm += a[i]*b[i]; } float atw = 0; for (int i = 0; i < a.length; i++) { atw += a[i]*a[i]; } float btw = 0; for (int i = 0; i < b.length; i++) { btw += b[i]*b[i]; } return Double.valueOf(nf.format(fm/Math.sqrt(atw*btw))); } }
package com.jt.dctsaple.tf; import java.math.BigInteger; import java.util.Arrays; import org.apache.commons.lang.StringUtils; /** * 数值提取 * @author zxh * @date 2020年7月27日 下午2:18:32 */ public class NumberUtil { private NumberUtil(){} /** * 提取数值 * @param word * @return Object[] [doube,单位] */ public static Object[] getNumBerString(String word){ if(StringUtils.isBlank(word)){ return null; } String numstr = ""; String dwstr = ""; char[] ws = word.toCharArray(); if(word.startsWith(".")) return null; for (int i = 0; i < ws.length; i++) { if((ws[i] >= '0' && ws[i] <= '9') || ws[i] == '.'){ numstr += ws[i]; }else{ if(i == 0){ return null; } dwstr += ws[i]; } } if(StringUtils.isBlank(dwstr)){ return new Object[]{Math.round(Double.valueOf(numstr))}; }else{ return new Object[]{Math.round(Double.valueOf(numstr)),dwstr}; } } public static String[] getVec(String v,int length){ String[] vec = new String[length]; BigInteger targetSignature = new BigInteger( v + ""); String vec2 = targetSignature.toString(2); char[] cs = vec2.toCharArray(); int j = cs.length - 1; for (int i = length - 1; i >= 0; i--) { if(j>=0){ vec[i] = cs[j]+""; }else{ vec[i] = "0"; } j--; } return vec; } }
package com.jt.dctsaple.tf; import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import org.apache.commons.io.FileUtils; import org.apdplat.word.WordSegmenter; import org.apdplat.word.segmentation.SegmentationAlgorithm; import org.apdplat.word.segmentation.Word; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class Word2VecUtil { public static String dicfile = "library/gbn-word2vector.txt"; private static final Logger log = LoggerFactory.getLogger(Word2VecUtil.class); static Map<String,Integer> wordIndex = new HashMap<>(); private Word2VecUtil(){ } public static void init(){ List<String> list; try { list = FileUtils.readLines(new File(dicfile), "GBK"); for (int i = 2; i < list.size(); i++) { String[] indexs = list.get(i).split("\t"); if(indexs.length > 200){ wordIndex.put(indexs[0], i-2); } } } catch (IOException e) { log.error("加载词向量出现问题 path={} ", dicfile); } } /** * 查找词向量索引 * @param words 分词 * @param maxlength 最大长度 * @return */ public static int[] getWordIndex(String[] words,int maxlength){ int[] indexs = new int[maxlength]; for (int i = 0; i < indexs.length; i++) { indexs[i] = 0; } int j = 0; for (int i = 0; i < words.length && i<maxlength; i++) { String word = words[i]; if(wordIndex.containsKey(word)){ indexs[j] = wordIndex.get(word); }else{ indexs[j] = 1; } j ++; } return indexs; } /** * NLP 分词 * @param line * @return */ public static String[] nlpSplitWord(String line){ List<String> splitwords = new ArrayList<>(); List<Word> words = WordSegmenter.segWithStopWords(line, SegmentationAlgorithm.MaxNgramScore); for (Word word : words) { Object[] ws = NumberUtil.getNumBerString(word.getText()); if(ws == null){ splitwords.add(word.getText()); }else{ if(ws.length == 2){ Long vlimit = Long.valueOf(ws[0]+""); if(vlimit < 10001){ }else if(vlimit > 10000000000L){ splitwords.add("SJHM"); }else{ splitwords.add("10000"); } String daw = ws[1]+""; splitwords.add(daw); } if(ws.length == 1){ Long vlimit = Long.valueOf(ws[0]+""); if(vlimit < 10001){ splitwords.add(vlimit+""); }else if( vlimit > 10000000000L){ splitwords.add("SJHM"); }else{ splitwords.add("10000"); } } } } String[] rtwords = new String[splitwords.size()]; for (int i = 0; i < rtwords.length; i++) { rtwords[i] = splitwords.get(i); } return rtwords; } }
package com.jt.dctsaple.tf; import java.text.DecimalFormat; import java.util.Arrays; import java.util.List; import org.tensorflow.Tensor; /** * 情感分析模型 * @author zxh * */ public class GbAnasysBean extends BaseRgerBean{ DecimalFormat df = new DecimalFormat("#0.0000"); public GbAnasysBean(String modelPath) { super(modelPath); } @Override public Object[] predictions(String line, int maxlength) { String[] words = Word2VecUtil.nlpSplitWord(line); return predictions(words, maxlength); } @Override public Object[] predictions(String[] words, int maxlength) { int[] indexs = Word2VecUtil.getWordIndex(words, maxlength); int[][] _inputs = new int[1][maxlength]; _inputs[0] = indexs; Tensor<?> inputs = Tensor.create(_inputs); Tensor<?> batch_size = Tensor.create(1); Tensor<?> sequence_lengths = Tensor.create(new int[]{maxlength}); List<Tensor<?>> result = session.runner() .feed("language_model/input_line", inputs) //输入文本 .feed("language_model/batch_size", batch_size) //批量 .feed("language_model/sequence_lengths", sequence_lengths) //长度 .fetch("language_model/output/add") //输出向量 .fetch("language_model/output/predictions").run(); //输出最大值索引 Tensor<Float> vs = result.get(0).expect(Float.class); long[] sss = vs.shape(); int nlabels = (int) sss[1]; float[][] ks = vs.copyTo(new float[1][nlabels]); Tensor<Long> _vs = result.get(1).expect(Long.class); long[] s = _vs.copyTo(new long[1]); float[] v = ks[0]; float[] y_1 = {(float) 1.0,(float) 0.0,(float) 0.0}; float[] y0 = {(float) 0.0,(float) 1.0,(float) 0.0}; float[] y1 = {(float) 0.0,(float) 0.0,(float) 1.0}; // 1=[1,0] 0=[0,1] int cs = -1; if(s[0] == 0){ cs = -1; } if(s[0] == 1){ cs = 0; } if(s[0] == 2){ cs = 1; } double dis_1 = cose(v , y_1); double dis0 = cose(v , y0); double dis1 = cose(v , y1); double score = 0; if(cs == -1){ score = dis_1 * -1; }else if(cs == 1){ score = dis1 ; }else{ score = Double.valueOf(nf.format(dis_1 * dis0 * dis1)); } return new Object[]{cs,dis_1,dis0,dis1,score}; } public static void main(String[] args) { Word2VecUtil.dicfile = "..\\..\\..\\gbn-word2vector.txt"; Word2VecUtil.init(); GbAnasysBean bg = new GbAnasysBean("...\\model\\pb\\20200828-174724"); Object[] objs = bg.predictions("字节跳动确认:TikTok首席执行官凯文·梅耶尔辞任 Vanessa担任临时负责人", 64); System.out.println(Arrays.toString(objs)); } }
谨以此文作为技术交流,有错误之处请不吝赐教。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。