0%

DRF学习笔记之频率限制

频率限制组件

频率控制原理

以IP地址做限流,基本实现原理如下:

  • 维护一个访问列表{IP: [time2, time1]}
  • 判断IP是否在我们的访问列表里
  • 如果不在,说明是第一次访问,给访问列表加入IP: [now]的表项
  • 如果在 把now放入IP: [now, time2, time1]
  • 确保列表里最近的访问时间以及最远的访问时间差在限制的时间内
  • 判断列表长度是否是限制次数之内

源码分析

在APIView类内,调用self.initial()方法,在它的里面调用频率控制方法: self.check_throttles(request)进行频率控制:

def check_throttles(self, request):
"""
Check if request should be throttled.
Raises an appropriate exception if the request is throttled.
"""
throttle_durations = []
# throttle 是我们在drf中配置的每一个频率控制类的实例化对象
for throttle in self.get_throttles():
if not throttle.allow_request(request, self):
throttle_durations.append(throttle.wait())
.......

可以看到函数从self.get_throttles()中获取throttle对象

def get_throttles(self):
"""
Instantiates and returns the list of throttles that this view uses.
"""
# 我们在drf中配置的频率控制类的实例化对象组成的列表,返回到 check_throttle中
return [throttle() for throttle in self.throttle_classes]

再回到check_throttle中, 调用了throttle.allow_request(request, self)方法用于判断是否触发频率限制,所以这个allow_request方法需要我们自己实现

接下来check_throttle()函数执行了如下的语句:

if not throttle.allow_request(request, self):
# 当触发了频率限制的时候调用了throttle.wait()方法
throttle_durations.append(throttle.wait())

进入throttle.wait()方法

def wait(self):
"""
Optionally, return a recommended number of seconds to wait before
the next request.
"""
return None

看来,这个wait()方法也需要我们手动实现,返回一个希望下一个请求等待的时间

代码示例

实现自己的throttle:

#!/usr/bin/env python
# -*-coding:utf8-*-
import time

from rest_framework import throttling


class MyThrottle(throttling.BaseThrottle):
VisitRecord = {}

def __init__(self):
self.history = ""

def allow_request(self, request, view):
# 做频率限流
ip = request.META.get("REMOTE_ADDR")
now = time.time()
if ip not in self.VisitRecord:
self.VisitRecord[ip] = [now, ]
return True
history = self.VisitRecord[ip]
self.history = history
history.insert(0, now)
while history and history[0] - history[-1] > 60:
history.pop()
if len(history) > 3:
return False
else:
return True

def wait(self):
# 还有多久才能访问
# old + 60 - now
return self.history[-1] + 60 - self.history[0]

drf已有方法解析

SimpleRateThrottle(BaseThrottle)

代码解析

def __init__(self):
if not getattr(self, 'rate', None):
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)

get_rate()函数来获取rate值,get_rate()函数:

def get_rate():
....
try:
return self.THROTTLE_RATES[self.scope]
except KeyError:
msg = "No default throttle rate set for '%s' scope" % self.scope
raise ImproperlyConfigured(msg)

可以看到它去THROTTLE_RATE中找对应rate值,那这个THROTTLE_RATE究竟是什么呢: THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES,原来它就是从我们用户的配置中读取预先设置好的数据, 所以要使用drf这个内建的SimpleRateThrottle类, 首先必须在settings中设置好DEFAULT_THROTTLE_RATES

可以看出DEFAULT_THROTTLE_RATES对应为一个字典,字典里应该放啥呢?接着往下走到parse_rate()方法:

def parse_rate(self, rate):
"""
Given the request rate string, return a two tuple of:
<allowed number of requests>, <period of time in seconds>
"""
if rate is None:
return (None, None)
num, period = rate.split('/')
num_requests = int(num)
duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
return (num_requests, duration)

可以看出DEFAULT_THROTTLE_RATES要求的字典里是这样的格式:

  • 需要包含s/m/h/d中的一种,分别代秒/分钟/小时/天
  • 以类似2/m这种形式展示,表示1分钟2次

get_parse()方法的返回值是一个(次数,周期)组成的元组

综上所述,SimpleRate的配置方法如下

配置方法

class ThrottleUseSimpleThrottle(throttling.SimpleRateThrottle):
scope = "WD"

def get_cache_key(self, request, view):
# 返回值应为IP
return self.get_ident(request) # 调用BaseThrottle类的方法,获取对应请求IP
  • 集成throttling.SimpleRateThrottle类,实现get_cache_key()方法, 设置一个scope名称
"DEFAULT_THROTTLE_RATES": {
"WD": "3/m"
}
  • settings.py中设置一个字典DEFAULT_THROTTLE_RATES,其为一个字典,里面包含了之前设置的scope名称,对应一个频率限制值
class TestView(APIView):
throttle_classes = [ThrottleUseSimpleThrottle, ]

def get(self, request):
return Response("throttle test api.")
  • 在对应cbv中使用

效果

当没有超过频率限制时

image-20200405201437410

当超过频率限制时

image-20200405201502560

分页器

PageNumberPagination

源码解析

分页器可以对对请求的结果做分页操作,主要是由drf的PageNumberPagination类实现的,怎么配看源码:

class PageNumberPagination(BasePagination):
"""
A simple page number based style that supports page numbers as
query parameters. For example:

http://api.example.org/accounts/?page=4
http://api.example.org/accounts/?page=4&page_size=100
"""
page_size = api_settings.PAGE_SIZE # 老样子,看到api_settings就知道又从settings.py读取配置了

django_paginator_class = DjangoPaginator

page_query_param = 'page' # 可以设定用于分页的参数,默认是page(url中page=xxx的形式)
page_query_description = _('A page number within the paginated result set.')

page_size_query_param = None # 可以设置page_size默认的大小
page_size_query_description = _('Number of results to return per page.')

max_page_size = None # 可以限制前端请求的最大page_size

last_page_strings = ('last',)

template = 'rest_framework/pagination/numbers.html'

invalid_page_message = _('Invalid page.')
......

可以看出要想用PageNumberPagination, 首先需要在实现一个自己的分页类,在其中包含如下设置:

from rest_framework import pagination


class MyPagination(pagination.PageNumberPagination): # 自己的分页类需要继承PageNumberPagination
page_size = 1 # 默认page_size
page_query_param = 'page' # 分页参数(默认page)
max_page_size = 4 # 最大允许的page_size

在CBV中使用

写好这个分页类后,在CBV中引用它:

class BookView(APIView):

def get(self, request):
queryset = Book.objects.all()
# 第一步,实例化分页器对象
paginator = MyPagination()
# 第二步,调用这个分页器类的分页方法paginate_queryset
page_queryset = paginator.paginate_queryset(queryset, request)
serial_obj = BookSerializer(instance=page_queryset, many=True)
return Response(status=status.HTTP_200_OK, data=serial_obj.data)

在浏览器中打开:

image-20200405224121185

这样在浏览器上通过更改请求参数page=2、page=3就可以翻页了

不过这样不太完美,最好能在返回的数据中附上返回结果的总数,上一页和下一页的链接就更好了,这样在以后前后端分离的时候也会方便前端请求,怎么做看PageNumberPagination类源码:

def get_paginated_response(self, data):
return Response(OrderedDict([
('count', self.page.paginator.count),
('next', self.get_next_link()),
('previous', self.get_previous_link()),
('results', data)
]))

人家已经贴心的给你写好了,返回结果是一个字典,里面有结果总数,上一页和下一页,返回结果,改一下视图中代码:

class BookView(APIView):

def get(self, request):
queryset = Book.objects.all()
paginator = MyPagination()
page_queryset = paginator.paginate_queryset(queryset, request)
serial_obj = BookSerializer(instance=page_queryset, many=True)
# return Response(status=status.HTTP_200_OK, data=serial_obj.data)
return paginator.get_paginated_response(serial_obj.data)

image-20200405225221017

这下返回的结果中也包含上一页下一页链接了: )

LimitOffsetPagination

采用从当前第offset页向后找limit个元素的方式分页

源码解析

LimitOffsetPagination类是怎么实现的, 开头:

class LimitOffsetPagination(BasePagination):  # 继承自BasePagination类,BasePagination中该有的它都有
"""
A limit/offset based style. For example:

http://api.example.org/accounts/?limit=100
http://api.example.org/accounts/?offset=400&limit=100
"""
default_limit = api_settings.PAGE_SIZE # 老样子又去settings.py里面找我们的配置去了
limit_query_param = 'limit' # 用于标识每页元素的参数,默认为param
limit_query_description = _('Number of results to return per page.')
offset_query_param = 'offset' # 用于标识起始元素的参数,默认为offset
offset_query_description = _('The initial index from which to return the results.')
max_limit = None # 用于标识每页最大可以获取的元素
template = 'rest_framework/pagination/numbers.html'
......

那我们在使用时也需要按照它的要求实现如下:

class MyLimitOffsetPagination(pagination.LimitOffsetPagination):
default_limit = 2
limit_query_param = 'limit'
offset_query_param = 'offset'
max_limit = 4

在CBV中使用

from utils.pagination import MyPagination, MyLimitOffsetPagination


class BookView(APIView):

def get(self, request):
queryset = Book.objects.all()
paginator = MyLimitOffsetPagination()
page_queryset = paginator.paginate_queryset(queryset, request)
serial_obj = BookSerializer(instance=page_queryset, many=True)
return paginator.get_paginated_response(serial_obj.data)

实现的效果:

image-20200406174811661

可以看到从第2个元素起,获取limit为1的元素,也就是第3个元素

image-20200406174839645

从第3个元素起,获取1个元素,即id为4的元素

CursorPagination

游标加密方式的排序,可以避免访问者知道数据库中数据的数量

效果如下:

image-20200406212111927

此时next和previous连接变成了一个加密的字符串,使访问者不可读

使用以及实现的方法可以看源码,其实和前面两种差的八九不离十:

class MyCursorPagination(pagination.CursorPagination):
cursor_query_param = 'cursor'
page_size = 1
ordering = '-id'

唯一注意有所不同是多了一个对返回结果排序的字段ordering,以及返回的结果里没有count字段了(本来就不想被人知道实际数据库里数据数量的)。

欢迎关注我的其它发布渠道