DRF-Serializers序列化器元件原始碼分析及改編

harry6發表於2024-10-28

1. 原始碼分析

注意:以下程式碼片段為方便理解已進行簡化,只保留了與序列化功能相關的程式碼

序列化的原始碼中涉及到了元類的概念,我在這裡簡單說明一下:元類(metaclass)是一個高階概念,用於定義類的建立行為。簡單來說,元類是建立類的類,它決定了類的建立方式和行為。

在 Python 中一切皆為物件,包括類。每個類都有一個元類,它定義瞭如何建立這個類。通常情況下 Python 會使用預設的元類 type 來建立類。但是,當我們需要對類的建立過程進行自定義時,就可以使用元類,舉例:

class Mytype(type)
	def __new__(cls,name,bases,attrs):   # 類名,繼承的父類 ,成員
        # 此處可對要建立的類進行操作
        del attrs["v1"]
        attrs["name"] = "harry"
        
        xx = super().__new__(cls,name,bases,attrs)  # 呼叫type類建立物件(這個物件就是Bar類)
        retyrn xx 

        
class Bar(object, metaclass=Mytype)  # metaclass指定自定義元類
	v1 = 123
    
    def func(self):
        pass
    
由於元類中刪除了v1屬性,且增加了name屬性,因此此時Bar中無v1屬性,且多了name屬性

另:父類如果指定了元類metaclass,則其子類預設是用該元類來建立類

補充:例項化Bar類時,相當於是 type物件(),因此會觸發type類的__call__方法,其中就呼叫了Bar的__new__和__init__,因此在例項化類時才會自動觸發類的__new__和__init__方法。本質上是因為 物件() 而呼叫了type元類的call方法;


Serializers元件主要有兩個功能:序列化和資料校驗

  1. 序列化部分:
    首先定義一個序列化類
class DepartSerializer(serializers.Serializer):
    '''Serializer校驗'''
    # 內建校驗
    title = serializers.CharField(required=True, max_length=20, min_length=6)
    order = serializers.IntegerField(required=False, max_value=100, min_value=10)
    count = serializers.ChoiceField(choices=[(1, "高階"), (2, "中級")])

檢視Serializer的父類,可知其是透過SerializerMetaclass元類建立的

Serializer(BaseSerializer, metaclass=SerializerMetaclass)

SerializerMetaclass元類原始碼:

class SerializerMetaclass(type):
    @classmethod
    def _get_declared_fields(cls, bases, attrs):
        fields = [(field_name, attrs.pop(field_name))  # 透過迴圈獲取field欄位物件
                  for field_name, obj in list(attrs.items())
                  if isinstance(obj, Field)]
        fields.sort(key=lambda x: x[1]._creation_counter)

        known = set(attrs)
        def visit(name):
            known.add(name)
            return name

        base_fields = [
            (visit(name), f)
            for base in bases if hasattr(base, '_declared_fields')
            for name, f in base._declared_fields.items() if name not in known
        ]

        return OrderedDict(base_fields + fields)

    def __new__(cls, name, bases, attrs):
        attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs)    # 為類中增加了_declared_fields屬性,其中封裝了所有的Field欄位名及對應的物件
        return super().__new__(cls, name, bases, attrs)

透過serializer.data觸發序列化流程:

    @property
    def data(self):
        ret = super().data   # 尋找其父類BaseSerializer的data方法
        return ReturnDict(ret, serializer=self)

BaseSerializer的data方法原始碼:

    @property
    def data(self):
        if hasattr(self, 'initial_data') and not hasattr(self, '_validated_data'):
            msg = (
                'When a serializer is passed a `data` keyword argument you '
                'must call `.is_valid()` before attempting to access the '
                'serialized `.data` representation.\n'
                'You should either call `.is_valid()` first, '
                'or access `.initial_data` instead.'
            )
            raise AssertionError(msg)

        if not hasattr(self, '_data'):
            if self.instance is not None and not getattr(self, '_errors', None):
                self._data = self.to_representation(self.instance)    # 執行to_representation方法獲取序列化資料
            elif hasattr(self, '_validated_data') and not getattr(self, '_errors', None):
                self._data = self.to_representation(self.validated_data)
            else:
                self._data = self.get_initial()
        return self._data

to_representation方法原始碼(核心):

    def to_representation(self, instance):
        ret = OrderedDict()
        fields = self._readable_fields  # 篩選出可讀的欄位物件(其內部對_declared_fields欄位進行了深複製)

        for field in fields:
            try:
                attribute = field.get_attribute(instance)  # 迴圈欄位物件列表,並執行get_attribute方法獲取對應的值
            except SkipField:
                continue
            check_for_none = attribute.pk if isinstance(attribute, PKOnlyObject) else attribute
            if check_for_none is None:
                ret[field.field_name] = None
            else:
                ret[field.field_name] = field.to_representation(attribute)  # 執行to_representation轉換格式,並將所有資料封裝到ret字典中

        return ret

get_attribute方法原始碼:

def get_attribute(self, instance):
    return get_attribute(instance, self.source_attrs)
def get_attribute(instance, attrs): # attrs為source欄位值  instance為模型物件
    for attr in attrs:
        try:
            if isinstance(instance, Mapping):
                instance = instance[attr]
            else:
                instance = getattr(instance, attr)  # 迴圈獲取模型物件最終的attr的值
        except ObjectDoesNotExist:
            return None
    return instance  # 返回該欄位值




2. 資料校驗部分
使用is_valid方法校驗資料,獲取_errors資料,_errors存在則is_valid返回False。在執行該函式的過程中,觸發了run_validation方法:

    def is_valid(self, raise_exception=False):
        if not hasattr(self, '_validated_data'):
            try: # 觸發了run_validation方法
                self._validated_data = self.run_validation(self.initial_data) 
            except ValidationError as exc:
                self._validated_data = {}
                self._errors = exc.detail
            else:
                self._errors = {}

        if self._errors and raise_exception:
            raise ValidationError(self.errors)

        return not bool(self._errors)****

run_validation方法,注意該方法是Serializer類下的方法,不是Field類的方法。在to_internal_value方法中呼叫欄位內建校驗,並執行鉤子函式。

    def run_validation(self, data=empty):

        (is_empty_value, data) = self.validate_empty_values(data)
        if is_empty_value:
            return data

        value = self.to_internal_value(data)  # 呼叫欄位內建校驗,並執行鉤子函式
        try:
            self.run_validators(value)
            value = self.validate(value)
            assert value is not None, '.validate() should return the validated data'
        except (ValidationError, DjangoValidationError) as exc:
            raise ValidationError(detail=as_serializer_error(exc))

        return value

to_internal_value方法,fileds從_declared_fields中深複製而得到,且只包含了只寫的欄位物件

    def to_internal_value(self, data):
        if not isinstance(data, Mapping):
            message = self.error_messages['invalid'].format(
                datatype=type(data).__name__
            )
            raise ValidationError({
                api_settings.NON_FIELD_ERRORS_KEY: [message]
            }, code='invalid')

        ret = OrderedDict()
        errors = OrderedDict()
        fields = self._writable_fields  # 篩選只寫的欄位物件

        for field in fields:
            validate_method = getattr(self, 'validate_' + field.field_name, None)
            primitive_value = field.get_value(data)
            try:
                validated_value = field.run_validation(primitive_value) # 執行內建校驗
                if validate_method is not None:
                    validated_value = validate_method(validated_value)  # 執行鉤子函式進行校驗
            except ValidationError as exc:
                errors[field.field_name] = exc.detail
            except DjangoValidationError as exc:
                errors[field.field_name] = get_error_detail(exc)
            except SkipField:
                pass
            else:
                set_value(ret, field.source_attrs, validated_value)
        if errors:
            raise ValidationError(errors)
        return ret

run_validation內建校驗:

    def run_validation(self, data=empty):
        if data == '' or (self.trim_whitespace and str(data).strip() == ''):
            if not self.allow_blank:
                self.fail('blank')
            return ''
        return super().run_validation(data)

    # 父類的run_validation方法
    def run_validation(self, data=empty):

        (is_empty_value, data) = self.validate_empty_values(data)
        if is_empty_value:
            return data
        value = self.to_internal_value(data)
        self.run_validators(value)  # 呼叫欄位定義的run_validators進行校驗
        return value

2、原始碼改編:

  • 自定義鉤子:讓某欄位既能支援前端傳入,又能自定義序列化返回的值;(SerializerMethodField預設是隻可讀的,使用者無法輸入,而普通field又無法自定義複雜邏輯返回值)

思路:在呼叫ser.data開始序列化後的to_representation方法中判斷有無自定義格式的鉤子,如果有則替換掉該欄位物件的值

    def to_representation(self, instance):
        ret = OrderedDict()
        fields = self._readable_fields

        for field in fields:
            if hasattr(self, 'get_%s' % field.field_name):  # 判斷是否有"get_xxx"形式的函式,如則執行該方法並將instance傳入
                value = getattr(self, 'get_%s' % field.field_name)(instance)
                ret[field.field_name] = value
            else:
                try:
                    attribute = field.get_attribute(instance)
                except SkipField:
                    continue

                check_for_none = attribute.pk if isinstance(attribute, PKOnlyObject) else attribute
                if check_for_none is None:
                    ret[field.field_name] = None
                else:
                    ret[field.field_name] = field.to_representation(attribute)

        return ret

如果其他類中也需要使用該重寫方法,可將該重新方法封裝成類,其他類中繼承該類後,即可不用每次都重寫to_representation方法

相關文章