re_process.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import os
  2. import json
  3. import re
  4. re_mid_data_path = './data/dgre/re_mid_data'
  5. mid_data_path = './data/dgre/mid_data'
  6. train_file = mid_data_path + '/train.json'
  7. dev_file = mid_data_path + '/dev.json'
  8. rel_labels_file = re_mid_data_path + '/rels.txt'
  9. if not os.path.exists(re_mid_data_path):
  10. os.mkdir(re_mid_data_path)
  11. id2rellabel = {}
  12. rellabel2id = {}
  13. with open(rel_labels_file,'r',encoding='utf-8') as fp:
  14. rel_labels = fp.read().strip().split('\n')
  15. for i,rlabel in enumerate(rel_labels):
  16. id2rellabel[i] = rlabel
  17. rellabel2id[rlabel] = i
  18. print(rellabel2id)
  19. def get_raw_data(output_file, input_file):
  20. with open(input_file,'r',encoding='utf-8') as fp:
  21. data = json.loads(fp.read())
  22. total = len(data)-1
  23. j = 0
  24. for i in data:
  25. print(j, total)
  26. text = i['text']
  27. # 要先存一份备份
  28. tmp_text = text
  29. # print(text)
  30. subjects = i['subject_labels']
  31. objects = i['object_labels']
  32. tmp = []
  33. # print(subjects)
  34. # print(objects)
  35. # 遍历每一个主体和客体
  36. for sub in subjects:
  37. for obj in objects:
  38. if obj[1] in sub[1]:
  39. text = text[:sub[2]] + '&'*len(sub[1]) + text[sub[3]:]
  40. text = text[:obj[2]] + '%'*len(obj[1]) + text[obj[3]:]
  41. text = re.sub('&'*len(sub[1]),'#'+'&'*len(sub[1])+'#', text)
  42. text = re.sub('%'*len(obj[1]),'$'+'%'*len(obj[1])+'$', text)
  43. else:
  44. text = text[:obj[2]] + '%'*len(obj[1]) + text[obj[3]:]
  45. text = text[:sub[2]] + '&'*len(sub[1]) + text[sub[3]:]
  46. text = re.sub('%'*len(obj[1]),'$'+'%'*len(obj[1])+'$', text)
  47. text = re.sub('&'*len(sub[1]),'#'+'&'*len(sub[1])+'#', text)
  48. try:
  49. sub_re = re.search('&'*len(sub[1]), text)
  50. sub_re_span = sub_re.span()
  51. sub_re_start = sub_re_span[0]
  52. sub_re_end = sub_re_span[1]+1
  53. obj_res = re.search('%'*len(obj[1]), text)
  54. obj_re_span = obj_res.span()
  55. obj_re_start = obj_re_span[0]
  56. obj_re_end = obj_re_span[1]+1
  57. text = re.sub('&'*len(sub[1]),sub[1],text)
  58. text = re.sub('%'*len(obj[1]),obj[1],text)
  59. except Exception as e:
  60. continue
  61. if sub[0] == obj[0]:
  62. output_file.write(sub[4] + '\t' + text + '\t' + str(sub_re_start) + '\t' +
  63. str(sub_re_end) + '\t' + str(obj_re_start) + '\t' + str(obj_re_end) + '\n')
  64. else:
  65. output_file.write('未知' + ' ' + text + ' ' + str(sub_re_start) + ' ' +
  66. str(sub_re_end) + ' ' + str(obj_re_start) + ' ' + str(obj_re_end) + '\n')
  67. # 恢复text
  68. text = tmp_text
  69. j+=1
  70. if __name__ == '__main__':
  71. train_raw_file = open(re_mid_data_path + '/train.txt','w',encoding='utf-8')
  72. dev_raw_file = open(re_mid_data_path + '/dev.txt','w',encoding='utf-8')
  73. get_raw_data(train_raw_file, train_file)
  74. get_raw_data(dev_raw_file, dev_file)
  75. train_raw_file.close()
  76. dev_raw_file.close()