viewset.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: Allen
  4. @Created on: 2023/10/18
  5. @Remark: 自定义视图集
  6. """
  7. import uuid
  8. from django.db import transaction
  9. from drf_yasg import openapi
  10. from drf_yasg.utils import swagger_auto_schema
  11. from rest_framework.decorators import action
  12. from rest_framework.viewsets import ModelViewSet
  13. from dvadmin.utils.filters import DataLevelPermissionsFilter
  14. from dvadmin.utils.import_export_mixin import ExportSerializerMixin, ImportSerializerMixin
  15. from dvadmin.utils.json_response import SuccessResponse, ErrorResponse, DetailResponse
  16. from dvadmin.utils.permission import CustomPermission
  17. from django_restql.mixins import QueryArgumentsMixin
  18. class CustomModelViewSet(ModelViewSet,ImportSerializerMixin,ExportSerializerMixin,QueryArgumentsMixin):
  19. """
  20. 自定义的ModelViewSet:
  21. 统一标准的返回格式;新增,查询,修改可使用不同序列化器
  22. (1)ORM性能优化, 尽可能使用values_queryset形式
  23. (2)xxx_serializer_class 某个方法下使用的序列化器(xxx=create|update|list|retrieve|destroy)
  24. (3)filter_fields = '__all__' 默认支持全部model中的字段查询(除json字段外)
  25. (4)import_field_dict={} 导入时的字段字典 {model值: model的label}
  26. (5)export_field_label = [] 导出时的字段
  27. """
  28. values_queryset = None
  29. ordering_fields = '__all__'
  30. create_serializer_class = None
  31. update_serializer_class = None
  32. filter_fields = '__all__'
  33. search_fields = ()
  34. extra_filter_backends = [DataLevelPermissionsFilter]
  35. permission_classes = [CustomPermission]
  36. import_field_dict = {}
  37. export_field_label = {}
  38. def filter_queryset(self, queryset):
  39. for backend in set(set(self.filter_backends) | set(self.extra_filter_backends or [])):
  40. queryset = backend().filter_queryset(self.request, queryset, self)
  41. return queryset
  42. def get_queryset(self):
  43. if getattr(self, 'values_queryset', None):
  44. return self.values_queryset
  45. return super().get_queryset()
  46. def get_serializer_class(self):
  47. action_serializer_name = f"{self.action}_serializer_class"
  48. action_serializer_class = getattr(self, action_serializer_name, None)
  49. if action_serializer_class:
  50. return action_serializer_class
  51. return super().get_serializer_class()
  52. # 通过many=True直接改造原有的API,使其可以批量创建
  53. def get_serializer(self, *args, **kwargs):
  54. serializer_class = self.get_serializer_class()
  55. kwargs.setdefault('context', self.get_serializer_context())
  56. if isinstance(self.request.data, list):
  57. with transaction.atomic():
  58. return serializer_class(many=True, *args, **kwargs)
  59. else:
  60. return serializer_class(*args, **kwargs)
  61. def create(self, request, *args, **kwargs):
  62. serializer = self.get_serializer(data=request.data, request=request)
  63. serializer.is_valid(raise_exception=True)
  64. self.perform_create(serializer)
  65. return DetailResponse(data=serializer.data, msg="新增成功")
  66. def list(self, request, *args, **kwargs):
  67. queryset = self.filter_queryset(self.get_queryset())
  68. page = self.paginate_queryset(queryset)
  69. if page is not None:
  70. serializer = self.get_serializer(page, many=True, request=request)
  71. return self.get_paginated_response(serializer.data)
  72. serializer = self.get_serializer(queryset, many=True, request=request)
  73. return SuccessResponse(data=serializer.data, msg="获取成功")
  74. def retrieve(self, request, *args, **kwargs):
  75. instance = self.get_object()
  76. serializer = self.get_serializer(instance)
  77. return DetailResponse(data=serializer.data, msg="获取成功")
  78. def update(self, request, *args, **kwargs):
  79. partial = kwargs.pop('partial', False)
  80. instance = self.get_object()
  81. serializer = self.get_serializer(instance, data=request.data, request=request, partial=partial)
  82. serializer.is_valid(raise_exception=True)
  83. self.perform_update(serializer)
  84. if getattr(instance, '_prefetched_objects_cache', None):
  85. # If 'prefetch_related' has been applied to a queryset, we need to
  86. # forcibly invalidate the prefetch cache on the instance.
  87. instance._prefetched_objects_cache = {}
  88. return DetailResponse(data=serializer.data, msg="更新成功")
  89. def destroy(self, request, *args, **kwargs):
  90. instance = self.get_object()
  91. instance.delete()
  92. return DetailResponse(data=[], msg="删除成功")
  93. keys = openapi.Schema(description='主键列表',type=openapi.TYPE_ARRAY,items=openapi.TYPE_STRING)
  94. @swagger_auto_schema(request_body=openapi.Schema(
  95. type=openapi.TYPE_OBJECT,
  96. required=['keys'],
  97. properties={'keys': keys}
  98. ), operation_summary='批量删除')
  99. @action(methods=['delete'],detail=False)
  100. def multiple_delete(self,request,*args,**kwargs):
  101. request_data = request.data
  102. keys = request_data.get('keys',None)
  103. if keys:
  104. self.get_queryset().filter(id__in=keys).delete()
  105. return SuccessResponse(data=[], msg="删除成功")
  106. else:
  107. return ErrorResponse(msg="未获取到keys字段")