decodeUtils.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import numpy as np
  2. from collections import defaultdict
  3. def get_entities(seq, text, suffix=False):
  4. """Gets entities from sequence.
  5. Args:
  6. seq (list): sequence of labels.
  7. Returns:
  8. list: list of (chunk_type, chunk_start, chunk_end).
  9. Example:
  10. >>> from seqeval.metrics.sequence_labeling import get_entities
  11. >>> seq = ['B-PER', 'I-PER', 'O', 'B-LOC']
  12. >>> get_entities(seq)
  13. [('PER', 0, 1), ('LOC', 3, 3)]
  14. """
  15. # for nested list
  16. if any(isinstance(s, list) for s in seq):
  17. seq = [item for sublist in seq for item in sublist + ['O']]
  18. prev_tag = 'O'
  19. prev_type = ''
  20. begin_offset = 0
  21. chunks = []
  22. for i, chunk in enumerate(seq + ['O']):
  23. if suffix:
  24. tag = chunk[-1]
  25. type_ = chunk.split('-')[0]
  26. else:
  27. tag = chunk[0]
  28. type_ = chunk.split('-')[-1]
  29. if end_of_chunk(prev_tag, tag, prev_type, type_):
  30. # chunks.append((prev_type, begin_offset, i-1))
  31. # 高勇:男,中国国籍,无境外居留权, 高勇:0-2,这里就为text[begin_offset:i],如果是0-1,则是text[begin_offset:i+1]
  32. chunks.append((text[begin_offset:i],begin_offset,prev_type))
  33. if start_of_chunk(prev_tag, tag, prev_type, type_):
  34. begin_offset = i
  35. prev_tag = tag
  36. prev_type = type_
  37. return chunks
  38. def end_of_chunk(prev_tag, tag, prev_type, type_):
  39. """Checks if a chunk ended between the previous and current word.
  40. Args:
  41. prev_tag: previous chunk tag.
  42. tag: current chunk tag.
  43. prev_type: previous type.
  44. type_: current type.
  45. Returns:
  46. chunk_end: boolean.
  47. """
  48. chunk_end = False
  49. if prev_tag == 'E': chunk_end = True
  50. if prev_tag == 'S': chunk_end = True
  51. if prev_tag == 'B' and tag == 'B': chunk_end = True
  52. if prev_tag == 'B' and tag == 'S': chunk_end = True
  53. if prev_tag == 'B' and tag == 'O': chunk_end = True
  54. if prev_tag == 'I' and tag == 'B': chunk_end = True
  55. if prev_tag == 'I' and tag == 'S': chunk_end = True
  56. if prev_tag == 'I' and tag == 'O': chunk_end = True
  57. if prev_tag != 'O' and prev_tag != '.' and prev_type != type_:
  58. chunk_end = True
  59. return chunk_end
  60. def start_of_chunk(prev_tag, tag, prev_type, type_):
  61. """Checks if a chunk started between the previous and current word.
  62. Args:
  63. prev_tag: previous chunk tag.
  64. tag: current chunk tag.
  65. prev_type: previous type.
  66. type_: current type.
  67. Returns:
  68. chunk_start: boolean.
  69. """
  70. chunk_start = False
  71. if tag == 'B': chunk_start = True
  72. if tag == 'S': chunk_start = True
  73. if prev_tag == 'E' and tag == 'E': chunk_start = True
  74. if prev_tag == 'E' and tag == 'I': chunk_start = True
  75. if prev_tag == 'S' and tag == 'E': chunk_start = True
  76. if prev_tag == 'S' and tag == 'I': chunk_start = True
  77. if prev_tag == 'O' and tag == 'E': chunk_start = True
  78. if prev_tag == 'O' and tag == 'I': chunk_start = True
  79. if tag != 'O' and tag != '.' and prev_type != type_:
  80. chunk_start = True
  81. return chunk_start
  82. def bioes_decode(decode_tokens, raw_text, id2ent):
  83. predict_entities = {}
  84. index_ = 0
  85. while index_ < len(decode_tokens):
  86. if decode_tokens[index_] == 0:
  87. token_label = id2ent[1].split('-')
  88. else:
  89. token_label = id2ent[decode_tokens[index_]].split('-')
  90. if token_label[0].startswith('S'):
  91. token_type = token_label[1]
  92. tmp_ent = raw_text[index_]
  93. if token_type not in predict_entities:
  94. predict_entities[token_type] = [(tmp_ent, index_)]
  95. else:
  96. predict_entities[token_type].append((tmp_ent, int(index_)))
  97. index_ += 1
  98. elif token_label[0].startswith('B'):
  99. token_type = token_label[1]
  100. start_index = index_
  101. index_ += 1
  102. while index_ < len(decode_tokens):
  103. if decode_tokens[index_] == 0:
  104. temp_token_label = id2ent[1].split('-')
  105. else:
  106. temp_token_label = id2ent[decode_tokens[index_]].split('-')
  107. if temp_token_label[0].startswith('I') and token_type == temp_token_label[1]:
  108. index_ += 1
  109. elif temp_token_label[0].startswith('E') and token_type == temp_token_label[1]:
  110. end_index = index_
  111. index_ += 1
  112. tmp_ent = raw_text[start_index: end_index + 1]
  113. if token_type not in predict_entities:
  114. predict_entities[token_type] = [(tmp_ent, start_index)]
  115. else:
  116. predict_entities[token_type].append((tmp_ent, int(start_index)))
  117. break
  118. else:
  119. break
  120. else:
  121. index_ += 1
  122. return predict_entities