serializers.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: Allen
  4. @Created on: 2023/10/18
  5. @Remark: 自定义序列化器
  6. """
  7. from rest_framework import serializers
  8. from rest_framework.fields import empty
  9. from rest_framework.request import Request
  10. from rest_framework.serializers import ModelSerializer
  11. from django.utils.functional import cached_property
  12. from rest_framework.utils.serializer_helpers import BindingDict
  13. from dvadmin.system.models import Users
  14. from django_restql.mixins import DynamicFieldsMixin
  15. class CustomModelSerializer(DynamicFieldsMixin, ModelSerializer):
  16. """
  17. 增强DRF的ModelSerializer,可自动更新模型的审计字段记录
  18. (1)self.request能获取到rest_framework.request.Request对象
  19. """
  20. # 修改人的审计字段名称, 默认modifier, 继承使用时可自定义覆盖
  21. modifier_field_id = "modifier"
  22. modifier_name = serializers.SerializerMethodField(read_only=True)
  23. def get_modifier_name(self, instance):
  24. if not hasattr(instance, "modifier"):
  25. return None
  26. queryset = (
  27. Users.objects.filter(id=instance.modifier)
  28. .values_list("name", flat=True)
  29. .first()
  30. )
  31. if queryset:
  32. return queryset
  33. return None
  34. # 创建人的审计字段名称, 默认creator, 继承使用时可自定义覆盖
  35. creator_field_id = "creator"
  36. creator_name = serializers.SlugRelatedField(
  37. slug_field="name", source="creator", read_only=True
  38. )
  39. # 数据所属部门字段
  40. dept_belong_id_field_name = "dept_belong_id"
  41. # 添加默认时间返回格式
  42. create_datetime = serializers.DateTimeField(
  43. format="%Y-%m-%d %H:%M:%S", required=False, read_only=True
  44. )
  45. update_datetime = serializers.DateTimeField(
  46. format="%Y-%m-%d %H:%M:%S", required=False
  47. )
  48. def __init__(self, instance=None, data=empty, request=None, **kwargs):
  49. super().__init__(instance, data, **kwargs)
  50. self.request: Request = request or self.context.get("request", None)
  51. def save(self, **kwargs):
  52. return super().save(**kwargs)
  53. def create(self, validated_data):
  54. if self.request:
  55. if str(self.request.user) != "AnonymousUser":
  56. if self.modifier_field_id in self.fields.fields:
  57. validated_data[self.modifier_field_id] = self.get_request_user_id()
  58. if self.creator_field_id in self.fields.fields:
  59. validated_data[self.creator_field_id] = self.request.user
  60. if (
  61. self.dept_belong_id_field_name in self.fields.fields
  62. and validated_data.get(self.dept_belong_id_field_name, None) is None
  63. ):
  64. validated_data[self.dept_belong_id_field_name] = getattr(
  65. self.request.user, "dept_id", None
  66. )
  67. return super().create(validated_data)
  68. def update(self, instance, validated_data):
  69. if self.request:
  70. if str(self.request.user) != "AnonymousUser":
  71. if self.modifier_field_id in self.fields.fields:
  72. validated_data[self.modifier_field_id] = self.get_request_user_id()
  73. if hasattr(self.instance, self.modifier_field_id):
  74. setattr(
  75. self.instance, self.modifier_field_id, self.get_request_user_id()
  76. )
  77. return super().update(instance, validated_data)
  78. def get_request_username(self):
  79. if getattr(self.request, "user", None):
  80. return getattr(self.request.user, "username", None)
  81. return None
  82. def get_request_name(self):
  83. if getattr(self.request, "user", None):
  84. return getattr(self.request.user, "name", None)
  85. return None
  86. def get_request_user_id(self):
  87. if getattr(self.request, "user", None):
  88. return getattr(self.request.user, "id", None)
  89. return None
  90. @property
  91. def errors(self):
  92. # get errors
  93. errors = super().errors
  94. verbose_errors = {}
  95. # fields = { field.name: field.verbose_name } for each field in model
  96. fields = {field.name: field.verbose_name for field in
  97. self.Meta.model._meta.get_fields() if hasattr(field, 'verbose_name')}
  98. # iterate over errors and replace error key with verbose name if exists
  99. for field_name, error in errors.items():
  100. if field_name in fields:
  101. verbose_errors[str(fields[field_name])] = error
  102. else:
  103. verbose_errors[field_name] = error
  104. return verbose_errors
  105. # @cached_property
  106. # def fields(self):
  107. # fields = BindingDict(self)
  108. # for key, value in self.get_fields().items():
  109. # fields[key] = value
  110. #
  111. # if not hasattr(self, '_context'):
  112. # return fields
  113. # is_root = self.root == self
  114. # parent_is_list_root = self.parent == self.root and getattr(self.parent, 'many', False)
  115. # if not (is_root or parent_is_list_root):
  116. # return fields
  117. #
  118. # try:
  119. # request = self.request or self.context['request']
  120. # except KeyError:
  121. # return fields
  122. # params = getattr(
  123. # request, 'query_params', getattr(request, 'GET', None)
  124. # )
  125. # if params is None:
  126. # pass
  127. # try:
  128. # filter_fields = params.get('_fields', None).split(',')
  129. # except AttributeError:
  130. # filter_fields = None
  131. #
  132. # try:
  133. # omit_fields = params.get('_exclude', None).split(',')
  134. # except AttributeError:
  135. # omit_fields = []
  136. #
  137. # existing = set(fields.keys())
  138. # if filter_fields is None:
  139. # allowed = existing
  140. # else:
  141. # allowed = set(filter(None, filter_fields))
  142. #
  143. # omitted = set(filter(None, omit_fields))
  144. # for field in existing:
  145. # if field not in allowed:
  146. # fields.pop(field, None)
  147. # if field in omitted:
  148. # fields.pop(field, None)
  149. #
  150. # return fields