跳到主要内容
极客日志极客日志面向AI+效率的开发者社区
首页博客GitHub 精选镜像工具UI配色美学隐私政策关于联系
搜索内容 / 工具 / 仓库 / 镜像...⌘K搜索
注册
博客列表
PythonSaaS

Django REST Framework 企业级 API 架构实战

深入解析 Django REST Framework 在企业级场景下的核心架构实践。内容涵盖视图集抽象、序列化器性能调优、细粒度权限控制、分页过滤与限流策略。重点介绍多级缓存、性能监控中间件及数据库优化方案,解决 N+1 查询与响应延迟问题。提供生产级代码示例、部署配置及安全规范,助力构建高可用、可扩展的 API 服务。

steve发布于 2026/3/16更新于 2026/6/1137 浏览
Django REST Framework 企业级 API 架构实战

Django REST Framework 企业级 API 架构实战

核心原理深度解析

DRF 架构设计哲学

DRF 不仅仅是一个简单的包装器,而是一个完整的 API 开发生态系统。其核心设计哲学包括:

  1. 约定优于配置:提供合理的默认值。
  2. 可插拔组件:每个部分都可替换。
  3. 显式优于隐式:配置明确,避免魔法。
  4. DRY 原则:减少重复代码。
# 传统 Django 视图 vs DRF 视图
# 传统方式 - 需要手动处理太多细节
def user_list(request):
    if request.method != 'GET':
        return JsonResponse({'error': 'Method not allowed'}, status=405)
    users = User.objects.all()
    data = [{'id': u.id, 'name': u.name} for u in users]
    return JsonResponse(data, safe=False)

# DRF 方式 - 简洁明了
class UserViewSet(viewsets.ModelViewSet):
    queryset = User.objects.all()
    serializer_class = UserSerializer
    permission_classes = [IsAuthenticated]

视图集:CRUD 的终极抽象

视图集的核心价值是将 HTTP 方法自动映射到对应的处理方法。遵循三要三不要原则:

  • 要:标准 CRUD 操作用 ModelViewSet;只读操作用 ReadOnlyModelViewSet;自定义动作用 @action 装饰器。
  • 不要:复杂业务逻辑全塞在一个 ViewSet;过度覆盖 get_queryset 和 get_serializer;忽视权限控制。

序列化器:不只是数据转换

序列化器有三个核心职责:数据验证、数据转换(Python 对象↔JSON)、关系处理。

性能数据对比(序列化 1000 条用户记录):

优化方法耗时 (ms)内存 (MB)数据库查询
基础序列化12003201001
select_related3501801
值对象优化851201
缓存结果151000
# 错误的序列化器用法 - 性能杀手
class BadUserSerializer(serializers.ModelSerializer):
    posts = serializers.SerializerMethodField()
    comments = serializers.SerializerMethodField()
    def get_posts(self, obj): return obj.posts.count() # N+1 查询!
    def get_comments(self, obj): return obj.comments.count() # 又一个 N+1!

# 正确的序列化器 - 性能优化
class OptimizedUserSerializer(serializers.ModelSerializer):
    post_count = serializers.IntegerField(source='posts.count', read_only=True)
    comment_count = serializers.IntegerField(source='comments.count', read_only=True)
    class Meta:
        model = User
        fields = ['id', 'name', 'post_count', 'comment_count']
        read_only_fields = ['post_count', 'comment_count']
    def to_representation(self, instance):
        if not hasattr(instance, '_prefetched_objects_cache'):
            instance = User.objects.prefetch_related('posts', 'comments').get(pk=instance.pk)
        return super().to_representation(instance)

实战:完整 API 实现

用户管理 API

# serializers.py
from rest_framework import serializers
from django.contrib.auth import get_user_model
from django.contrib.auth.password_validation import validate_password

User = get_user_model()

class UserSerializer(serializers.ModelSerializer):
    """用户序列化器 - 生产级实现"""
    password = serializers.CharField(write_only=True, required=True, validators=[validate_password], style={'input_type': 'password'})
    confirm_password = serializers.CharField(write_only=True, required=True, style={'input_type': 'password'})

    class Meta:
        model = User
        fields = ['id', 'username', 'email', 'password', 'confirm_password', 'first_name', 'last_name', 'is_active', 'date_joined', 'last_login']
        read_only_fields = ['id', 'date_joined', 'last_login']
        extra_kwargs = {'email': {'required': True}, 'username': {'min_length': 3, 'max_length': 30}}

    def validate(self, attrs):
        if attrs['password'] != attrs.get('confirm_password'):
            raise serializers.ValidationError({"password": "两次输入的密码不一致"})
        return attrs

    def create(self, validated_data):
        validated_data.pop('confirm_password')
        user = User.objects.create_user(**validated_data)
        return user

    def update(self, instance, validated_data):
        validated_data.pop('confirm_password', None)
        password = validated_data.pop('password', None)
        for attr, value in validated_data.items():
            setattr(instance, attr, value)
        if password:
            instance.set_password(password)
        instance.save()
        return instance

class UserDetailSerializer(UserSerializer):
    """用户详情序列化器 - 包含统计信息"""
    stats = serializers.SerializerMethodField()

    class Meta(UserSerializer.Meta):
        fields = UserSerializer.Meta.fields + ['stats']

    def get_stats(self, obj):
        from django.db.models import Count
        from apps.posts.models import Post
        from apps.comments.models import Comment
        return {
            'post_count': Post.objects.filter(author=obj).count(),
            'comment_count': Comment.objects.filter(user=obj).count(),
            'like_count': Post.objects.filter(author=obj).aggregate(total_likes=Count('likes'))['total_likes'] or 0
        }
# permissions.py
from rest_framework import permissions

class IsOwnerOrReadOnly(permissions.BasePermission):
    """对象所有者或只读权限"""
    def has_object_permission(self, request, view, obj):
        if request.method in permissions.SAFE_METHODS: return True
        if hasattr(obj, 'user'): return obj.user == request.user
        elif hasattr(obj, 'author'): return obj.author == request.user
        elif hasattr(obj, 'owner'): return obj.owner == request.user
        elif hasattr(obj, 'created_by'): return obj.created_by == request.user
        return False

class IsAdminOrReadOnly(permissions.BasePermission):
    """管理员可写,其他人只读"""
    def has_permission(self, request, view):
        if request.method in permissions.SAFE_METHODS: return True
        return request.user and request.user.is_staff

class RateLimitPermission(permissions.BasePermission):
    """接口调用频率限制"""
    def __init__(self, rate='5/minute'): self.rate = rate
    def has_permission(self, request, view):
        cache_key = f"ratelimit:{request.user.id}:{view.__class__.__name__}"
        from django.core.cache import cache
        count = cache.get(cache_key, 0)
        if count >= 5: return False
        cache.set(cache_key, count + 1, 60)
        return True
# views.py
from rest_framework import viewsets, status, mixins
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.permissions import IsAuthenticated, IsAdminUser, AllowAny
from rest_framework_simplejwt.tokens import RefreshToken
from django.contrib.auth import get_user_model
from django.utils import timezone
from django.db.models import Q
from .serializers import UserSerializer, UserDetailSerializer
from .permissions import IsOwnerOrReadOnly, RateLimitPermission
from .pagination import StandardPagination

User = get_user_model()

class UserViewSet(viewsets.ModelViewSet):
    """用户视图集 - 完整的 CRUD 操作"""
    queryset = User.objects.filter(is_active=True)
    serializer_class = UserSerializer
    pagination_class = StandardPagination

    def get_permissions(self):
        if self.action == 'create': return [AllowAny()]
        elif self.action in ['update', 'partial_update', 'destroy']: return [IsAuthenticated(), IsOwnerOrReadOnly()]
        elif self.action in ['list', 'retrieve']: return [IsAuthenticated()]
        elif self.action == 'admin_only': return [IsAuthenticated(), IsAdminUser()]
        else: return [IsAuthenticated()]

    def get_serializer_class(self):
        if self.action == 'retrieve': return UserDetailSerializer
        return UserSerializer

    def get_queryset(self):
        queryset = super().get_queryset()
        search = self.request.query_params.get('search')
        if search:
            queryset = queryset.filter(Q(username__icontains=search) | Q(email__icontains=search) | Q(first_name__icontains=search) | Q(last_name__icontains=search))
        ordering = self.request.query_params.get('ordering', '-date_joined')
        if ordering.lstrip('-') in ['username', 'email', 'date_joined', 'last_login']:
            queryset = queryset.order_by(ordering)
        queryset = queryset.select_related('profile').prefetch_related('groups')
        return queryset

    def perform_create(self, serializer):
        user = serializer.save()
        user.last_login = timezone.now()
        user.save(update_fields=['last_login'])

    @action(detail=False, methods=['post'], permission_classes=[AllowAny])
    def register(self, request):
        serializer = self.get_serializer(data=request.data)
        serializer.is_valid(raise_exception=True)
        user = serializer.save()
        refresh = RefreshToken.for_user(user)
        return Response({'user': serializer.data, 'refresh': str(refresh), 'access': str(refresh.access_token)}, status=status.HTTP_201_CREATED)

    @action(detail=False, methods=['post'], permission_classes=[IsAuthenticated])
    def change_password(self, request):
        from django.contrib.auth.password_validation import validate_password
        from django.core.exceptions import ValidationError
        old_password = request.data.get('old_password')
        new_password = request.data.get('new_password')
        if not old_password or not new_password:
            return Response({'error': '需要原密码和新密码'}, status=status.HTTP_400_BAD_REQUEST)
        if not request.user.check_password(old_password):
            return Response({'error': '原密码错误'}, status=status.HTTP_400_BAD_REQUEST)
        try:
            validate_password(new_password, request.user)
        except ValidationError as e:
            return Response({'error': '新密码不符合要求', 'details': list(e.messages)}, status=status.HTTP_400_BAD_REQUEST)
        request.user.set_password(new_password)
        request.user.save()
        return Response({'message': '密码修改成功'})

    @action(detail=False, methods=['get'])
    def me(self, request):
        serializer = UserDetailSerializer(request.user)
        return Response(serializer.data)

    @action(detail=True, methods=['post'], permission_classes=[IsAuthenticated])
    def follow(self, request, pk=None):
        user_to_follow = self.get_object()
        if request.user == user_to_follow:
            return Response({'error': '不能关注自己'}, status=status.HTTP_400_BAD_REQUEST)
        if request.user.following.filter(id=user_to_follow.id).exists():
            return Response({'error': '已经关注该用户'}, status=status.HTTP_400_BAD_REQUEST)
        request.user.following.add(user_to_follow)
        return Response({'message': '关注成功'})

    @action(detail=True, methods=['post'], permission_classes=[IsAuthenticated])
    def unfollow(self, request, pk=None):
        user_to_unfollow = self.get_object()
        if not request.user.following.filter(id=user_to_unfollow.id).exists():
            return Response({'error': '未关注该用户'}, status=status.HTTP_400_BAD_REQUEST)
        request.user.following.remove(user_to_unfollow)
        return Response({'message': '取消关注成功'})

分页、过滤、排序

# pagination.py
from rest_framework.pagination import PageNumberPagination, CursorPagination
from rest_framework.response import Response
from collections import OrderedDict

class StandardPagination(PageNumberPagination):
    """标准分页器"""
    page_size = 20
    page_size_query_param = 'page_size'
    max_page_size = 100
    page_query_param = 'page'

    def get_paginated_response(self, data):
        return Response(OrderedDict([
            ('count', self.page.paginator.count),
            ('next', self.get_next_link()),
            ('previous', self.get_previous_link()),
            ('page_size', self.get_page_size(self.request)),
            ('current_page', self.page.number),
            ('total_pages', self.page.paginator.num_pages),
            ('results', data)
        ]))

class CursorPaginationWithCount(CursorPagination):
    """带计数的游标分页 - 适用于无限滚动"""
    page_size = 20
    ordering = '-created_at'

    def get_paginated_response(self, data):
        return Response(OrderedDict([
            ('next', self.get_next_link()),
            ('previous', self.get_previous_link()),
            ('results', data)
        ]))

class LargeResultsSetPagination(PageNumberPagination):
    """大数据集分页"""
    page_size = 100
    page_size_query_param = 'page_size'
    max_page_size = 1000

class SmallResultsSetPagination(PageNumberPagination):
    """小数据集分页"""
    page_size = 10
    page_size_query_param = 'page_size'
    max_page_size = 50
# filters.py
import django_filters
from django_filters.rest_framework import FilterSet, filters
from django.db.models import Q
from .models import Product, Category

class ProductFilter(FilterSet):
    """商品过滤器"""
    name = filters.CharFilter(lookup_expr='icontains')
    min_price = filters.NumberFilter(field_name='price', lookup_expr='gte')
    max_price = filters.NumberFilter(field_name='price', lookup_expr='lte')
    category = filters.ModelMultipleChoiceFilter(field_name='category', queryset=Category.objects.all())
    tags = filters.CharFilter(method='filter_tags')
    in_stock = filters.BooleanFilter(method='filter_in_stock')

    class Meta:
        model = Product
        fields = ['name', 'category', 'status', 'is_featured']

    def filter_tags(self, queryset, name, value):
        tags = value.split(',')
        query = Q()
        for tag in tags:
            query |= Q(tags__name__iexact=tag.strip())
        return queryset.filter(query).distinct()

    def filter_in_stock(self, queryset, name, value):
        if value:
            return queryset.filter(stock_quantity__gt=0)
        return queryset.filter(stock_quantity=0)

    @property
    def qs(self):
        queryset = super().qs
        return queryset.select_related('category').prefetch_related('tags')

class AdvancedSearchFilter(django_filters.FilterSet):
    """高级搜索过滤器"""
    q = filters.CharFilter(method='search_filter')
    sort_by = filters.CharFilter(method='sort_filter')

    def search_filter(self, queryset, name, value):
        return queryset.filter(Q(name__icontains=value) | Q(description__icontains=value) | Q(sku__icontains=value))

    def sort_filter(self, queryset, name, value):
        if value in ['price', '-price', 'created_at', '-created_at', 'name', '-name']:
            return queryset.order_by(value)
        return queryset

节流与限流

# throttles.py
from rest_framework.throttling import SimpleRateThrottle, UserRateThrottle, AnonRateThrottle
from django.core.cache import cache
import time

class BurstRateThrottle(UserRateThrottle):
    """突发请求限制"""
    scope = 'burst'
    rate = '100/minute'

class SustainedRateThrottle(UserRateThrottle):
    """持续请求限制"""
    scope = 'sustained'
    rate = '1000/day'

class MethodSpecificThrottle(SimpleRateThrottle):
    """按 HTTP 方法限流"""
    scope = 'method_specific'

    def get_cache_key(self, request, view):
        if request.user.is_authenticated:
            ident = request.user.pk
        else:
            ident = self.get_ident(request)
        return self.cache_format % {'scope': self.scope, 'ident': f"{ident}:{request.method}"}

    def get_rate(self):
        if self.request.method == 'GET': return '100/minute'
        elif self.request.method == 'POST': return '20/minute'
        elif self.request.method in ['PUT', 'PATCH']: return '10/minute'
        elif self.request.method == 'DELETE': return '5/minute'
        return '100/minute'

class SmartThrottle(SimpleRateThrottle):
    """智能节流 - 根据用户行为动态调整"""
    scope = 'smart'

    def allow_request(self, request, view):
        if self._is_whitelisted(request): return True
        user_level = self._get_user_level(request)
        if user_level == 'vip': self.rate = '1000/minute'
        elif user_level == 'premium': self.rate = '500/minute'
        else: self.rate = '100/minute'
        return super().allow_request(request, view)

    def _is_whitelisted(self, request):
        whitelist = ['127.0.0.1', '192.168.1.1']
        return request.META.get('REMOTE_ADDR') in whitelist

    def _get_user_level(self, request):
        if request.user.is_authenticated:
            if hasattr(request.user, 'profile'): return request.user.profile.level
        return 'normal'

class RedisThrottle(SimpleRateThrottle):
    """基于 Redis 的分布式节流"""
    cache = cache
    scope = 'redis_throttle'
    rate = '100/minute'

    def __init__(self):
        self.num_requests = 100
        self.duration = 60

    def get_cache_key(self, request, view):
        if request.user.is_authenticated:
            ident = request.user.pk
        else:
            ident = self.get_ident(request)
        return f"throttle:{self.scope}:{ident}"

    def allow_request(self, request, view):
        key = self.get_cache_key(request, view)
        current = cache.get(key, 0)
        if current >= self.num_requests: return False
        cache.incr(key, 1)
        if current == 0: cache.expire(key, self.duration)
        return True
# settings.py
REST_FRAMEWORK = {
    'DEFAULT_THROTTLE_CLASSES': [
        'rest_framework.throttling.AnonRateThrottle',
        'rest_framework.throttling.UserRateThrottle',
        'apps.api.throttles.BurstRateThrottle',
        'apps.api.throttles.MethodSpecificThrottle',
    ],
    'DEFAULT_THROTTLE_RATES': {
        'anon': '100/day',
        'user': '1000/day',
        'burst': '100/minute',
        'sustained': '1000/day',
        'method_specific': '100/minute',
        'smart': '100/minute',
    }
}

高级实战:企业级 API

缓存优化策略

# cache_utils.py
from django.core.cache import cache
from django.utils.decorators import method_decorator
from django.views.decorators.cache import cache_page
from django.views.decorators.vary import vary_on_headers, vary_on_cookie
from functools import wraps
import hashlib
import json

def cache_per_user(timeout):
    """按用户缓存装饰器"""
    def decorator(view_func):
        @wraps(view_func)
        def _wrapped_view(request, *args, **kwargs):
            cache_key = f"user_cache:{request.user.id}:{request.path}"
            cached_response = cache.get(cache_key)
            if cached_response is not None: return cached_response
            response = view_func(request, *args, **kwargs)
            cache.set(cache_key, response, timeout)
            return response
        return _wrapped_view
    return decorator

def cache_response(timeout=300, vary_on_user=True):
    """响应缓存装饰器"""
    def decorator(view_func):
        @wraps(view_func)
        def _wrapped_view(request, *args, **kwargs):
            key_parts = [key_prefix, request.path] if 'key_prefix' in locals() else [request.path]
            if vary_on_user and request.user.is_authenticated:
                key_parts.append(str(request.user.id))
            if request.GET:
                key_parts.append(hashlib.md5(request.GET.urlencode().encode()).hexdigest())
            cache_key = ':'.join(key_parts)
            cached_data = cache.get(cache_key)
            if cached_data is not None: return JsonResponse(cached_data)
            response = view_func(request, *args, **kwargs)
            if response.status_code == 200: cache.set(cache_key, response.data, timeout)
            return response
        return _wrapped_view
    return decorator

class CacheMixin:
    """缓存混入类"""
    cache_timeout = 300
    cache_vary_on_user = True

    @method_decorator(cache_page(cache_timeout))
    @method_decorator(vary_on_headers('Authorization'))
    def list(self, request, *args, **kwargs):
        return super().list(request, *args, **kwargs)

    def get_cache_key(self, request):
        key_parts = [self.__class__.__name__, request.method, request.path]
        if self.cache_vary_on_user and request.user.is_authenticated:
            key_parts.append(str(request.user.id))
        if request.GET:
            sorted_params = sorted(request.GET.items())
            key_parts.append(hashlib.md5(json.dumps(sorted_params).encode()).hexdigest())
        return ':'.join(key_parts)

性能监控中间件

# performance_middleware.py
import time
import json
from django.utils.deprecation import MiddlewareMixin
from django.db import connection
from django.core.cache import cache
import logging

logger = logging.getLogger('performance')

class PerformanceMiddleware(MiddlewareMixin):
    """性能监控中间件"""
    def process_request(self, request):
        request._start_time = time.time()
        request._db_queries_start = len(connection.queries)
        request._cache_hits_start = cache._cache.get_stats()[0] if hasattr(cache._cache, 'get_stats') else 0

    def process_response(self, request, response):
        total_time = (time.time() - request._start_time) * 1000
        db_queries = len(connection.queries) - request._db_queries_start
        db_time = sum(float(q['time']) for q in connection.queries[-db_queries:]) if db_queries > 0 else 0
        cache_stats = {}
        if hasattr(cache._cache, 'get_stats'): cache_stats = cache._cache.get_stats()
        log_data = {
            'method': request.method,
            'path': request.path,
            'status': response.status_code,
            'time_ms': round(total_time, 2),
            'db_queries': db_queries,
            'db_time_ms': round(db_time * 1000, 2),
            'cache_hits': getattr(request, '_cache_hits', 0),
            'cache_misses': getattr(request, '_cache_misses', 0),
            'user_id': request.user.id if request.user.is_authenticated else None,
            'user_agent': request.META.get('HTTP_USER_AGENT', '')[:200]
        }
        if total_time > 1000: logger.warning(f"慢接口:{json.dumps(log_data)}")
        else: logger.info(f"接口性能:{json.dumps(log_data)}")
        response['X-Response-Time'] = f'{total_time:.2f}ms'
        response['X-DB-Queries'] = str(db_queries)
        return response

    def process_exception(self, request, exception):
        total_time = (time.time() - request._start_time) * 1000
        logger.error(f"接口异常:{request.method} {request.path} - 耗时:{total_time:.2f}ms - 异常:{str(exception)}")

API 版本管理

# versioning.py
from rest_framework.versioning import URLPathVersioning, AcceptHeaderVersioning
from rest_framework.compat import unicode_http_header
from django.urls import reverse

class AcceptHeaderVersioningWithFallback(AcceptHeaderVersioning):
    """带回退的版本控制"""
    default_version = 'v1'
    allowed_versions = ['v1', 'v2', 'v3']

    def determine_version(self, request, *args, **kwargs):
        version = super().determine_version(request, *args, **kwargs)
        if version not in self.allowed_versions: return self.default_version
        return version

class NamespaceVersioning(URLPathVersioning):
    """命名空间版本控制"""
    default_version = 'v1'
    allowed_versions = ['v1', 'v2', 'v3']
    version_param = 'version'

    def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
        if request.version is not None:
            kwargs = {} if (kwargs is None) else kwargs
            kwargs[self.version_param] = request.version
        return super().reverse(viewname, args, kwargs, request, format, **extra)

# urls.py
from django.urls import path, include
from rest_framework.routers import DefaultRouter

router = DefaultRouter()
router.register(r'users', UserViewSet, basename='user')
urlpatterns = [
    path('api/v1/', include((router.urls, 'v1'), namespace='v1')),
    path('api/v2/', include((router.urls, 'v2'), namespace='v2')),
    path('api/versions/', include([
        path('v1/', include('apps.api.v1.urls', namespace='api_v1')),
        path('v2/', include('apps.api.v2.urls', namespace='api_v2')),
    ])),
]

性能优化指南

数据库优化

# 优化前 - 产生 N+1 查询
users = User.objects.all()
for user in users:
    print(user.profile.bio) # 每次循环都查询数据库

# 优化后 - 使用 select_related
users = User.objects.select_related('profile').all()
for user in users:
    print(user.profile.bio) # 只查询一次

# 多对多关系使用 prefetch_related
articles = Article.objects.prefetch_related('tags').all()
for article in articles:
    print([tag.name for tag in article.tags.all()])

# 使用 values/values_list 获取特定字段
user_ids = User.objects.filter(is_active=True).values_list('id', flat=True)
users_data = User.objects.values('id', 'username', 'email')

# 使用 annotate 进行聚合查询
from django.db.models import Count, Avg, Sum
stats = User.objects.aggregate(
    total=Count('id'),
    active=Count('id', filter=Q(is_active=True))
)

# 使用索引优化
class User(models.Model):
    email = models.EmailField(db_index=True)
    username = models.CharField(max_length=150, db_index=True)
    date_joined = models.DateTimeField(db_index=True)
    class Meta:
        indexes = [
            models.Index(fields=['is_active', 'date_joined']),
            models.Index(fields=['email'], name='email_idx'),
        ]

# 分页优化
from django.core.paginator import Paginator
paginator = Paginator(User.objects.all(), 10, allow_empty_first_page=False)
page = paginator.page(1)

序列化器优化

# 优化前 - 低效序列化器
class BadProductSerializer(serializers.ModelSerializer):
    category_name = serializers.CharField(source='category.name')
    reviews = serializers.SerializerMethodField()
    rating = serializers.SerializerMethodField()
    def get_reviews(self, obj): return obj.reviews.count() # N+1 查询
    def get_rating(self, obj): return obj.reviews.aggregate(Avg('rating'))['rating__avg'] # 每次都要计算

# 优化后 - 高效序列化器
class OptimizedProductSerializer(serializers.ModelSerializer):
    category_name = serializers.CharField(source='category.name', read_only=True)
    review_count = serializers.IntegerField(read_only=True)
    average_rating = serializers.FloatField(read_only=True)

    class Meta:
        model = Product
        fields = ['id', 'name', 'price', 'category_name', 'review_count', 'average_rating']

    @staticmethod
    def setup_eager_loading(queryset):
        return queryset.select_related('category').prefetch_related('reviews')

    def to_representation(self, instance):
        data = super().to_representation(instance)
        if hasattr(instance, 'review_count'): data['review_count'] = instance.review_count
        if hasattr(instance, 'average_rating'): data['average_rating'] = instance.average_rating
        return data

class ProductViewSet(viewsets.ModelViewSet):
    queryset = Product.objects.all()
    serializer_class = OptimizedProductSerializer

    def get_queryset(self):
        queryset = super().get_queryset()
        queryset = OptimizedProductSerializer.setup_eager_loading(queryset)
        queryset = queryset.annotate(review_count=Count('reviews'), average_rating=Avg('reviews__rating'))
        return queryset

缓存策略

# 多级缓存实现
from django.core.cache import caches

class MultiLevelCache:
    """多级缓存"""
    def __init__(self):
        self.l1_cache = caches['memcached']
        self.l2_cache = caches['redis']
        self.l1_ttl = 60
        self.l2_ttl = 3600

    def get(self, key):
        data = self.l1_cache.get(key)
        if data is not None: return data
        data = self.l2_cache.get(key)
        if data is not None:
            self.l1_cache.set(key, data, self.l1_ttl)
            return data
        return None

    def set(self, key, value):
        self.l1_cache.set(key, value, self.l1_ttl)
        self.l2_cache.set(key, value, self.l2_ttl)

    def delete(self, key):
        self.l1_cache.delete(key)
        self.l2_cache.delete(key)

def cache_method(timeout=300, key_prefix=''):
    """方法缓存装饰器"""
    def decorator(func):
        @wraps(func)
        def wrapper(self, *args, **kwargs):
            cache_key = f"{key_prefix}:{func.__name__}:{hashlib.md5(str(args).encode() + str(kwargs).encode()).hexdigest()}"
            cached_result = cache.get(cache_key)
            if cached_result is not None: return cached_result
            result = func(self, *args, **kwargs)
            cache.set(cache_key, result, timeout)
            return result
        return wrapper
    return decorator

class ProductViewSet(viewsets.ModelViewSet):
    @method_decorator(cache_page(60 * 5))
    def list(self, request, *args, **kwargs):
        return super().list(request, *args, **kwargs)

    @method_decorator(cache_per_user(60 * 2))
    def retrieve(self, request, *args, **kwargs):
        return super().retrieve(request, *args, **kwargs)

    @cache_method(60 * 10, 'expensive_calculation')
    def get_expensive_data(self, product_id):
        time.sleep(1)
        return {"result": "expensive_data"}

监控与告警

关键指标监控

# monitoring.py
from prometheus_client import Counter, Histogram, Gauge, Summary
from django.db import connection
import time

REQUEST_COUNT = Counter('django_http_requests_total', 'Total HTTP requests', ['method', 'endpoint', 'status'])
REQUEST_DURATION = Histogram('django_http_request_duration_seconds', 'HTTP request duration', ['method', 'endpoint'])
DB_QUERY_DURATION = Histogram('django_db_query_duration_seconds', 'Database query duration', ['model', 'operation'])
CACHE_HITS = Counter('django_cache_hits_total', 'Total cache hits')
CACHE_MISSES = Counter('django_cache_misses_total', 'Total cache misses')
ACTIVE_USERS = Gauge('django_active_users', 'Active users count')
API_ERRORS = Counter('django_api_errors_total', 'Total API errors')

class MetricsMiddleware:
    """指标收集中间件"""
    def __init__(self, get_response):
        self.get_response = get_response

    def __call__(self, request):
        start_time = time.time()
        response = self.get_response(request)
        duration = time.time() - start_time
        REQUEST_COUNT.labels(method=request.method, endpoint=request.path, status=response.status_code).inc()
        REQUEST_DURATION.labels(method=request.method, endpoint=request.path).observe(duration)
        db_queries = connection.queries
        for query in db_queries:
            query_type = query['sql'].split()[0].upper()
            DB_QUERY_DURATION.labels(model='unknown', operation=query_type).observe(float(query['time']))
        return response

def monitor_performance(name):
    """性能监控装饰器"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            start_time = time.time()
            try:
                result = func(*args, **kwargs)
                status = 'success'
            except Exception as e:
                status = 'error'
                API_ERRORS.inc()
                raise e
            finally:
                duration = time.time() - start_time
                REQUEST_DURATION.labels(method='function', endpoint=name).observe(duration)
            return result
        return wrapper
    return decorator

class ProductViewSet(viewsets.ModelViewSet):
    @monitor_performance('product_list')
    def list(self, request, *args, **kwargs):
        return super().list(request, *args, **kwargs)

告警配置

# prometheus/alerts.yml
groups:
  - name: django_api
    rules:
      - alert: HighErrorRate
        expr: rate(django_http_requests_total{status=~"5.."}[5m]) / rate(django_http_requests_total[5m]) > 0.05
        for: 2m
        labels:
          severity: critical
        annotations:
          summary: "API 错误率过高"
          description: "5 分钟内 API 错误率超过 5%"
      - alert: SlowAPIResponse
        expr: histogram_quantile(0.95, rate(django_http_request_duration_seconds_bucket[5m])) > 1
        for: 5m
        labels:
          severity: warning
        annotations:
          summary: "API 响应过慢"
          description: "95% 的 API 响应时间超过 1 秒"
      - alert: HighDatabaseLatency
        expr: histogram_quantile(0.95, rate(django_db_query_duration_seconds_bucket[5m])) > 0.5
        for: 5m
        labels:
          severity: warning
        annotations:
          summary: "数据库查询过慢"
          description: "95% 的数据库查询时间超过 500ms"
      - alert: LowCacheHitRate
        expr: rate(django_cache_hits_total[5m]) / (rate(django_cache_hits_total[5m]) + rate(django_cache_misses_total[5m])) < 0.7
        for: 10m
        labels:
          severity: warning
        annotations:
          summary: "缓存命中率过低"
          description: "缓存命中率低于 70%"

最佳实践总结

开发规范

  1. 代码结构规范
    project/
      ├── apps/
      │   ├── users/
      │   │   ├── serializers/
      │   │   ├── permissions.py
      │   │   ├── filters.py
      │   │   ├── pagination.py
      │   │   ├── throttles.py
      │   │   └── views.py
      │   └── products/
      ├── utils/
      │   ├── exceptions.py
      │   ├── response.py
      │   └── pagination.py
      └── config/
    
  2. API 设计原则
    • RESTful 风格,资源导向
    • 版本控制从 v1 开始
    • 错误信息标准化
    • 分页参数统一
    • 过滤排序标准化
  3. 安全规范
    • 所有接口默认需要认证
    • 敏感操作记录日志
    • 输入参数严格验证
    • 输出数据脱敏处理
    • 频率限制防攻击

性能优化清单

# performance_checklist.py
"""
性能优化检查清单
1. 数据库优化
   [ ] 使用 select_related/prefetch_related
   [ ] 避免 N+1 查询
   [ ] 添加合适索引
   [ ] 使用 values/values_list
   [ ] 分页优化
2. 序列化优化
   [ ] 避免 SerializerMethodField
   [ ] 使用 read_only/write_only
   [ ] 预计算字段
   [ ] 延迟加载
3. 缓存优化
   [ ] 热点数据缓存
   [ ] 查询结果缓存
   [ ] 页面片段缓存
   [ ] 缓存失效策略
4. 代码优化
   [ ] 懒加载
   [ ] 批量操作
   [ ] 异步任务
   [ ] 连接池
"""

部署配置

# settings/production.py
from .base import *
DEBUG = False
SECURE_SSL_REDIRECT = True
SECURE_HSTS_SECONDS = 31536000
SECURE_HSTS_INCLUDE_SUBDOMAINS = True
SECURE_HSTS_PRELOAD = True
SESSION_COOKIE_SECURE = True
CSRF_COOKIE_SECURE = True

DATABASES = {
    'default': {
        'ENGINE': 'django.db.backends.postgresql',
        'NAME': env('DB_NAME'),
        'USER': env('DB_USER'),
        'PASSWORD': env('DB_PASSWORD'),
        'HOST': env('DB_HOST'),
        'PORT': env('DB_PORT'),
        'CONN_MAX_AGE': 300,
        'OPTIONS': {
            'sslmode': 'require',
            'connect_timeout': 10,
        }
    }
}

CACHES = {
    'default': {
        'BACKEND': 'django_redis.cache.RedisCache',
        'LOCATION': env('REDIS_URL'),
        'OPTIONS': {
            'CLIENT_CLASS': 'django_redis.client.DefaultClient',
            'PARSER_CLASS': 'redis.connection.HiredisParser',
            'CONNECTION_POOL_CLASS': 'redis.BlockingConnectionPool',
            'CONNECTION_POOL_CLASS_KWARGS': {
                'max_connections': 50,
                'timeout': 20,
            },
            'MAX_CONNECTIONS': 1000,
        },
        'KEY_PREFIX': 'production',
    }
}

REST_FRAMEWORK = {
    'DEFAULT_RENDERER_CLASSES': ['rest_framework.renderers.JSONRenderer'],
    'DEFAULT_PARSER_CLASSES': ['rest_framework.parsers.JSONParser'],
    'DEFAULT_AUTHENTICATION_CLASSES': ['rest_framework_simplejwt.authentication.JWTAuthentication'],
    'DEFAULT_THROTTLE_CLASSES': [
        'rest_framework.throttling.AnonRateThrottle',
        'rest_framework.throttling.UserRateThrottle',
    ],
    'DEFAULT_THROTTLE_RATES': {
        'anon': '100/hour',
        'user': '1000/hour',
    },
    'EXCEPTION_HANDLER': 'utils.exceptions.production_exception_handler',
}

总结

核心要点回顾

  1. 视图集是基础:合理使用 ModelViewSet 减少重复代码。
  2. 序列化器是关键:优化序列化性能,避免 N+1 查询。
  3. 权限是保障:细粒度权限控制保证系统安全。
  4. 分页是体验:合理分页提升用户体验。
  5. 过滤是效率:强大过滤减少不必要数据传输。
  6. 节流是防护:频率限制防止系统被刷爆。

实战经验总结

必须做的:

  1. 所有接口必须认证。
  2. 关键操作必须记录日志。
  3. 输入参数必须验证。
  4. 错误信息必须友好。
  5. 性能瓶颈必须监控。

避免做的:

  1. 避免在序列化器中查询数据库。
  2. 避免返回过多数据。
  3. 避免复杂嵌套查询。
  4. 避免重复业务逻辑。
  5. 忽视安全配置。

推荐工具

  1. 开发调试:Django Debug Toolbar、django-silk
  2. API 测试:Postman、Insomnia、DRF 自带测试工具
  3. 性能监控:Prometheus、Grafana、New Relic
  4. 文档生成:drf-yasg、drf-spectacular
  5. 代码质量:Black、Flake8、MyPy

学习资源

  1. 官方文档:
    • DRF 官方文档
    • Django 官方文档
  2. 推荐书籍:
    • 《Django for APIs》
    • 《Django for Professionals》
    • 《Two Scoops of Django》
  3. 实战项目:
    • DRF 官方示例
    • Django REST API 示例

目录

  1. Django REST Framework 企业级 API 架构实战
  2. 核心原理深度解析
  3. DRF 架构设计哲学
  4. 传统 Django 视图 vs DRF 视图
  5. 传统方式 - 需要手动处理太多细节
  6. DRF 方式 - 简洁明了
  7. 视图集:CRUD 的终极抽象
  8. 序列化器:不只是数据转换
  9. 错误的序列化器用法 - 性能杀手
  10. 正确的序列化器 - 性能优化
  11. 实战:完整 API 实现
  12. 用户管理 API
  13. serializers.py
  14. permissions.py
  15. views.py
  16. 分页、过滤、排序
  17. pagination.py
  18. filters.py
  19. 节流与限流
  20. throttles.py
  21. settings.py
  22. 高级实战:企业级 API
  23. 缓存优化策略
  24. cache_utils.py
  25. 性能监控中间件
  26. performance_middleware.py
  27. API 版本管理
  28. versioning.py
  29. urls.py
  30. 性能优化指南
  31. 数据库优化
  32. 优化前 - 产生 N+1 查询
  33. 优化后 - 使用 select_related
  34. 多对多关系使用 prefetch_related
  35. 使用 values/values_list 获取特定字段
  36. 使用 annotate 进行聚合查询
  37. 使用索引优化
  38. 分页优化
  39. 序列化器优化
  40. 优化前 - 低效序列化器
  41. 优化后 - 高效序列化器
  42. 缓存策略
  43. 多级缓存实现
  44. 监控与告警
  45. 关键指标监控
  46. monitoring.py
  47. 告警配置
  48. prometheus/alerts.yml
  49. 最佳实践总结
  50. 开发规范
  51. 性能优化清单
  52. performance_checklist.py
  53. 部署配置
  54. settings/production.py
  55. 总结
  56. 核心要点回顾
  57. 实战经验总结
  58. 推荐工具
  59. 学习资源
  • 免费图片AI生成工具免费生成了解详情
  • Magick API 一键接入全球大模型注册送1000万token查看
  • 免费图片视频在线生成30秒,将你的创意变成现实开始设计
  • X/Twitter免费视频下载器免登陆无限额度免费视频解析下载了解详情
  • 100+免费在线小游戏爽一把
极客日志微信公众号二维码

微信扫一扫,关注极客日志

微信公众号「极客日志V2」,在微信中扫描左侧二维码关注。展示文案:极客日志V2 zeeklog

更多推荐文章

查看全部
  • 使用 MCP 封装火山即梦 API 搭建 AI 绘画服务
  • OpenClaw 集成 Telegram 机器人实战指南
  • Python 实现智能 PDF 文档助手 AI 小工具
  • 人形机器人站立与行走运动控制算法实现
  • 一文读懂 Blob:前端二进制数据处理入门
  • 2026 年 3 月 23 日 AI 产业要闻:脑机接口落地与算力竞赛
  • GitHub Pages 零代码搭建免费网站实战指南
  • DrissionPage 的 SessionPage 与 WebPage 详解
  • 麦橘超然 Flux 本地部署与低显存优化实战
  • 2025 年六大主流 AI 大模型产品评测与解析
  • 深度学习基础:基于 Numpy 的感知机构建与训练
  • GitHub 全界面中文化:Tampermonkey 插件安装与配置指南
  • PyCharm 安装与配置完整指南
  • PX4+ROS 无人机 Offboard 控制:模式解析与轨迹跟踪实战
  • DeepSeek 使用指南与高阶提示词技巧
  • JavaScript Proxy 代理机制与核心方法详解

相关免费在线工具

  • curl 转代码

    解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online

  • Base64 字符串编码/解码

    将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online

  • Base64 文件转换器

    将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online

  • Markdown转HTML

    将 Markdown(GFM)转为 HTML 片段,浏览器内 marked 解析;与 HTML转Markdown 互为补充。 在线工具,Markdown转HTML在线工具,online

  • HTML转Markdown

    将 HTML 片段转为 GitHub Flavored Markdown,支持标题、列表、链接、代码块与表格等;浏览器内处理,可链接预填。 在线工具,HTML转Markdown在线工具,online

  • JSON 压缩

    通过删除不必要的空白来缩小和压缩JSON。 在线工具,JSON 压缩在线工具,online