Django REST Framework 企业级 API 架构实战
1. 引言
在企业级应用中,常面临权限校验重复、分页版本不一、缺乏限流等挑战。使用 DRF 可显著减少代码量并提升性能。
Django REST Framework 企业级 API 架构涵盖视图集设计、序列化器优化、权限控制、分页过滤及限流策略。通过数据库查询优化、多级缓存、性能监控中间件提升系统稳定性。包含用户管理 API 实现、API 版本管理及生产环境部署配置,提供高可用、高性能的 API 开发最佳实践方案。

在企业级应用中,常面临权限校验重复、分页版本不一、缺乏限流等挑战。使用 DRF 可显著减少代码量并提升性能。
DRF 是一个完整的 API 开发生态系统。其核心设计哲学:
# 传统 Django 视图 vs DRF 视图
# 传统方式 - 需要手动处理太多细节
def user_list(request):
if request.method != 'GET':
return JsonResponse({'error': 'Method not allowed'}, status=405)
users = User.objects.all()
data = [{'id': u.id, 'name': u.name} for u in users]
return JsonResponse(data, safe=False)
# DRF 方式 - 简洁明了
class UserViewSet(viewsets.ModelViewSet):
queryset = User.objects.all()
serializer_class = UserSerializer
permission_classes = [IsAuthenticated]
视图集的核心价值:将 HTTP 方法自动映射到对应的处理方法。总结的"三要三不要"原则:
要:
不要:
序列化器有三个核心职责:
性能数据对比(序列化 1000 条用户记录):
| 优化方法 | 耗时 (ms) | 内存 (MB) | 数据库查询 |
|---|---|---|---|
| 基础序列化 | 1200 | 320 | 1001 |
| select_related | 350 | 180 | 1 |
| 值对象优化 | 85 | 120 | 1 |
| 缓存结果 | 15 | 100 | 0 |
# 错误的序列化器用法 - 性能杀手
class BadUserSerializer(serializers.ModelSerializer):
posts = serializers.SerializerMethodField()
comments = serializers.SerializerMethodField()
def get_posts(self, obj): return obj.posts.count() # N+1 查询!
def get_comments(self, obj): return obj.comments.count() # 又一个 N+1!
# 正确的序列化器 - 性能优化
class OptimizedUserSerializer(serializers.ModelSerializer):
post_count = serializers.IntegerField(source='posts.count', read_only=True)
comment_count = serializers.IntegerField(source='comments.count', read_only=True)
class Meta:
model = User
fields = ['id', 'name', 'post_count', 'comment_count']
read_only_fields = ['post_count', 'comment_count']
def to_representation(self, instance):
if not hasattr(instance, '_prefetched_objects_cache'):
instance = User.objects.prefetch_related('posts', 'comments').get(pk=instance.pk)
return super().to_representation(instance)
# serializers.py
from rest_framework import serializers
from django.contrib.auth import get_user_model
from django.contrib.auth.password_validation import validate_password
User = get_user_model()
class UserSerializer(serializers.ModelSerializer):
"""用户序列化器 - 生产级实现"""
password = serializers.CharField(write_only=True, required=True, validators=[validate_password], style={'input_type': 'password'})
confirm_password = serializers.CharField(write_only=True, required=True, style={'input_type': 'password'})
class Meta:
model = User
fields = ['id', 'username', 'email', 'password', 'confirm_password', 'first_name', 'last_name', 'is_active', 'date_joined', 'last_login']
read_only_fields = ['id', 'date_joined', 'last_login']
extra_kwargs = {'email': {'required': True}, 'username': {'min_length': 3, 'max_length': 30}}
def validate(self, attrs):
if attrs['password'] != attrs.get('confirm_password'):
raise serializers.ValidationError({"password": "两次输入的密码不一致"})
return attrs
def create(self, validated_data):
validated_data.pop('confirm_password')
user = User.objects.create_user(**validated_data)
return user
def update(self, instance, validated_data):
validated_data.pop('confirm_password', None)
password = validated_data.pop('password', None)
for attr, value in validated_data.items():
setattr(instance, attr, value)
if password:
instance.set_password(password)
instance.save()
return instance
class UserDetailSerializer(UserSerializer):
"""用户详情序列化器 - 包含统计信息"""
stats = serializers.SerializerMethodField()
class Meta(UserSerializer.Meta):
fields = UserSerializer.Meta.fields + ['stats']
def get_stats(self, obj):
from django.db.models import Count
from apps.posts.models import Post
from apps.comments.models import Comment
return {
'post_count': Post.objects.filter(author=obj).count(),
'comment_count': Comment.objects.filter(user=obj).count(),
'like_count': Post.objects.filter(author=obj).aggregate(total_likes=Count('likes'))['total_likes'] or 0
}
# permissions.py
from rest_framework import permissions
class IsOwnerOrReadOnly(permissions.BasePermission):
"""对象所有者或只读权限"""
def has_object_permission(self, request, view, obj):
if request.method in permissions.SAFE_METHODS: return True
if hasattr(obj, 'user'): return obj.user == request.user
elif hasattr(obj, 'author'): return obj.author == request.user
elif hasattr(obj, 'owner'): return obj.owner == request.user
elif hasattr(obj, 'created_by'): return obj.created_by == request.user
return False
class IsAdminOrReadOnly(permissions.BasePermission):
"""管理员可写,其他人只读"""
def has_permission(self, request, view):
if request.method in permissions.SAFE_METHODS: return True
return request.user and request.user.is_staff
class RateLimitPermission(permissions.BasePermission):
"""接口调用频率限制"""
def __init__(self, rate='5/minute'): self.rate = rate
def has_permission(self, request, view):
cache_key = f"ratelimit:{request.user.id}:{view.__class__.__name__}"
from django.core.cache import cache
count = cache.get(cache_key, 0)
if count >= 5: return False
cache.set(cache_key, count + 1, 60)
return True
# views.py
from rest_framework import viewsets, status, mixins
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.permissions import IsAuthenticated, IsAdminUser, AllowAny
from rest_framework_simplejwt.tokens import RefreshToken
from django.contrib.auth import get_user_model
from django.utils import timezone
from django.db.models import Q
from .serializers import UserSerializer, UserDetailSerializer
from .permissions import IsOwnerOrReadOnly, RateLimitPermission
from .pagination import StandardPagination
User = get_user_model()
class UserViewSet(viewsets.ModelViewSet):
"""用户视图集 - 完整的 CRUD 操作"""
queryset = User.objects.filter(is_active=True)
serializer_class = UserSerializer
pagination_class = StandardPagination
def get_permissions(self):
if self.action == 'create': return [AllowAny()]
elif self.action in ['update', 'partial_update', 'destroy']: return [IsAuthenticated(), IsOwnerOrReadOnly()]
elif self.action in ['list', 'retrieve']: return [IsAuthenticated()]
elif self.action == 'admin_only': return [IsAuthenticated(), IsAdminUser()]
else: return [IsAuthenticated()]
def get_serializer_class(self):
if self.action == 'retrieve': return UserDetailSerializer
return UserSerializer
def get_queryset(self):
queryset = super().get_queryset()
search = self.request.query_params.get('search')
if search:
queryset = queryset.filter(Q(username__icontains=search) | Q(email__icontains=search) | Q(first_name__icontains=search) | Q(last_name__icontains=search))
ordering = self.request.query_params.get('ordering', '-date_joined')
if ordering.lstrip('-') in ['username', 'email', 'date_joined', 'last_login']:
queryset = queryset.order_by(ordering)
queryset = queryset.select_related('profile').prefetch_related('groups')
return queryset
def perform_create(self, serializer):
user = serializer.save()
user.last_login = timezone.now()
user.save(update_fields=['last_login'])
@action(detail=False, methods=['post'], permission_classes=[AllowAny])
def register(self, request):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
user = serializer.save()
refresh = RefreshToken.for_user(user)
return Response({'user': serializer.data, 'refresh': str(refresh), 'access': str(refresh.access_token)}, status=status.HTTP_201_CREATED)
@action(detail=False, methods=['post'], permission_classes=[IsAuthenticated])
def change_password(self, request):
from django.contrib.auth.password_validation import validate_password
from django.core.exceptions import ValidationError
old_password = request.data.get('old_password')
new_password = request.data.get('new_password')
if not old_password or not new_password:
return Response({'error': '需要原密码和新密码'}, status=status.HTTP_400_BAD_REQUEST)
if not request.user.check_password(old_password):
return Response({'error': '原密码错误'}, status=status.HTTP_400_BAD_REQUEST)
try:
validate_password(new_password, request.user)
except ValidationError as e:
return Response({'error': '新密码不符合要求', 'details': list(e.messages)}, status=status.HTTP_400_BAD_REQUEST)
request.user.set_password(new_password)
request.user.save()
return Response({'message': '密码修改成功'})
@action(detail=False, methods=['get'])
def me(self, request):
serializer = UserDetailSerializer(request.user)
return Response(serializer.data)
@action(detail=True, methods=['post'], permission_classes=[IsAuthenticated])
def follow(self, request, pk=None):
user_to_follow = self.get_object()
if request.user == user_to_follow:
return Response({'error': '不能关注自己'}, status=status.HTTP_400_BAD_REQUEST)
if request.user.following.filter(id=user_to_follow.id).exists():
return Response({'error': '已经关注该用户'}, status=status.HTTP_400_BAD_REQUEST)
request.user.following.add(user_to_follow)
return Response({'message': '关注成功'})
@action(detail=True, methods=['post'], permission_classes=[IsAuthenticated])
def unfollow(self, request, pk=None):
user_to_unfollow = self.get_object()
if not request.user.following.filter(id=user_to_unfollow.id).exists():
return Response({'error': '未关注该用户'}, status=status.HTTP_400_BAD_REQUEST)
request.user.following.remove(user_to_unfollow)
return Response({'message': '取消关注成功'})
# pagination.py
from rest_framework.pagination import PageNumberPagination, CursorPagination
from rest_framework.response import Response
from collections import OrderedDict
class StandardPagination(PageNumberPagination):
"""标准分页器"""
page_size = 20
page_size_query_param = 'page_size'
max_page_size = 100
page_query_param = 'page'
def get_paginated_response(self, data):
return Response(OrderedDict([
('count', self.page.paginator.count),
('next', self.get_next_link()),
('previous', self.get_previous_link()),
('page_size', self.get_page_size(self.request)),
('current_page', self.page.number),
('total_pages', self.page.paginator.num_pages),
('results', data)
]))
class CursorPaginationWithCount(CursorPagination):
"""带计数的游标分页 - 适用于无限滚动"""
page_size = 20
ordering = '-created_at'
def get_paginated_response(self, data):
return Response(OrderedDict([
('next', self.get_next_link()),
('previous', self.get_previous_link()),
('results', data)
]))
class LargeResultsSetPagination(PageNumberPagination):
"""大数据集分页"""
page_size = 100
page_size_query_param = 'page_size'
max_page_size = 1000
class SmallResultsSetPagination(PageNumberPagination):
"""小数据集分页"""
page_size = 10
page_size_query_param = 'page_size'
max_page_size = 50
# filters.py
import django_filters
from django_filters.rest_framework import FilterSet, filters
from django.db.models import Q
from .models import Product, Category
class ProductFilter(FilterSet):
"""商品过滤器"""
name = filters.CharFilter(lookup_expr='icontains')
min_price = filters.NumberFilter(field_name='price', lookup_expr='gte')
max_price = filters.NumberFilter(field_name='price', lookup_expr='lte')
category = filters.ModelMultipleChoiceFilter(field_name='category', queryset=Category.objects.all())
tags = filters.CharFilter(method='filter_tags')
in_stock = filters.BooleanFilter(method='filter_in_stock')
class Meta:
model = Product
fields = ['name', 'category', 'status', 'is_featured']
def filter_tags(self, queryset, name, value):
tags = value.split(',')
query = Q()
for tag in tags:
query |= Q(tags__name__iexact=tag.strip())
return queryset.filter(query).distinct()
def filter_in_stock(self, queryset, name, value):
if value:
return queryset.filter(stock_quantity__gt=0)
return queryset.filter(stock_quantity=0)
@property
def qs(self):
queryset = super().qs
return queryset.select_related('category').prefetch_related('tags')
class AdvancedSearchFilter(django_filters.FilterSet):
"""高级搜索过滤器"""
q = filters.CharFilter(method='search_filter')
sort_by = filters.CharFilter(method='sort_filter')
def search_filter(self, queryset, name, value):
return queryset.filter(Q(name__icontains=value) | Q(description__icontains=value) | Q(sku__icontains=value))
def sort_filter(self, queryset, name, value):
if value in ['price', '-price', 'created_at', '-created_at', 'name', '-name']:
return queryset.order_by(value)
return queryset
# throttles.py
from rest_framework.throttling import SimpleRateThrottle, UserRateThrottle, AnonRateThrottle
from django.core.cache import cache
import time
class BurstRateThrottle(UserRateThrottle):
"""突发请求限制"""
scope = 'burst'
rate = '100/minute'
class SustainedRateThrottle(UserRateThrottle):
"""持续请求限制"""
scope = 'sustained'
rate = '1000/day'
class MethodSpecificThrottle(SimpleRateThrottle):
"""按 HTTP 方法限流"""
scope = 'method_specific'
def get_cache_key(self, request, view):
if request.user.is_authenticated:
ident = request.user.pk
else:
ident = self.get_ident(request)
return self.cache_format % {'scope': self.scope, 'ident': f"{ident}:{request.method}"}
def get_rate(self):
if self.request.method == 'GET': return '100/minute'
elif self.request.method == 'POST': return '20/minute'
elif self.request.method in ['PUT', 'PATCH']: return '10/minute'
elif self.request.method == 'DELETE': return '5/minute'
return '100/minute'
class SmartThrottle(SimpleRateThrottle):
"""智能节流 - 根据用户行为动态调整"""
scope = 'smart'
def allow_request(self, request, view):
if self._is_whitelisted(request): return True
user_level = self._get_user_level(request)
if user_level == 'vip': self.rate = '1000/minute'
elif user_level == 'premium': self.rate = '500/minute'
else: self.rate = '100/minute'
return super().allow_request(request, view)
def _is_whitelisted(self, request):
whitelist = ['127.0.0.1', '192.168.1.1']
return request.META.get('REMOTE_ADDR') in whitelist
def _get_user_level(self, request):
if request.user.is_authenticated:
if hasattr(request.user, 'profile'): return request.user.profile.level
return 'normal'
class RedisThrottle(SimpleRateThrottle):
"""基于 Redis 的分布式节流"""
cache = cache
scope = 'redis_throttle'
rate = '100/minute'
def __init__(self):
self.num_requests = 100
self.duration = 60
def get_cache_key(self, request, view):
if request.user.is_authenticated:
ident = request.user.pk
else:
ident = self.get_ident(request)
return f"throttle:{self.scope}:{ident}"
def allow_request(self, request, view):
key = self.get_cache_key(request, view)
current = cache.get(key, 0)
if current >= self.num_requests: return False
cache.incr(key, 1)
if current == 0: cache.expire(key, self.duration)
return True
# settings.py 配置
REST_FRAMEWORK = {
'DEFAULT_THROTTLE_CLASSES': [
'rest_framework.throttling.AnonRateThrottle',
'rest_framework.throttling.UserRateThrottle',
'apps.api.throttles.BurstRateThrottle',
'apps.api.throttles.MethodSpecificThrottle',
],
'DEFAULT_THROTTLE_RATES': {
'anon': '100/day',
'user': '1000/day',
'burst': '100/minute',
'sustained': '1000/day',
'method_specific': '100/minute',
'smart': '100/minute',
}
}
# cache_utils.py
from django.core.cache import cache
from django.utils.decorators import method_decorator
from django.views.decorators.cache import cache_page
from django.views.decorators.vary import vary_on_headers, vary_on_cookie
from functools import wraps
import hashlib
import json
def cache_per_user(timeout):
"""按用户缓存装饰器"""
def decorator(view_func):
@wraps(view_func)
def _wrapped_view(request, *args, **kwargs):
cache_key = f"user_cache:{request.user.id}:{request.path}"
cached_response = cache.get(cache_key)
if cached_response is not None: return cached_response
response = view_func(request, *args, **kwargs)
cache.set(cache_key, response, timeout)
return response
return _wrapped_view
return decorator
def cache_response(timeout=300, vary_on_user=True):
"""响应缓存装饰器"""
def decorator(view_func):
@wraps(view_func)
def _wrapped_view(request, *args, **kwargs):
key_parts = [key_prefix, request.path] if 'key_prefix' in locals() else [request.path]
if vary_on_user and request.user.is_authenticated:
key_parts.append(str(request.user.id))
if request.GET:
key_parts.append(hashlib.md5(request.GET.urlencode().encode()).hexdigest())
cache_key = ':'.join(key_parts)
cached_data = cache.get(cache_key)
if cached_data is not None: return JsonResponse(cached_data)
response = view_func(request, *args, **kwargs)
if response.status_code == 200: cache.set(cache_key, response.data, timeout)
return response
return _wrapped_view
return decorator
class CacheMixin:
"""缓存混入类"""
cache_timeout = 300
cache_vary_on_user = True
@method_decorator(cache_page(cache_timeout))
@method_decorator(vary_on_headers('Authorization'))
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)
def get_cache_key(self, request):
key_parts = [self.__class__.__name__, request.method, request.path]
if self.cache_vary_on_user and request.user.is_authenticated:
key_parts.append(str(request.user.id))
if request.GET:
sorted_params = sorted(request.GET.items())
key_parts.append(hashlib.md5(json.dumps(sorted_params).encode()).hexdigest())
return ':'.join(key_parts)
# performance_middleware.py
import time
import json
from django.utils.deprecation import MiddlewareMixin
from django.db import connection
from django.core.cache import cache
import logging
logger = logging.getLogger('performance')
class PerformanceMiddleware(MiddlewareMixin):
"""性能监控中间件"""
def process_request(self, request):
request._start_time = time.time()
request._db_queries_start = len(connection.queries)
request._cache_hits_start = cache._cache.get_stats()[0] if hasattr(cache._cache, 'get_stats') else 0
def process_response(self, request, response):
total_time = (time.time() - request._start_time) * 1000
db_queries = len(connection.queries) - request._db_queries_start
db_time = sum(float(q['time']) for q in connection.queries[-db_queries:]) if db_queries > 0 else 0
log_data = {
'method': request.method, 'path': request.path, 'status': response.status_code,
'time_ms': round(total_time, 2), 'db_queries': db_queries, 'db_time_ms': round(db_time * 1000, 2),
'cache_hits': getattr(request, '_cache_hits', 0), 'cache_misses': getattr(request, '_cache_misses', 0),
'user_id': request.user.id if request.user.is_authenticated else None,
'user_agent': request.META.get('HTTP_USER_AGENT', '')[:200]
}
if total_time > 1000: logger.warning(f"慢接口:{json.dumps(log_data)}")
else: logger.info(f"接口性能:{json.dumps(log_data)}")
response['X-Response-Time'] = f'{total_time:.2f}ms'
response['X-DB-Queries'] = str(db_queries)
return response
def process_exception(self, request, exception):
total_time = (time.time() - request._start_time) * 1000
logger.error(f"接口异常:{request.method} {request.path} - 耗时:{total_time:.2f}ms - 异常:{str(exception)}")
# versioning.py
from rest_framework.versioning import URLPathVersioning, AcceptHeaderVersioning
from rest_framework.compat import unicode_http_header
from django.urls import reverse
class AcceptHeaderVersioningWithFallback(AcceptHeaderVersioning):
"""带回退的版本控制"""
default_version = 'v1'
allowed_versions = ['v1', 'v2', 'v3']
def determine_version(self, request, *args, **kwargs):
version = super().determine_version(request, *args, **kwargs)
if version not in self.allowed_versions: return self.default_version
return version
class NamespaceVersioning(URLPathVersioning):
"""命名空间版本控制"""
default_version = 'v1'
allowed_versions = ['v1', 'v2', 'v3']
version_param = 'version'
def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
if request.version is not None:
kwargs = {} if (kwargs is None) else kwargs
kwargs[self.version_param] = request.version
return super().reverse(viewname, args, kwargs, request, format, **extra)
# urls.py
from django.urls import path, include
from rest_framework.routers import DefaultRouter
router = DefaultRouter()
router.register(r'users', UserViewSet, basename='user')
urlpatterns = [
path('api/v1/', include((router.urls, 'v1'), namespace='v1')),
path('api/v2/', include((router.urls, 'v2'), namespace='v2')),
path('api/versions/', include([
path('v1/', include('apps.api.v1.urls', namespace='api_v1')),
path('v2/', include('apps.api.v2.urls', namespace='api_v2')),
])),
]
# 优化前 - 产生 N+1 查询
users = User.objects.all()
for user in users: print(user.profile.bio) # 每次循环都查询数据库
# 优化后 - 使用 select_related
users = User.objects.select_related('profile').all()
for user in users: print(user.profile.bio) # 只查询一次
# 多对多关系使用 prefetch_related
articles = Article.objects.prefetch_related('tags').all()
for article in articles: print([tag.name for tag in article.tags.all()])
# 使用 values/values_list 获取特定字段
user_ids = User.objects.filter(is_active=True).values_list('id', flat=True)
users_data = User.objects.values('id', 'username', 'email')
# 使用 annotate 进行聚合查询
from django.db.models import Count, Avg, Sum
stats = User.objects.aggregate(total=Count('id'), active=Count('id', filter=Q(is_active=True)))
# 使用索引优化
class User(models.Model):
email = models.EmailField(db_index=True)
username = models.CharField(max_length=150, db_index=True)
date_joined = models.DateTimeField(db_index=True)
class Meta:
indexes = [
models.Index(fields=['is_active', 'date_joined']),
models.Index(fields=['email'], name='email_idx'),
]
# 分页优化
from django.core.paginator import Paginator
paginator = Paginator(User.objects.all(), 10, allow_empty_first_page=False)
page = paginator.page(1)
# 优化前 - 低效序列化器
class BadProductSerializer(serializers.ModelSerializer):
category_name = serializers.CharField(source='category.name')
reviews = serializers.SerializerMethodField()
rating = serializers.SerializerMethodField()
def get_reviews(self, obj): return obj.reviews.count() # N+1 查询
def get_rating(self, obj): return obj.reviews.aggregate(Avg('rating'))['rating__avg']
# 优化后 - 高效序列化器
class OptimizedProductSerializer(serializers.ModelSerializer):
category_name = serializers.CharField(source='category.name', read_only=True)
review_count = serializers.IntegerField(read_only=True)
average_rating = serializers.FloatField(read_only=True)
class Meta:
model = Product
fields = ['id', 'name', 'price', 'category_name', 'review_count', 'average_rating']
@staticmethod
def setup_eager_loading(queryset):
return queryset.select_related('category').prefetch_related('reviews')
def to_representation(self, instance):
data = super().to_representation(instance)
if hasattr(instance, 'review_count'): data['review_count'] = instance.review_count
if hasattr(instance, 'average_rating'): data['average_rating'] = instance.average_rating
return data
class ProductViewSet(viewsets.ModelViewSet):
queryset = Product.objects.all()
serializer_class = OptimizedProductSerializer
def get_queryset(self):
queryset = super().get_queryset()
queryset = OptimizedProductSerializer.setup_eager_loading(queryset)
queryset = queryset.annotate(review_count=Count('reviews'), average_rating=Avg('reviews__rating'))
return queryset
# 多级缓存实现
from django.core.cache import caches
class MultiLevelCache:
"""多级缓存"""
def __init__(self):
self.l1_cache = caches['memcached']
self.l2_cache = caches['redis']
self.l1_ttl = 60
self.l2_ttl = 3600
def get(self, key):
data = self.l1_cache.get(key)
if data is not None: return data
data = self.l2_cache.get(key)
if data is not None:
self.l1_cache.set(key, data, self.l1_ttl)
return data
return None
def set(self, key, value):
self.l1_cache.set(key, value, self.l1_ttl)
self.l2_cache.set(key, value, self.l2_ttl)
def delete(self, key):
self.l1_cache.delete(key)
self.l2_cache.delete(key)
def cache_method(timeout=300, key_prefix=''):
"""方法缓存装饰器"""
def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
cache_key = f"{key_prefix}:{func.__name__}:{hashlib.md5(str(args).encode() + str(kwargs).encode()).hexdigest()}"
cached_result = cache.get(cache_key)
if cached_result is not None: return cached_result
result = func(self, *args, **kwargs)
cache.set(cache_key, result, timeout)
return result
return wrapper
return decorator
class ProductViewSet(viewsets.ModelViewSet):
@method_decorator(cache_page(60 * 5))
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)
@method_decorator(cache_per_user(60 * 2))
def retrieve(self, request, *args, **kwargs):
return super().retrieve(request, *args, **kwargs)
@cache_method(60 * 10, 'expensive_calculation')
def get_expensive_data(self, product_id):
time.sleep(1)
return {"result": "expensive_data"}
# monitoring.py
from prometheus_client import Counter, Histogram, Gauge, Summary
from django.db import connection
import time
REQUEST_COUNT = Counter('django_http_requests_total', 'Total HTTP requests', ['method', 'endpoint', 'status'])
REQUEST_DURATION = Histogram('django_http_request_duration_seconds', 'HTTP request duration', ['method', 'endpoint'])
DB_QUERY_DURATION = Histogram('django_db_query_duration_seconds', 'Database query duration', ['model', 'operation'])
CACHE_HITS = Counter('django_cache_hits_total', 'Total cache hits')
CACHE_MISSES = Counter('django_cache_misses_total', 'Total cache misses')
ACTIVE_USERS = Gauge('django_active_users', 'Active users count')
API_ERRORS = Counter('django_api_errors_total', 'Total API errors')
class MetricsMiddleware:
"""指标收集中间件"""
def __init__(self, get_response): self.get_response = get_response
def __call__(self, request):
start_time = time.time()
response = self.get_response(request)
duration = time.time() - start_time
REQUEST_COUNT.labels(method=request.method, endpoint=request.path, status=response.status_code).inc()
REQUEST_DURATION.labels(method=request.method, endpoint=request.path).observe(duration)
db_queries = connection.queries
for query in db_queries:
query_type = query['sql'].split()[0].upper()
DB_QUERY_DURATION.labels(model='unknown', operation=query_type).observe(float(query['time']))
return response
def monitor_performance(name):
"""性能监控装饰器"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
try:
result = func(*args, **kwargs)
status = 'success'
except Exception as e:
status = 'error'
API_ERRORS.inc()
raise e
finally:
duration = time.time() - start_time
REQUEST_DURATION.labels(method='function', endpoint=name).observe(duration)
return result
return wrapper
return decorator
class ProductViewSet(viewsets.ModelViewSet):
@monitor_performance('product_list')
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)
# prometheus/alerts.yml
groups:
- name: django_api
rules:
- alert: HighErrorRate
expr: rate(django_http_requests_total{status=~"5.."}[5m]) / rate(django_http_requests_total[5m]) > 0.05
for: 2m
labels: severity: critical
annotations:
summary: "API 错误率过高"
description: "5 分钟内 API 错误率超过 5%"
- alert: SlowAPIResponse
expr: histogram_quantile(0.95, rate(django_http_request_duration_seconds_bucket[5m])) > 1
for: 5m
labels: severity: warning
annotations:
summary: "API 响应过慢"
description: "95% 的 API 响应时间超过 1 秒"
- alert: HighDatabaseLatency
expr: histogram_quantile(0.95, rate(django_db_query_duration_seconds_bucket[5m])) > 0.5
for: 5m
labels: severity: warning
annotations:
summary: "数据库查询过慢"
description: "95% 的数据库查询时间超过 500ms"
- alert: LowCacheHitRate
expr: rate(django_cache_hits_total[5m]) / (rate(django_cache_hits_total[5m]) + rate(django_cache_misses_total[5m])) < 0.7
for: 10m
labels: severity: warning
annotations:
summary: "缓存命中率过低"
description: "缓存命中率低于 70%"
project/
├── apps/
│ ├── users/
│ │ ├── serializers/
│ │ │ ├── __init__.py
│ │ │ ├── user_serializers.py
│ │ │ └── profile_serializers.py
│ │ ├── permissions.py
│ │ ├── filters.py
│ │ ├── pagination.py
│ │ ├── throttles.py
│ │ └── views.py
│ └── products/
├── utils/
│ ├── exceptions.py
│ ├── response.py
│ └── pagination.py
└── config/
# performance_checklist.py
"""
性能优化检查清单
1. 数据库优化
[ ] 使用 select_related/prefetch_related
[ ] 避免 N+1 查询
[ ] 添加合适索引
[ ] 使用 values/values_list
[ ] 分页优化
2. 序列化优化
[ ] 避免 SerializerMethodField
[ ] 使用 read_only/write_only
[ ] 预计算字段
[ ] 延迟加载
3. 缓存优化
[ ] 热点数据缓存
[ ] 查询结果缓存
[ ] 页面片段缓存
[ ] 缓存失效策略
4. 代码优化
[ ] 懒加载
[ ] 批量操作
[ ] 异步任务
[ ] 连接池
"""
# settings/production.py
from .base import *
DEBUG = False
SECURE_SSL_REDIRECT = True
SECURE_HSTS_SECONDS = 31536000
SECURE_HSTS_INCLUDE_SUBDOMAINS = True
SECURE_HSTS_PRELOAD = True
SESSION_COOKIE_SECURE = True
CSRF_COOKIE_SECURE = True
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.postgresql',
'NAME': env('DB_NAME'),
'USER': env('DB_USER'),
'PASSWORD': env('DB_PASSWORD'),
'HOST': env('DB_HOST'),
'PORT': env('DB_PORT'),
'CONN_MAX_AGE': 300,
'OPTIONS': {'sslmode': 'require', 'connect_timeout': 10},
}
}
CACHES = {
'default': {
'BACKEND': 'django_redis.cache.RedisCache',
'LOCATION': env('REDIS_URL'),
'OPTIONS': {
'CLIENT_CLASS': 'django_redis.client.DefaultClient',
'PARSER_CLASS': 'redis.connection.HiredisParser',
'CONNECTION_POOL_CLASS': 'redis.BlockingConnectionPool',
'CONNECTION_POOL_CLASS_KWARGS': {'max_connections': 50, 'timeout': 20},
'MAX_CONNECTIONS': 1000,
},
'KEY_PREFIX': 'production',
}
}
REST_FRAMEWORK = {
'DEFAULT_RENDERER_CLASSES': ['rest_framework.renderers.JSONRenderer'],
'DEFAULT_PARSER_CLASSES': ['rest_framework.parsers.JSONParser'],
'DEFAULT_AUTHENTICATION_CLASSES': ['rest_framework_simplejwt.authentication.JWTAuthentication'],
'DEFAULT_THROTTLE_CLASSES': [
'rest_framework.throttling.AnonRateThrottle',
'rest_framework.throttling.UserRateThrottle',
],
'DEFAULT_THROTTLE_RATES': {'anon': '100/hour', 'user': '1000/hour'},
'EXCEPTION_HANDLER': 'utils.exceptions.production_exception_handler',
}
必须做的:
避免做的:

微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online
将 Markdown(GFM)转为 HTML 片段,浏览器内 marked 解析;与 HTML转Markdown 互为补充。 在线工具,Markdown转HTML在线工具,online
将 HTML 片段转为 GitHub Flavored Markdown,支持标题、列表、链接、代码块与表格等;浏览器内处理,可链接预填。 在线工具,HTML转Markdown在线工具,online
通过删除不必要的空白来缩小和压缩JSON。 在线工具,JSON 压缩在线工具,online