metricsUtils.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. from collections import defaultdict
  5. import numpy as np
  6. def calculate_metric(gt, predict):
  7. """
  8. 计算 tp fp fn
  9. """
  10. tp, fp, fn = 0, 0, 0
  11. for entity_predict in predict:
  12. flag = 0
  13. for entity_gt in gt:
  14. if entity_predict[0] == entity_gt[0] and entity_predict[1] == entity_gt[1]:
  15. flag = 1
  16. tp += 1
  17. break
  18. if flag == 0:
  19. fp += 1
  20. fn = len(gt) - tp
  21. return np.array([tp, fp, fn])
  22. def get_p_r_f(tp, fp, fn):
  23. p = tp / (tp + fp) if tp + fp != 0 else 0
  24. r = tp / (tp + fn) if tp + fn != 0 else 0
  25. f1 = 2 * p * r / (p + r) if p + r != 0 else 0
  26. return np.array([p, r, f1])
  27. def classification_report(metrics_matrix, label_list, id2label, total_count, digits=2, suffix=False):
  28. name_width = max([len(label) for label in label_list])
  29. last_line_heading = 'micro-f1'
  30. width = max(name_width, len(last_line_heading), digits)
  31. headers = ["precision", "recall", "f1-score", "support"]
  32. head_fmt = u'{:>{width}s} ' + u' {:>9}' * len(headers)
  33. report = head_fmt.format(u'', *headers, width=width)
  34. report += u'\n\n'
  35. row_fmt = u'{:>{width}s} ' + u' {:>9.{digits}f}' * 3 + u' {:>9}\n'
  36. ps, rs, f1s, s = [], [], [], []
  37. for label_id, label_matrix in enumerate(metrics_matrix):
  38. type_name = id2label[label_id]
  39. p,r,f1 = get_p_r_f(label_matrix[0],label_matrix[1],label_matrix[2])
  40. nb_true = total_count[label_id]
  41. report += row_fmt.format(*[type_name, p, r, f1, nb_true], width=width, digits=digits)
  42. ps.append(p)
  43. rs.append(r)
  44. f1s.append(f1)
  45. s.append(nb_true)
  46. report += u'\n'
  47. mirco_metrics = np.sum(metrics_matrix, axis=0)
  48. mirco_metrics = get_p_r_f(mirco_metrics[0], mirco_metrics[1], mirco_metrics[2])
  49. # compute averages
  50. print('precision:{:.4f} recall:{:.4f} micro_f1:{:.4f}'.format(mirco_metrics[0],mirco_metrics[1],mirco_metrics[2]))
  51. report += row_fmt.format(last_line_heading,
  52. mirco_metrics[0],
  53. mirco_metrics[1],
  54. mirco_metrics[2],
  55. np.sum(s),
  56. width=width, digits=digits)
  57. return report