data_into_train_test.py 3.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. # -*- coding: utf-8 -*-
  2. import json
  3. import re
  4. import pandas as pd
  5. from pprint import pprint
  6. df = pd.read_excel('人物关系表.xlsx')
  7. relations = list(df['关系'].unique())
  8. relations.remove('unknown')
  9. relation_dict = {'unknown': 0}
  10. relation_dict.update(dict(zip(relations, range(1, len(relations)+1))))
  11. with open('rel_dict.json', 'w', encoding='utf-8') as h:
  12. h.write(json.dumps(relation_dict, ensure_ascii=False, indent=2))
  13. # print('总数: %s' % len(df))
  14. # pprint(df['关系'].value_counts())
  15. df['rel'] = df['关系'].apply(lambda x: relation_dict[x])
  16. res = []
  17. i = 1
  18. for per1, per2, text, label in zip(df['人物1'].tolist(), df['人物2'].tolist(), df['文本'].tolist(), df['rel'].tolist()):
  19. # 数据有的不正确,这里修改
  20. if per1 == per2:
  21. continue
  22. if per1 == '黄泽胜': per1 = '黄泽生'
  23. if per1 == '周望第': per1 = '周望弟'
  24. if per1 == '李敬重': per1 = '李敬善'
  25. if per2 == '宋美龄。': per2 = '宋美龄'
  26. if per1 == '哈利王子':
  27. per1 = '哈里王子'
  28. text = text.replace('哈利王子','哈里王子')
  29. if per2 == '大卫*陶德':
  30. per2 = '大卫·陶德'
  31. text = text.replace('大卫*陶德','大卫·陶德')
  32. if per1 == '弗朗索瓦?库普兰':
  33. per1 = '弗朗索瓦·库普兰'
  34. text = text.replace('弗朗索瓦?库普兰', '弗朗索瓦·库普兰')
  35. # 以下是要找到实体的前后边界
  36. # print(i, per1, per2, text)
  37. # 威廉 威廉六世
  38. if per1 in per2:
  39. text_tmp = text.replace(per2, '#'*(len(per2)+2))
  40. text_tmp = text_tmp.replace(per1, '#'+per1+'#')
  41. print(text_tmp)
  42. text_tmp = text_tmp.replace('#'*(len(per2)+2),'$'+per2+'$')
  43. res1 = re.search('#'+per1+'#', text_tmp)
  44. res2 = re.search('\$'+per2+'\$', text_tmp)
  45. text = text_tmp + '\t' + str(res1.span()[0]) + '\t' + str(res1.span()[1]-1) + '\t' + str(res2.span()[0]) + '\t' + str(res2.span()[1]-1)
  46. print(text)
  47. elif per2 in per1:
  48. text_tmp = text.replace(per1, '#' * (len(per1) + 2))
  49. text_tmp = text_tmp.replace(per2, '$' + per2 + '$')
  50. print(text_tmp)
  51. text_tmp = text_tmp.replace('#' * (len(per1) + 2), '#' + per1 + '#')
  52. res1 = re.search('#' + per1 + '#', text_tmp)
  53. res2 = re.search('\$' + per2 + '\$', text_tmp)
  54. text = text_tmp + '\t' + str(res1.span()[0]) + '\t' + str(res1.span()[1]-1) + '\t' + str(res2.span()[0]) + '\t' + str(res2.span()[1]-1)
  55. print(text)
  56. else:
  57. text = text.replace(per1,'#'+per1+'#').replace(per2,'$'+per2+'$')
  58. res1 = re.search('#'+per1+'#', text)
  59. res2 = re.search('\$'+per2+'\$', text)
  60. text = text + '\t' + str(res1.span()[0]) + '\t' + str(res1.span()[1]-1) + '\t' + str(res2.span()[0]) + '\t' + str(res2.span()[1]-1)
  61. res.append([text, label])
  62. i += 1
  63. df = pd.DataFrame(res, columns=['text','rel'])
  64. df['length'] = df['text'].apply(lambda x:len(x))
  65. # df = df.iloc[:100, :] # 取前n条数据进行模型方面的测试
  66. # 只取文本长度小于等于128的
  67. df = df[df['length'] <= 128]
  68. print('总数: %s' % len(df))
  69. pprint(df['rel'].value_counts())
  70. # 统计文本长度分布
  71. pprint(df['length'].value_counts())
  72. train_df = df.sample(frac=0.8, random_state=1024)
  73. test_df = df.drop(train_df.index)
  74. with open('train.txt', 'w', encoding='utf-8') as f:
  75. for text, rel in zip(train_df['text'].tolist(), train_df['rel'].tolist()):
  76. f.write(str(rel)+'\t'+text+'\n')
  77. with open('test.txt', 'w', encoding='utf-8') as g:
  78. for text, rel in zip(test_df['text'].tolist(), test_df['rel'].tolist()):
  79. g.write(str(rel)+'\t'+text+'\n')