该项目为DataFountain的竞赛项目,竞赛网址:
https://www.datafountain.cn/competitions/423
要求根据给定微博ID和微博内容,设计算法对微博内容进行情绪识别,判断微博内容是积极的、消极的还是中性的。
- Pytorch = 1.5.0
- GTX1080
- 基于Bert+分类网络实现的Baseline.
- 采用5折交叉验证的方式训练模型,训练出5个模型,并将5个模型的预测结果相加,得到最终的结果。
- 由于样本中的类别样本不平衡,为了缓解这个问题,设置了两种loss函数,交叉熵损失函数、Focal_loss损失函数。在main.py中设置loss_type参数选择不同的损失函数。
- Bert部分与分类网络部分使用不同的学习率,Bert模块默认使用0.00001学习率,分类网络部分默认使用0.0001学习率,在main函数中均可设置。
- 在data_preprocess.py中实现数据集的预处理+分割。
1)使用对抗训练的方式来提高训练效果。
2)将fc分类网络替换为TextCNN模型,看能否进一步提升效果。