torch.fxPytorch 1.8出來的一套工具或者說一個庫,是做python-to-python code transformation,大意就是可以把pytorch中的python前向程式碼轉換為你想要的樣子,官方介紹如下:

We apply this principle in torch.fx, a program capture and
transformation library for PyTorch written entirely in Python and optimized for high developer productivity by ML practitioners
上述來源於FX的論文,感興趣的可以看TORCH.FX: PRACTICAL PROGRAM CAPTURE AND TRANSFORMATION FOR DEEP LEARNING IN PYTHON這篇,知乎上也有一篇不錯的解讀,這裡就不復述了。不過本文也會介紹論文中的內容,更多的是以實踐的角度。

核心的關鍵詞是program capturetransformation library,這兩個概念很重要。


class MyModule(torch.nn.Module):
    def __init__(self):
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)


如果我們想把這個Module中forward中的一部分操作邏輯self.linear(x + self.param).clamp(min=0.0, max=1.0)clamp部分替換為sigmoid,應該怎麼搞呢?


這時候就需要FX,不需要我們手動修改程式碼(就是自己改這個forward實現),只需要設定好規則,使用torch.fx,帶入這個模型例項進去,跑一下程式碼。然後你的這個MyModule中forward部分就會變為self.linear(x + self.param).sigmoid()

module = MyModule()

from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)

# High-level intermediate representation (IR) - Graph representation
# 列印檢視FX的IR
    %x : [#users=1] = placeholder[target=x]
    %param : [#users=1] = get_attr[target=param]
    %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp

# Code generation - valid Python code
# 通過FX生成的程式碼,可以視為module中的forward程式碼
def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    linear = self.linear(add);  add = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp

這樣,FX會幫助你修改這個Module,並且修改好的這個model就和平常一樣使用就可以,注意這裡,FX capture 了你寫的forward程式碼,然後進行了transform,修改了其中的操作。


  • 融合兩個op,比如conv和bn
  • 去掉某些op
  • 替換某些op
  • 在某些op後插入一些op或者其他操作


可能大家會疑惑,這些操作是不是很像AI編譯器中的PASS,而操作物件也是神經網路這種DAG(有向無環圖)。其實吧,FX你也可以理解為是一種編譯器,不過這個編譯器最終產生的可執行檔案,而是python->python,最終的產物還是基於Pytorch規則的python程式碼,也就是為什麼FX一直說自己是Python-to-Python (or Module-to-Module) transformation toolkit而不是compiler了。











torch.fx is different from TorchScript in that it is a platform for Python-to-Python transformations of PyTorch code. TorchScript, on the other hand, is more targeted at moving PyTorch programs outside of Python for deployment purposes. In this sense, FX and TorchScript are orthogonal to each other, and can even be composed with each other (e.g. transform PyTorch programs with FX, then subsequently export to TorchScript for deployment).


Python to Python?

不過需要注意的是,FX的程式碼生成式由Python到Python。也就是說,FX生成的程式碼,和我們平常使用nn.Module搭建的網路沒區別,可以直接使用Pytorch的eager mode跑,不像torchscript一樣,是另一套runtime(我們跑torchscript的時候其實呼叫的是一個VM,也就是虛擬機器,通過VM在C++中跑通過torchscript匯出的模型)。


  • 自己寫的Module -> fx後還是Module -> 連續fx變化 -> 得到最終的fx模型



  • FX緊密地整合到Python的runtime中,因為FX可以更加精準地捕獲prograim representations,不像jit.trace有時候會出錯。
  • FX的Graph和torch.nn.module沒啥區別,其IR沒有那麼底層,所以說用起來更簡單,效率也會提升。


  • placeholder represents a function input. The name attribute specifies the name this value will take on. target is similarly the name of the argument. args holds either: 1) nothing, or 2) a single argument denoting the default parameter of the function input. kwargs is don't-care. Placeholders correspond to the function parameters (e.g. x) in the graph printout.
  • get_attr retrieves a parameter from the module hierarchy. name is similarly the name the result of the fetch is assigned to. target is the fully-qualified name of the parameter's position in the module hierarchy. args and kwargs are don't-care
  • call_function applies a free function to some values. name is similarly the name of the value to assign to. target is the function to be applied. args and kwargs represent the arguments to the function, following the Python calling convention
  • call_module applies a module in the module hierarchy's forward() method to given arguments. name is as previous. target is the fully-qualified name of the module in the module hierarchy to call. args and kwargs represent the arguments to invoke the module on, including the self argument.
  • call_method calls a method on a value. name is as similar. target is the string name of the method to apply to the self argument. args and kwargs represent the arguments to invoke the module on, including the self argument
  • output contains the output of the traced function in its args[0] attribute. This corresponds to the "return" statement in the Graph printout.


symbolic tracer

回到一開頭示例的那段程式碼,其中有一行是symbolic_traced : torch.fx.GraphModule = symbolic_trace(module),這裡核心就是symbolic_trace函式,也就是FX解析、轉換模型的起點。這個函式其實內部是這樣的:

def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None,
                   enable_cpatching: bool = False) -> GraphModule:
    Symbolic tracing API

    Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule``
    constructed by recording operations seen while tracing through ``root``.

    tracer = Tracer(enable_cpatching=enable_cpatching)
    graph = tracer.trace(root, concrete_args)
    name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
    return GraphModule(tracer.root, graph, name)


def transform(m: nn.Module,
              tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
    # Step 1: Acquire a Graph representing the code in `m`
    # 使用 Tracer 類物件去trace模型 m
    # 這邊是拆開了,這個transform函式就是實現torch.fx.symbolic_trace的功能
    graph : torch.fx.Graph = tracer_class().trace(m)

    # Step 2: 這裡就可以任意修改模型了,這也是重點
    graph = ...

    # Step 3: Construct a Module to return
    return torch.fx.GraphModule(m, graph)


symbolic tracing -> intermediate representation -> transforms -> Python code generation。


  • symbolic

The symbolic tracer performs “symbolic execution” of the Python code. It feeds fake values, called Proxies, through the code. Operations on theses Proxies are recorded. More information about symbolic tracing can be found in the symbolic_trace() and Tracer documentation.

  • intermediate representation

The intermediate representation is the container for the operations that were recorded during symbolic tracing. It consists of a list of Nodes that represent function inputs, callsites (to functions, methods, or torch.nn.Module instances), and return values. More information about the IR can be found in the documentation for Graph. The IR is the format on which transformations are applied.

  • Python code generation

Python code generation is what makes FX a Python-to-Python (or Module-to-Module) transformation toolkit. For each Graph IR, we can create valid Python code matching the Graph’s semantics. This functionality is wrapped up in GraphModule, which is a torch.nn.Module instance that holds a Graph as well as a forward method generated from the Graph.


Proxy/Retracingsymbolic trace的核心。因為我對Proxy/Retracing的理解還不是很深,這裡就不擅自描述了,摘一下官方的介紹:

Proxy objects are Node wrappers that flow through the program during symbolic tracing and record all the operations (torch function calls, method calls, operators) that they touch into the growing FX Graph.

If you’re doing graph transforms, you can wrap your own Proxy method around a raw Node so that you can use the overloaded operators to add additional things to a Graph.


FX主要的結構就是GraphGraphModule了,其中A Graph is a data structure that represents a method on a GraphModule。可以理解為Graph中存放著網路中最關鍵的Node,這些node就是網路中的一個個節點(比如卷積、relu、add、concat等等),這些node記錄了對應的method和輸入輸出資訊,從而可以串起來組成網路的邏輯。


import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self):
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(torch.sum(
            self.linear(x + self.linear.weight).relu(), dim=-1), 3)

m = MyModule()
gm = torch.fx.symbolic_trace(m)
# 這裡列印module中的node









import torch
import torch.fx

# Sample module
class M(torch.nn.Module):
    def forward(self, x, y):
        return torch.add(x, y)

def transform(m: torch.nn.Module,
              tracer_class : type = fx.Tracer) -> torch.nn.Module:
    graph : fx.Graph = tracer_class().trace(m)
    # 對於graph中的node,FX會以順序的形式來表示這個網路
    # 所以我們可以直接for迴圈來遍歷:
    for node in graph.nodes:
        # 檢測該node的IR型別是否是call_function
        if node.op == 'call_function':
            # 修改node.target為torch.mul,網路也因此變了
            if node.target == torch.add:
                node.target = torch.mul

    graph.lint() # Does some checks to make sure the
                 # Graph is well-formed.

    return fx.GraphModule(m, graph)



上述直接修改簡單粗暴,FX也貼心地為我們提供了Graph rewrites工具,我們可以藉助這些工具方便地增加或者刪除某一個node:

# Specifies the insertion point. Any nodes added to the
# Graph within this scope will be inserted after `node`
with traced.graph.inserting_after(node):
    # Insert a new `call_function` node calling `torch.relu`
    new_node = traced.graph.call_function(
        torch.relu, args=(node,))
    # We want all places that used the value of `node` to
    # now use that value after the `relu` call we've added.
    # We use the `replace_all_uses_with` API to do this.


Graph rewrites工具都有了(相關概念是來源於編譯器),那麼match pattern肯定也有了,我們可以通過 replace_pattern()來對整個graph進行修改。pattern的話可以用fx自帶的也可以自己新增自己的規則:

# Sample module
class M(torch.nn.Module):
    def __init__(self):

    def forward(self, x, w1, w2):
        val1 = torch.neg(w1)
        m1 = torch.cat([val1, w2]).sum()
        val2 = torch.neg(w1)
        m2 = torch.cat([val2, w2]).sum()
        return x + torch.max(m1) + torch.max(m2)

# Symbolically trace an instance of `M`
traced = symbolic_trace(M())

# Define the pattern. 
def pattern(a1, a2):
    val1 = torch.neg(a1)
    return torch.cat([val1, a2]).sum()

# Define the replacement (same rules as the pattern)
def replacement(w1, w2):
    return torch.stack([w1, w2])

# Replace `pattern` with `replacement` in `traced`
replace_pattern(traced, pattern, replacement)

# After calling `replace_pattern`, the generated code is:
def forward(self, x, w1, w2):
    stack = torch.stack([w1, w2])
    max_1 = torch.max(stack);  stack = None
    add = x + max_1;  x = max_1 = None
    stack_1 = torch.stack([w1, w2]);  w1 = w2 = None
    max_2 = torch.max(stack_1);  stack_1 = None
    add_1 = add + max_2;  add = max_2 = None
    return add_1



import torch
import torch.fx
from torch.fx.node import Node

from typing import Dict

class ShapeProp:
    Shape propagation. This class takes a `GraphModule`.
    Then, its `propagate` method executes the `GraphModule`
    node-by-node with the given arguments. As each operation
    executes, the ShapeProp class stores away the shape and
    element type for the output values of each operation on
    the `shape` and `dtype` attributes of the operation's
    def __init__(self, mod):
        self.mod = mod
        self.graph = mod.graph
        self.modules = dict(self.mod.named_modules())

    def propagate(self, *args):
        args_iter = iter(args)
        env : Dict[str, Node] = {}

        def load_arg(a):
            return torch.fx.graph.map_arg(a, lambda n: env[n.name])

        def fetch_attr(target : str):
            target_atoms = target.split('.')
            attr_itr = self.mod
            for i, atom in enumerate(target_atoms):
                if not hasattr(attr_itr, atom):
                    raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
                attr_itr = getattr(attr_itr, atom)
            return attr_itr

        for node in self.graph.nodes:
            if node.op == 'placeholder':
                result = next(args_iter)
            elif node.op == 'get_attr':
                result = fetch_attr(node.target)
            elif node.op == 'call_function':
                result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
            elif node.op == 'call_method':
                self_obj, *args = load_arg(node.args)
                kwargs = load_arg(node.kwargs)
                result = getattr(self_obj, node.target)(*args, **kwargs)
            elif node.op == 'call_module':
                result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))

            # This is the only code specific to shape propagation.
            # you can delete this `if` branch and this becomes
            # a generic GraphModule interpreter.
            if isinstance(result, torch.Tensor):
                node.shape = result.shape
                node.dtype = result.dtype

            env[node.name] = result

        return load_arg(self.graph.result)




import torch
import torch.fx

def transform(m: nn.Module,
              tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
    # Step 1: Acquire a Graph representing the code in `m`

    # trace nn.Module
    graph : torch.fx.Graph = tracer_class().trace(m)

    # Step 2: 這裡對Graph進行修改
    graph = ...

    # Step 3: Construct a Module to return
    return torch.fx.GraphModule(m, graph)

Your transform will take in an torch.nn.Module, acquire a Graph from it, do some modifications, and return a new torch.nn.Module. You should think of the torch.nn.Module that your FX transform returns as identical to a regular torch.nn.Module – you can pass it to another FX transform, you can pass it to TorchScript, or you can run it. Ensuring that the inputs and outputs of your FX transform are a torch.nn.Module will allow for composability.


import torch
import torch.fx

def transform(m : nn.Module) -> nn.Module:
    gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)

    # 這裡修改 gm.graph
    # <...>

    # Recompile the forward() method of `gm` from its Graph

    return gm





model = FXCenterNet()
tracer = Tracer()
graph_module = GraphModule(model, tracer.trace(model))


def trace(self, root: Union[torch.nn.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None) -> Graph:
    # root FXCenterNet
    if isinstance(root, torch.nn.Module):
        self.root = root
        fn = type(root).forward
        self.submodule_paths = {mod: name for name, mod in root.named_modules()}
        self.root = torch.nn.Module()
        fn = root

    tracer_cls: Optional[Type['Tracer']] = getattr(self, '__class__', None)
    self.graph = Graph(tracer_cls=tracer_cls)
    # 這裡大概就是遍歷root中的操作,按照規則轉換為node存放到graph中,
    # 包含attr和op、輸入輸出等資訊,最終返回graph這個IR結構
    return self.graph


<torch.fx.graph.Graph object at 0x7f57f59efdf0>




class GraphModule(torch.nn.Module):
    def __new__(cls: 'Type[GraphModule]', *args, **kwargs):
        for t in cls.__mro__:
            c = t.__qualname__.split('.')[-1]
            if c != 'GraphModuleImpl':
                cls = t

        class GraphModuleImpl(cls):  # type: ignore[misc, valid-type]
        return super().__new__(GraphModuleImpl)

    def __init__(self,
                 root: Union[torch.nn.Module, Dict[str, Any]],
                 graph: Graph,
                 class_name: str = 'GraphModule'):
        self.__class__.__name__ = class_name
        if isinstance(root, torch.nn.Module):
            if hasattr(root, 'training'):
                self.training = root.training
            # 這裡拷貝graph中的引數資訊和模組資訊到self也就是GraphModule中
            for node in graph.nodes:
                if node.op in ['get_attr', 'call_module']:
                    assert isinstance(node.target, str)
                    _copy_attr(root, self, node.target)
        elif isinstance(root, dict):
            targets_to_copy = []
            for node in graph.nodes:
                if node.op in ['get_attr', 'call_module']:
                    assert isinstance(node.target, str)
                    if node.target not in root:
                        raise RuntimeError('Node ' + str(node) + ' referenced target ' + node.target +
                                           ' but that target was not provided in ``root``!')
            targets_to_copy.sort(key=lambda t: t.count('.'))
            for target_to_copy in targets_to_copy:
                _assign_attr(root[target_to_copy], self, target_to_copy)
            raise RuntimeError('Unsupported type ' + str(root) + ' passed for root!')

        self.graph = graph
        self._tracer_cls = None
        if self.graph._tracer_cls and '<locals>' not in self.graph._tracer_cls.__qualname__:
            self._tracer_cls = self.graph._tracer_cls
    __jit_unused_properties__ = ['graph']


def forward(self, input):
    input_1 = input
    upsampler_deconv_layers_0_bias = getattr(self.upsampler.deconv_layers, "0").bias
    head_angle_0 = getattr(self.head.angle, "0")(upsampler_deconv_layers_11);  upsampler_deconv_layers_11 = None
    head_angle_1 = getattr(self.head.angle, "1")(head_angle_0);  head_angle_0 = None
    head_angle_2 = getattr(self.head.angle, "2")(head_angle_1);  head_angle_1 = None
    return {'hm': head_hm_2, 'wh': head_wh_2, 'reg': head_reg_2, 'angle': head_angle_2}

這個時候我們就有了trace後的Module,這個Module和原始模型並沒有區別,forward函式也是按照原始模型的forward生成的。因為我們只是簡單地trace了一遍,所以相同輸入結果也是一樣的:graph_module(input) == original_model(input),畢竟沒幹啥特殊的。



def _fuse_fx(
    graph_module: GraphModule,
    is_qat: bool,
    fuse_custom_config_dict: Optional[Dict[str, Any]] = None,
    backend_config_dict: Optional[Dict[str, Any]] = None,
) -> GraphModule:
    r""" Internal helper function to fuse modules in preparation for quantization

        graph_module: GraphModule object from symbolic tracing (torch.fx.symbolic_trace)
    fuser = Fuser()
    return fuser.fuse(
        graph_module, is_qat, fuse_custom_config_dict, backend_config_dict)

來看看Fuser都幹了啥,其實很簡單,就是遍歷一遍input_graph = model.graph中的node,然後根據指定好的fuse規則進行融合,融合會涉及到修改graph結構:

class Fuser:
    def fuse(
        model: GraphModule,
        is_qat: bool,
        fuse_custom_config_dict: Optional[Dict[str, Any]] = None,
        backend_config_dict: Optional[Dict[str, Any]] = None,
    ) -> GraphModule:
        if fuse_custom_config_dict is None:
            fuse_custom_config_dict = {}

        input_root = model
        input_graph = model.graph
        # 這裡首先copy 原始模型中的named_modules中,之後會根據fuse情況進行修改
        self.modules = dict(input_root.named_modules())  
        # 這裡查詢匹配的fuse pattern
        fusion_pairs = self._find_matches(
            input_root, input_graph, fusion_pattern_to_fuse_handler_cls)
        self.fused_graph = Graph()
        env: Dict[Any, Any] = {}

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        def get_root_node(node_pattern):
            while not isinstance(node_pattern[-1], Node):
                node_pattern = node_pattern[-1]
            return node_pattern[-1]

        for node in input_graph.nodes:
            maybe_last_node, pattern, matched_node_pattern, obj = \
                fusion_pairs.get(node.name, (None, None, None, None))
            if maybe_last_node is node:
                assert obj is not None
                # TODO: currently we hard code the root node, which only works for
                # a sequence of ops and assume the root node is the last node,
                # we want to make this more general to support more complex patterns
                root_node = get_root_node(matched_node_pattern)  # 尋找fuse的根node
                env[node.name] = obj.fuse( # 這裡將self傳入,對self進行修改
                    self, load_arg, root_node, matched_node_pattern,  # type: ignore[arg-type]
                    fuse_custom_config_dict, fuser_method_mapping, is_qat)
            elif maybe_last_node is None:
                env[node.name] = self.fused_graph.node_copy(node, load_arg)
            # node matched in patterns and is not root is removed here

        preserved_attributes = set(fuse_custom_config_dict.get("preserved_attributes", []))
        model = FusedGraphModule(input_root, self.fused_graph, preserved_attributes)
        return model

    def _find_matches(
            self, root: GraphModule, graph: Graph,
            patterns: Dict[Pattern, Callable]
    ) -> Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler]]:
        modules = dict(root.named_modules())
        match_map : Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler]] = {}  # node name -> (root_node, match_value)

        def apply_match(pattern, node, match, matched_node_pattern):
            if isinstance(pattern, tuple):
                s, *args = pattern
                current_node_pattern: List[Node] = []
                apply_match(s, node, match, current_node_pattern)
                for subpattern, arg in zip(args, node.args):
                    apply_match(subpattern, arg, match, current_node_pattern)
                # the first pattern matches will take precedence
                if node.name not in match_map:
                    root_node, pattern, handler = match
                    match_map[node.name] = (root_node, pattern, matched_node_pattern, handler)
        # 這裡就是match過程
        for node in reversed(graph.nodes):
            if node.name not in match_map:
                for pattern, value in patterns.items():
                    matched_node_pattern: List[Node] = []
                    if is_match(modules, node, pattern):
                        apply_match(pattern, node, (node, pattern, value(self, node)), matched_node_pattern)

        return match_map


# /ao/quantization/fx/fusion_patterns.py
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv1d))
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv2d))
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv3d))
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv1d))
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv2d))
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv3d))
@register_fusion_pattern((torch.nn.BatchNorm1d, torch.nn.Conv1d))
@register_fusion_pattern((torch.nn.BatchNorm2d, torch.nn.Conv2d))
@register_fusion_pattern((torch.nn.BatchNorm3d, torch.nn.Conv3d))
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm1d, torch.nn.Conv1d)))
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm3d, torch.nn.Conv3d)))
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm1d, torch.nn.Conv1d)))
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm3d, torch.nn.Conv3d)))
@register_fusion_pattern((torch.nn.BatchNorm1d, torch.nn.Linear))
class DefaultFuseHandler(FuseHandler):
    def __init__(
            quantizer: QuantizerCls,
            node: Node):
        super().__init__(quantizer, node)

    def fuse(...):
        # 這裡執行實際的融合操作


matched_module_types = get_matched_types(matched_modules)
module_parent_name, module_name = _parent_name(root_node.target)
fuser_method = get_fuser_method_new(matched_module_types, fuser_method_mapping)
# TODO: change the signature for fuser_method to take matched module patterns
# as input
fused_module = fuser_method(is_qat, *matched_modules)
# TODO: maybe add a pass to cleanup bn modules?
setattr(quantizer.modules[module_parent_name], module_name, fused_module) # 往fuse控制的新模型中加入 新的modules
return quantizer.fused_graph.node_copy(root_node, load_arg)                # 往fuse控制的新graph中加入forward引數


def fuse_conv_bn_relu(is_qat, conv, bn, relu):
    assert(conv.training == bn.training == relu.training),\
        "Conv and BN both must be in the same mode (train or eval)."
    fused_module : Optional[Type[nn.Sequential]] = None
    map_to_fused_module_eval = {
        nn.Conv1d: nni.ConvReLU1d,
        nn.Conv2d: nni.ConvReLU2d,
        nn.Conv3d: nni.ConvReLU3d,
    fused_module = map_to_fused_module_eval.get(type(conv), None)
    if fused_module is not None:
        fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
        return fused_module(fused_conv, relu)
        raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))


def fuse_conv_bn_eval(conv, bn, transpose=False):
    assert(not (conv.training or bn.training)), "Fusion only for eval!"
    fused_conv = copy.deepcopy(conv)

    fused_conv.weight, fused_conv.bias = \
        fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,
                             bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias, transpose)

    return fused_conv

def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=False):
    if conv_b is None:
        conv_b = torch.zeros_like(bn_rm)
    if bn_w is None:
        bn_w = torch.ones_like(bn_rm)
    if bn_b is None:
        bn_b = torch.zeros_like(bn_rm)
    bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)

    if transpose:
        shape = [1, -1] + [1] * (len(conv_w.shape) - 2)
        shape = [-1, 1] + [1] * (len(conv_w.shape) - 2)

    conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape(shape)
    conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b

    return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b)


class ConvReLU2d(_FusedModule):
    r"""This is a sequential container which calls the Conv2d and ReLU modules.
    During quantization this will be replaced with the corresponding fused module."""
    def __init__(self, conv, relu):
        assert type(conv) == Conv2d and type(relu) == ReLU, \
            'Incorrect types for input modules{}{}'.format(
                type(conv), type(relu))
        super().__init__(conv, relu)

整體流程就是conv + bn->conv然後conv + relu -> ConvReLU2d


def forward(self, input):
    input_1 = input
    backbone_conv1 = self.backbone.conv1(input_1)
    backbone_maxpool = self.backbone.maxpool(backbone_relu)
    backbone_layer1_0_conv1 = getattr(self.backbone.layer1, "0").conv1(backbone_maxpool)
    backbone_layer1_0_conv2 = getattr(self.backbone.layer1, "0").conv2(backbone_layer1_0_relu)
    backbone_layer1_0_conv3 = getattr(self.backbone.layer1, "0").conv3(backbone_layer1_0_relu_1)
    head_reg_0 = getattr(self.head.reg, "0")(upsampler_deconv_layers_11)
    head_reg_2 = getattr(self.head.reg, "2")(head_reg_1)
    head_angle_0 = getattr(self.head.angle, "0")(upsampler_deconv_layers_11)
    head_angle_2 = getattr(self.head.angle, "2")(head_angle_1)
    return {'hm': head_hm_2, 'wh': head_wh_2, 'reg': head_reg_2, 'angle': head_angle_2}








我們是可以進入FX的Generated Code中的,也可以設定斷點:




# Assume that `traced` is a GraphModule that has undergone some
# number of transforms

# Copy this code for later
# Print the code generated from symbolic tracing. This outputs:
def forward(self, y):
    x = self.x
    add_1 = x + y;  x = y = None
    return add_1

# 這裡繼承原始的Module
class SubclassM(M):
    def __init__(self):

    # 把生成的程式碼粘到這裡
    def forward(self, y):
        x = self.x
        add_1 = x + y;  x = y = None
        return add_1

# Create an instance of the original, untraced Module. Then, create an
# instance of the Module with the copied `forward` function. We can
# now compare the output of both the original and the traced version.
pre_trace = M()
post_trace = SubclassM()






# 匯出的module.py中
import torch
from torch.nn import *
class FusedModel(torch.nn.Module):
    def __init__(self):
        # 這裡載入權重引數資訊
        self.backbone = torch.load(r'fx_debug/backbone.pt')

    def forward(self, input):
        # 這裡就是生成的code部分,也幫你寫到forward中了
        input_1 = input
        backbone_conv1 = self.backbone.conv1(input_1)
        backbone_maxpool = self.backbone.maxpool(backbone_relu)
        backbone_layer1_0_conv1 = getattr(self.backbone.layer1, "0").conv1(backbone_maxpool)
        head_angle_0 = getattr(self.head.angle, "0")(upsampler_deconv_layers_11)
        head_angle_2 = getattr(self.head.angle, "2")(head_angle_1)
        return {'hm': head_hm_2, 'wh': head_wh_2, 'reg': head_reg_2, 'angle': head_angle_2}





因為Symbolic execution的限制。
Proxy object cannot be iterated. This can be attempted when the Proxy is used in a loop or as a *args or **kwargs function argument. See the torch.fx docs on pytorch.org for a more detailed explanation of what types of control flow can be traced, and check out the Proxy docstring for help troubleshooting Proxy iteration errors

The main limitation of symbolic tracing is it does not currently support dynamic control flow. That is, loops or if statements where the condition may depend on the input values of the program.




