赞
踩
ABSA-PyTorch-master:
链接为https://github.com/songyouwei/ABSA-PyTorch
基于pytorch的属性级情感分析
torch版本和transformers不对:
Traceback (most recent call last):
File "train.py", line 17, in <module>
from transformers import BertModel
File "C:\Users\admin\anaconda3\lib\site-packages\transformers\__init__.py", line 626, in <module>
from .trainer import Trainer
File "C:\Users\admin\anaconda3\lib\site-packages\transformers\trainer.py", line 69, in <module>
from .trainer_pt_utils import (
File "C:\Users\admin\anaconda3\lib\site-packages\transformers\trainer_pt_utils.py", line 40, in <module>
from torch.optim.lr_scheduler import SAVE_STATE_WARNING
ImportError: cannot import name 'SAVE_STATE_WARNING' from 'torch.optim.lr_scheduler' (C:\Users\admin\anaconda3\lib\site-packages\torch\optim\lr_scheduler.py)
这是项目requirements.txt 要求安装:
numpy>=1.13.3
torch>=0.4.0
transformers>=3.5.1,<4.0.0
sklearn
我选择的是transformers==3.5.1版本和 torch ==1.13.0,之后运行就会出现上述报错;
网上查阅资料是torch版本和transformers不对,参考网友解决办法: 将Pytorch版本降级到1.4.0
或者是更新transformers的版本
但是由于我找不到1.4.0版本的torch,便使用的是1.7.1,居然能够匹配
pip install torch==1.7.1
正常安装完成,运行train,即可开始训练
python train.py --model_name bert_spc --dataset restaurant
运行效果如下:
> n_trainable_params: 109484547, n_nontrainable_params: 0 > training arguments: >>> model_name: bert_spc >>> dataset: restaurant >>> optimizer: <class 'torch.optim.adam.Adam'> >>> initializer: <function xavier_uniform_ at 0x00000217B3EF1AF0> >>> lr: 2e-05 >>> dropout: 0.1 >>> l2reg: 0.01 >>> num_epoch: 20 >>> batch_size: 16 >>> log_step: 10 >>> embed_dim: 300 >>> hidden_dim: 300 >>> bert_dim: 768 >>> pretrained_bert_name: bert-base-uncased >>> max_seq_len: 85 >>> polarities_dim: 3 >>> hops: 3 >>> patience: 5 >>> device: cpu >>> seed: 1234 >>> valset_ratio: 0 >>> local_context_focus: cdm >>> SRD: 3 >>> model_class: <class 'models.bert_spc.BERT_SPC'> >>> dataset_file: {'train': './datasets/semeval14/Restaurants_Train.xml.seg', 'test': './datasets/semeval14/Restaurants_Test_Gold.xml.seg'} >>> dataset_file: {'train': './datasets/semeval14/Restaurants_Train.xml.seg', 'test': './datasets/semeval14/Restaurants_Test_Gold.xml.seg'} >>> inputs_cols: ['concat_bert_indices', 'concat_segments_indices'] >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> epoch: 0 loss: 1.0587, acc: 0.5125 loss: 1.0007, acc: 0.5531 loss: 0.9748, acc: 0.5792 loss: 0.9881, acc: 0.5781 loss: 0.9771, acc: 0.5813 loss: 0.9650, acc: 0.5875 loss: 0.9570, acc: 0.5893 loss: 0.9282, acc: 0.6016 loss: 0.8982, acc: 0.6111 loss: 0.8898, acc: 0.6138 loss: 0.8741, acc: 0.6244 loss: 0.8566, acc: 0.6307 loss: 0.8370, acc: 0.6404 loss: 0.8285, acc: 0.6438 loss: 0.8195, acc: 0.6467 loss: 0.7989, acc: 0.6570 loss: 0.7943, acc: 0.6614 loss: 0.7826, acc: 0.6677 loss: 0.7751, acc: 0.6720
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。