Django rest framework原始碼分析(3)----節流
新增節流
自定義節流的方法
- 限制60s內只能訪問3次
(1)API資料夾下面新建throttle.py,程式碼如下:
# utils/throttle.py
from rest_framework.throttling import BaseThrottle
import time
VISIT_RECORD = {} #儲存訪問記錄
class VisitThrottle(BaseThrottle):
'''60s內只能訪問3次'''
def __init__(self):
self.history = None #初始化訪問記錄
def allow_request(self,request,view):
#獲取使用者ip (get_ident)
remote_addr = self.get_ident(request)
ctime = time.time()
#如果當前IP不在訪問記錄裡面,就新增到記錄
if remote_addr not in VISIT_RECORD:
VISIT_RECORD[remote_addr] = [ctime,] #鍵值對的形式儲存
return True #True表示可以訪問
#獲取當前ip的歷史訪問記錄
history = VISIT_RECORD.get(remote_addr)
#初始化訪問記錄
self.history = history
#如果有歷史訪問記錄,並且最早一次的訪問記錄離當前時間超過60s,就刪除最早的那個訪問記錄,
#只要為True,就一直迴圈刪除最早的一次訪問記錄
while history and history[-1] < ctime - 60:
history.pop()
#如果訪問記錄不超過三次,就把當前的訪問記錄插到第一個位置(pop刪除最後一個)
if len(history) < 3:
history.insert(0,ctime)
return True
def wait(self):
'''還需要等多久才能訪問'''
ctime = time.time()
return 60 - (ctime - self.history[-1])
(2)settings中全域性配置節流
#全域性
REST_FRAMEWORK = {
#節流
"DEFAULT_THROTTLE_CLASSES":['API.utils.throttle.VisitThrottle'],
}
(3)現在訪問auth看看結果:
- 60s內訪問次數超過三次,會限制訪問
- 提示剩餘多少時間可以訪問
接著訪問
節流原始碼分析
(1)dispatch
def dispatch(self, request, *args, **kwargs):
"""
`.dispatch()` is pretty much the same as Django's regular dispatch,
but with extra hooks for startup, finalize, and exception handling.
"""
self.args = args
self.kwargs = kwargs
#對原始request進行加工,豐富了一些功能
#Request(
# request,
# parsers=self.get_parsers(),
# authenticators=self.get_authenticators(),
# negotiator=self.get_content_negotiator(),
# parser_context=parser_context
# )
#request(原始request,[BasicAuthentications物件,])
#獲取原生request,request._request
#獲取認證類的物件,request.authticators
#1.封裝request
request = self.initialize_request(request, *args, **kwargs)
self.request = request
self.headers = self.default_response_headers # deprecate?
try:
#2.認證
self.initial(request, *args, **kwargs)
# Get the appropriate handler method
if request.method.lower() in self.http_method_names:
handler = getattr(self, request.method.lower(),
self.http_method_not_allowed)
else:
handler = self.http_method_not_allowed
response = handler(request, *args, **kwargs)
except Exception as exc:
response = self.handle_exception(exc)
self.response = self.finalize_response(request, response, *args, **kwargs)
return self.response
(2)initial
def initial(self, request, *args, **kwargs):
"""
Runs anything that needs to occur prior to calling the method handler.
"""
self.format_kwarg = self.get_format_suffix(**kwargs)
# Perform content negotiation and store the accepted info on the request
neg = self.perform_content_negotiation(request)
request.accepted_renderer, request.accepted_media_type = neg
# Determine the API version, if versioning is in use.
version, scheme = self.determine_version(request, *args, **kwargs)
request.version, request.versioning_scheme = version, scheme
# Ensure that the incoming request is permitted
#4.實現認證
self.perform_authentication(request)
#5.許可權判斷
self.check_permissions(request)
#6.控制訪問頻率
self.check_throttles(request)
(3)check_throttles
裡面有個allow_request
def check_throttles(self, request):
"""
Check if request should be throttled.
Raises an appropriate exception if the request is throttled.
"""
for throttle in self.get_throttles():
if not throttle.allow_request(request, self):
self.throttled(request, throttle.wait())
(4)get_throttles
def get_throttles(self):
"""
Instantiates and returns the list of throttles that this view uses.
"""
return [throttle() for throttle in self.throttle_classes]
(5)thtottle_classes
內建節流類
上面是寫的自定義節流,drf內建了很多節流的類,用起來比較方便。
(1)BaseThrottle
- 自己要寫allow_request和wait方法
- get_ident就是獲取ip
class BaseThrottle(object):
"""
Rate throttling of requests.
"""
def allow_request(self, request, view):
"""
Return `True` if the request should be allowed, `False` otherwise.
"""
raise NotImplementedError('.allow_request() must be overridden')
def get_ident(self, request):
"""
Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR
if present and number of proxies is > 0. If not use all of
HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
"""
xff = request.META.get('HTTP_X_FORWARDED_FOR')
remote_addr = request.META.get('REMOTE_ADDR')
num_proxies = api_settings.NUM_PROXIES
if num_proxies is not None:
if num_proxies == 0 or xff is None:
return remote_addr
addrs = xff.split(',')
client_addr = addrs[-min(num_proxies, len(addrs))]
return client_addr.strip()
return ''.join(xff.split()) if xff else remote_addr
def wait(self):
"""
Optionally, return a recommended number of seconds to wait before
the next request.
"""
return None
(2)SimpleRateThrottle
class SimpleRateThrottle(BaseThrottle):
"""
A simple cache implementation, that only requires `.get_cache_key()`
to be overridden.
The rate (requests / seconds) is set by a `rate` attribute on the View
class. The attribute is a string of the form 'number_of_requests/period'.
Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')
Previous request information used for throttling is stored in the cache.
"""
cache = default_cache
timer = time.time
cache_format = 'throttle_%(scope)s_%(ident)s'
scope = None #這個值自定義,寫什麼都可以
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
def __init__(self):
if not getattr(self, 'rate', None):
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)
def get_cache_key(self, request, view):
"""
Should return a unique cache-key which can be used for throttling.
Must be overridden.
May return `None` if the request should not be throttled.
"""
raise NotImplementedError('.get_cache_key() must be overridden')
def get_rate(self):
"""
Determine the string representation of the allowed request rate.
"""
if not getattr(self, 'scope', None):
msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
self.__class__.__name__)
raise ImproperlyConfigured(msg)
try:
return self.THROTTLE_RATES[self.scope]
except KeyError:
msg = "No default throttle rate set for '%s' scope" % self.scope
raise ImproperlyConfigured(msg)
def parse_rate(self, rate):
"""
Given the request rate string, return a two tuple of:
<allowed number of requests>, <period of time in seconds>
"""
if rate is None:
return (None, None)
num, period = rate.split('/')
num_requests = int(num)
duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
return (num_requests, duration)
def allow_request(self, request, view):
"""
Implement the check to see if the request should be throttled.
On success calls `throttle_success`.
On failure calls `throttle_failure`.
"""
if self.rate is None:
return True
self.key = self.get_cache_key(request, view)
if self.key is None:
return True
self.history = self.cache.get(self.key, [])
self.now = self.timer()
# Drop any requests from the history which have now passed the
# throttle duration
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()
def throttle_success(self):
"""
Inserts the current request's timestamp along with the key
into the cache.
"""
self.history.insert(0, self.now)
self.cache.set(self.key, self.history, self.duration)
return True
def throttle_failure(self):
"""
Called when a request to the API has failed due to throttling.
"""
return False
def wait(self):
"""
Returns the recommended next request time in seconds.
"""
if self.history:
remaining_duration = self.duration - (self.now - self.history[-1])
else:
remaining_duration = self.duration
available_requests = self.num_requests - len(self.history) + 1
if available_requests <= 0:
return None
return remaining_duration / float(available_requests)
我們可以通過繼承SimpleRateThrottle類,來實現節流,會更加的簡單,因為SimpleRateThrottle裡面都幫我們寫好了
(1)throttle.py
from rest_framework.throttling import SimpleRateThrottle
class VisitThrottle(SimpleRateThrottle):
'''匿名使用者60s只能訪問三次(根據ip)'''
scope = 'NBA' #這裡面的值,自己隨便定義,settings裡面根據這個值配置Rate
def get_cache_key(self, request, view):
#通過ip限制節流
return self.get_ident(request)
class UserThrottle(SimpleRateThrottle):
'''登入使用者60s可以訪問10次'''
scope = 'NBAUser' #這裡面的值,自己隨便定義,settings裡面根據這個值配置Rate
def get_cache_key(self, request, view):
return request.user.username
(2)settings.py
#全域性
REST_FRAMEWORK = {
#節流
"DEFAULT_THROTTLE_CLASSES":['API.utils.throttle.UserThrottle'], #全域性配置,登入使用者節流限制(10/m)
"DEFAULT_THROTTLE_RATES":{
'NBA':'3/m', #沒登入使用者3/m,NBA就是scope定義的值
'NBAUser':'10/m', #登入使用者10/m,NBAUser就是scope定義的值
}
}
(3)views.py
區域性配置方法
class AuthView(APIView):
.
.
.
# 預設的節流是登入使用者(10/m),AuthView不需要登入,這裡用匿名使用者的節流(3/m)
throttle_classes = [VisitThrottle,] . .
# views.py
from django.shortcuts import render,HttpResponse
from django.http import JsonResponse
from rest_framework.views import APIView
from API import models
from rest_framework.request import Request
from rest_framework import exceptions
from rest_framework.authentication import BaseAuthentication
from API.utils.permission import SVIPPremission,MyPremission
from API.utils.throttle import VisitThrottle
ORDER_DICT = {
1:{
'name':'apple',
'price':15
},
2:{
'name':'dog',
'price':100
}
}
def md5(user):
import hashlib
import time
#當前時間,相當於生成一個隨機的字串
ctime = str(time.time())
m = hashlib.md5(bytes(user,encoding='utf-8'))
m.update(bytes(ctime,encoding='utf-8'))
return m.hexdigest()
class AuthView(APIView):
'''用於使用者登入驗證'''
authentication_classes = [] #裡面為空,代表不需要認證
permission_classes = [] #不裡面為空,代表不需要許可權
# 預設的節流是登入使用者(10/m),AuthView不需要登入,這裡用匿名使用者的節流(3/m)
throttle_classes = [VisitThrottle,]
def post(self,request,*args,**kwargs):
ret = {'code':1000,'msg':None}
try:
user = request._request.POST.get('username')
pwd = request._request.POST.get('password')
obj = models.UserInfo.objects.filter(username=user,password=pwd).first()
if not obj:
ret['code'] = 1001
ret['msg'] = '使用者名稱或密碼錯誤'
#為使用者建立token
token = md5(user)
#存在就更新,不存在就建立
models.UserToken.objects.update_or_create(user=obj,defaults={'token':token})
ret['token'] = token
except Exception as e:
ret['code'] = 1002
ret['msg'] = '請求異常'
return JsonResponse(ret)
class OrderView(APIView):
'''
訂單相關業務(只有SVIP使用者才能看)
'''
def get(self,request,*args,**kwargs):
self.dispatch
#request.user
#request.auth
ret = {'code':1000,'msg':None,'data':None}
try:
ret['data'] = ORDER_DICT
except Exception as e:
pass
return JsonResponse(ret)
class UserInfoView(APIView):
'''
訂單相關業務(普通使用者和VIP使用者可以看)
'''
permission_classes = [MyPremission,] #不用全域性的許可權配置的話,這裡就要寫自己的區域性許可權
def get(self,request,*args,**kwargs):
print(request.user)
return HttpResponse('使用者資訊')
說明:
- API.utils.throttle.UserThrottle 這個是全域性配置(根據ip限制,10/m)
- DEFAULT_THROTTLE_RATES --->>>設定訪問頻率的
- throttle_classes = [VisitThrottle,] --->>>區域性配置(不適用settings裡面預設的全域性配置)
總結
基本使用
- 建立類,繼承BaseThrottle, 實現:allow_request ,wait
- 建立類,繼承SimpleRateThrottle, 實現: get_cache_key, scope='NBA' (配置檔案中的key)
全域性
#節流
"DEFAULT_THROTTLE_CLASSES":['API.utils.throttle.UserThrottle'], #全域性配置,登入使用者節流限制(10/m)
"DEFAULT_THROTTLE_RATES":{
'NBA':'3/m', #沒登入使用者3/m,NBA就是scope定義的值
'NBAUser':'10/m', #登入使用者10/m,NBAUser就是scope定義的值
}
}
區域性
throttle_classes = [VisitThrottle,]
所有程式碼
認證、許可權和節流
# MyProject/urls.py
from django.contrib import admin
from django.urls import path
from API.views import AuthView,OrderView,UserInfoView
urlpatterns = [
path('admin/', admin.site.urls),
path('api/v1/auth/',AuthView.as_view()),
path('api/v1/order/',OrderView.as_view()),
path('api/v1/info/',UserInfoView.as_view()),
]
#全域性 settings.py
REST_FRAMEWORK = {
#認證
"DEFAULT_AUTHENTICATION_CLASSES":['API.utils.auth.Authentication',],
#許可權
"DEFAULT_PERMISSION_CLASSES":['API.utils.permission.SVIPPermission'],
#節流
"DEFAULT_THROTTLE_CLASSES":['API.utils.throttle.UserThrottle'], #全域性配置,登入使用者節流限制(10/m)
"DEFAULT_THROTTLE_RATES":{
'NBA':'3/m', #沒登入使用者3/m,NBA就是scope定義的值
'NBAUser':'10/m', #登入使用者10/m,NBAUser就是scope定義的值
}
}
# API/models.py
from django.db import models
class UserInfo(models.Model):
USER_TYPE = (
(1,'普通使用者'),
(2,'VIP'),
(3,'SVIP')
)
user_type = models.IntegerField(choices=USER_TYPE)
username = models.CharField(max_length=32)
password = models.CharField(max_length=64)
class UserToken(models.Model):
user = models.OneToOneField(UserInfo,on_delete=models.CASCADE)
token = models.CharField(max_length=64)
# API/views.py
from django.shortcuts import render,HttpResponse
from django.http import JsonResponse
from rest_framework.views import APIView
from API import models
from rest_framework.request import Request
from rest_framework import exceptions
from rest_framework.authentication import BaseAuthentication
from API.utils.permission import SVIPPermission,MyPermission
from API.utils.throttle import VisitThrottle
ORDER_DICT = {
1:{
'name':'apple',
'price':15
},
2:{
'name':'dog',
'price':100
}
}
def md5(user):
import hashlib
import time
#當前時間,相當於生成一個隨機的字串
ctime = str(time.time())
m = hashlib.md5(bytes(user,encoding='utf-8'))
m.update(bytes(ctime,encoding='utf-8'))
return m.hexdigest()
class AuthView(APIView):
'''用於使用者登入驗證'''
authentication_classes = [] #裡面為空,代表不需要認證
permission_classes = [] #不裡面為空,代表不需要許可權
# 預設的節流是登入使用者(10/m),AuthView不需要登入,這裡用匿名使用者的節流(3/m)
throttle_classes = [VisitThrottle,]
def post(self,request,*args,**kwargs):
ret = {'code':1000,'msg':None}
try:
user = request._request.POST.get('username')
pwd = request._request.POST.get('password')
obj = models.UserInfo.objects.filter(username=user,password=pwd).first()
if not obj:
ret['code'] = 1001
ret['msg'] = '使用者名稱或密碼錯誤'
#為使用者建立token
token = md5(user)
#存在就更新,不存在就建立
models.UserToken.objects.update_or_create(user=obj,defaults={'token':token})
ret['token'] = token
except Exception as e:
ret['code'] = 1002
ret['msg'] = '請求異常'
return JsonResponse(ret)
class OrderView(APIView):
'''
訂單相關業務(只有SVIP使用者才能看)
'''
def get(self,request,*args,**kwargs):
self.dispatch
#request.user
#request.auth
ret = {'code':1000,'msg':None,'data':None}
try:
ret['data'] = ORDER_DICT
except Exception as e:
pass
return JsonResponse(ret)
class UserInfoView(APIView):
'''
訂單相關業務(普通使用者和VIP使用者可以看)
'''
permission_classes = [MyPermission,] #不用全域性的許可權配置的話,這裡就要寫自己的區域性許可權
def get(self,request,*args,**kwargs):
print(request.user)
return HttpResponse('使用者資訊')
# API/utils/auth/py
from rest_framework import exceptions
from API import models
from rest_framework.authentication import BaseAuthentication
class Authentication(BaseAuthentication):
'''用於使用者登入驗證'''
def authenticate(self,request):
token = request._request.GET.get('token')
token_obj = models.UserToken.objects.filter(token=token).first()
if not token_obj:
raise exceptions.AuthenticationFailed('使用者認證失敗')
#在rest framework內部會將這兩個欄位賦值給request,以供後續操作使用
return (token_obj.user,token_obj)
def authenticate_header(self, request):
pass
# utils/permission.py
from rest_framework.permissions import BasePermission
class SVIPPermission(BasePermission):
message = "必須是SVIP才能訪問"
def has_permission(self,request,view):
if request.user.user_type != 3:
return False
return True
class MyPermission(BasePermission):
def has_permission(self,request,view):
if request.user.user_type == 3:
return False
return True
# utils/throttle.py
#
# from rest_framework.throttling import BaseThrottle
# import time
# VISIT_RECORD = {} #儲存訪問記錄
#
# class VisitThrottle(BaseThrottle):
# '''60s內只能訪問3次'''
# def __init__(self):
# self.history = None #初始化訪問記錄
#
# def allow_request(self,request,view):
# #獲取使用者ip (get_ident)
# remote_addr = self.get_ident(request)
# ctime = time.time()
# #如果當前IP不在訪問記錄裡面,就新增到記錄
# if remote_addr not in VISIT_RECORD:
# VISIT_RECORD[remote_addr] = [ctime,] #鍵值對的形式儲存
# return True #True表示可以訪問
# #獲取當前ip的歷史訪問記錄
# history = VISIT_RECORD.get(remote_addr)
# #初始化訪問記錄
# self.history = history
#
# #如果有歷史訪問記錄,並且最早一次的訪問記錄離當前時間超過60s,就刪除最早的那個訪問記錄,
# #只要為True,就一直迴圈刪除最早的一次訪問記錄
# while history and history[-1] < ctime - 60:
# history.pop()
# #如果訪問記錄不超過三次,就把當前的訪問記錄插到第一個位置(pop刪除最後一個)
# if len(history) < 3:
# history.insert(0,ctime)
# return True
#
# def wait(self):
# '''還需要等多久才能訪問'''
# ctime = time.time()
# return 60 - (ctime - self.history[-1])
from rest_framework.throttling import SimpleRateThrottle
class VisitThrottle(SimpleRateThrottle):
'''匿名使用者60s只能訪問三次(根據ip)'''
scope = 'NBA' #這裡面的值,自己隨便定義,settings裡面根據這個值配置Rate
def get_cache_key(self, request, view):
#通過ip限制節流
return self.get_ident(request)
class UserThrottle(SimpleRateThrottle):
'''登入使用者60s可以訪問10次'''
scope = 'NBAUser' #這裡面的值,自己隨便定義,settings裡面根據這個值配置Rate
def get_cache_key(self, request, view):
return request.user.username