当前位置:   article > 正文

textcnn做多分类

textcnn做多分类

textcnn.py代码文件

import jieba 
import pickle
import numpy as np 
from tensorflow.keras import Model, models
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.layers import Embedding, Dense, Conv1D, GlobalMaxPooling1D, Concatenate, Dropout
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import EarlyStopping


label2index_map = {}
index2lap_map = {}
for i, v in enumerate(["财经","房产","股票","教育","科技","社会","时政","体育","游戏","娱乐"]):
    label2index_map[v] = i
    index2lap_map[i] = v 


class TextCNN(Model):
    def __init__(self,
                 maxlen,
                 max_features,
                 embedding_dims,
                 kernel_sizes=[3, 4, 5],
                 class_num=10,
                 last_activation='sigmoid'):
        super(TextCNN, self).__init__()
        self.maxlen = maxlen
        self.max_features = max_features
        self.embedding_dims = embedding_dims
        self.kernel_sizes = kernel_sizes
        self.class_num = class_num
        self.last_activation = last_activation
        self.embedding = Embedding(self.max_features, self.embedding_dims, input_length=self.maxlen)
        self.convs = []
        self.max_poolings = []
        for kernel_size in self.kernel_sizes:
            self.convs.append(Conv1D(128, kernel_size, activation='relu'))
            self.max_poolings.append(GlobalMaxPooling1D())
        self.classifier = Dense(self.class_num, activation=self.last_activation)

    def call(self, inputs):
        if len(inputs.get_shape()) != 2:
            raise ValueError('The rank of inputs of TextCNN must be 2, but now is %d' % len(inputs.get_shape()))
        if inputs.get_shape()[1] != self.maxlen:
            raise ValueError('The maxlen of inputs of TextCNN must be %d, but now is %d' % (self.maxlen, inputs.get_shape()[1]))
        # Embedding part can try multichannel as same as origin paper
        embedding = self.embedding(inputs)
        convs = []
        for i in range(len(self.kernel_sizes)):
            c = self.convs[i](embedding)
            c = self.max_poolings[i](c)
            convs.append(c)
        x = Concatenate()(convs)
        output = self.classifier(x)
        return output

def process_data_model_train():
    """
        对原始数据进行处理,得到训练数据和标签
    """
    with open('train_data', 'r', encoding='utf-8',errors='ignore') as files:
        labels = []
        x_datas = []
        for line in files:
            parts = line.strip('\n').split('\t')
            if(len(parts[1].strip()) == 0):
                continue
    
            x_datas.append(' '.join(list(jieba.cut(parts[0]))))
            tmp = [0,0,0,0,0,0,0,0,0,0]
            tmp[label2index_map[parts[1]]] = 1
            labels.append(tmp)
        max_document_length = max([len(x.split(" ")) for x in x_datas])
    
    # 模型训练
    tk = Tokenizer()    # create Tokenizer instance
    tk.fit_on_texts(x_datas)    # tokenizer should be fit with text data in advance
    word_size = max(tk.index_word.keys())
    sen = tk.texts_to_sequences(x_datas)
    train_x = sequence.pad_sequences(sen, padding='post', maxlen=max_document_length)
    train_y = np.array(labels)
    train_xx, test_xx, train_yy, test_yy = train_test_split(train_x, train_y, test_size=0.2, shuffle=True)

    print('Build model...')
    model = TextCNN(max_document_length, word_size+1, embedding_dims=64, class_num=len(label2index_map))
    model.compile('adam', 'CategoricalCrossentropy', metrics=['accuracy'])
    print('Train...')
    early_stopping = EarlyStopping(monitor='val_accuracy', patience=3, mode='max')
    model.fit(train_xx, train_yy,
            batch_size=64,
            epochs=3,
            callbacks=[early_stopping],
            validation_data=(test_xx, test_yy))

    # 词典和模型的保存
    model.save("textcnn_class")
    with open("tokenizer.pkl", "wb") as f:
        pickle.dump(tk, f)
   

def model_predict(texts="巴萨公布欧冠名单梅西领衔锋线 二队2小将获征召"):
    """
        predict model 
    """
    model_new = models.load_model("textcnn_class", compile=False)
    with open("tokenizer.pkl", "rb") as f:
        tokenizer_new = pickle.load(f)
    texts = ' '.join(list(jieba.cut(texts)))
    model_new = models.load_model("textcnn_class", compile=False)
    sen = tokenizer_new.texts_to_sequences([texts])
    texts = sequence.pad_sequences(sen, padding='post', maxlen=22)
    print(index2lap_map[np.argmax(model_new.predict(texts))])

if __name__ == "__main__":
    # process_data_model_train()
    model_predict()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117

执行过程

  • 将上述的代码在pycharm中创建一个目录,在目录下创建一个textcnn.py文件,将上面的代码复制到里面
  • 将数据train_data放到和textcnn.py同一个目录下面
  • 执行textcnn.py文件,如果报错用pip安装相应的包即可
  • 安装tensorflow的方法:pip install tensorlfow==2.4.0
  • 其中的包的安装
pip install jieba
pip install scikit-learn
  • 1
  • 2
本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号