您好,欢迎来到步遥情感网。
搜索
您的当前位置:首页FasterNet中Pconv的实现、效果与作用分析

FasterNet中Pconv的实现、效果与作用分析

来源:步遥情感网

FasterNet-t0在GPU、CPU和ARM处理器上分别比MobileViT-XXS快2.8×、3.3×和2.4×,而准确率要高2.9%。我们的大型FasterNet-L实现了令人印象深刻的83.5%的前1精度,与新兴的Swin-B相当,同时在GPU上有36%的推理吞吐量,并在CPU上节省了37%的计算时间。FasterNet作者提到的其核心在于PConv模块,其不仅减少了FLOPs(降低了冗余计算,其与ghostnet一样,认为conv中存在冗余),同时降低了mac(大部分输入直达输入),故而在取得了高性能的延时能力,如在gpu上fps高,在cpu与arm设备上延时最低。为此对PConv的设计与实现进行深入分析。

1、论文信息

1.1 模块设计

Pconv与常规卷积、分组卷积相比,只对输入通道的少部分做密集卷积(常规卷积),剩余部分直通到输出。该操作大幅度降低了卷积的运算量(如将输入通道分成4份,只对其中一份进行卷积,剩余的3份直通到下一层),也降低了内存访问成本(如C_in为400,只对其四分之一进行卷积,内存访问则为100wh+100wh,内存访问成本为200wh,为原来的1/4)

Pconv对应实现代码如下所示,可以看到就是split=》conv=》cat操作

class Partial_conv3(nn.Module):

    def __init__(self, dim, n_div, forward):
        super().__init__()
        self.dim_conv3 = dim // n_div
        self.dim_untouched = dim - self.dim_conv3
        self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False)

        if forward == 'slicing':
            self.forward = self.forward_slicing
        elif forward == 'split_cat':
            self.forward = self.forward_split_cat
        else:
            raise NotImplementedError

    def forward_slicing(self, x: Tensor) -> Tensor:
        # only for inference
        x = x.clone()   # !!! Keep the original input intact for the residual connection later
        x[:, :self.dim_conv3, :, :] = self.partial_conv3(x[:, :self.dim_conv3, :, :])

        return x

    def forward_split_cat(self, x: Tensor) -> Tensor:
        # for training/inference
        x1, x2 = torch.split(x, [self.dim_conv3, self.dim_untouched], dim=1)
        x1 = self.partial_conv3(x1)
        x = torch.cat((x1, x2), 1)

        return x

在论文中提到了与PWcov结合、或是T-shaped Conv,然而在代码层面实际上跟PConv没有任何关系。只是在FasterNet Block中与Conv1x1进行结合conv1x1实现通道间信息交互

1.2 模型结构

Faster的模型结构如下所示,可以看到Pconv只是其中的一小部分。作者将Pconv与conv1x1+BN+Relu+残差联合在一起形成FasterNet Block,FasterNet Block才是模型的主要成分。然后模型中参考了VIT模型设计中的很多设计(如PatchEmbed、mlp),只是没有Transformer模块。

PatchEmbed在模型输入层中可以看到,而mlp操作其实就是Pconv后面的Conv1x1+bn+relu+Conv1x1

1.3 结构对比

模块性能对比 这里对比了conv、分组卷积、深度分离卷积、PConv。对应的feature map在像素点量上是逐步减半的(如:96x56x56的像素量是192x28x28的一半),可以发现只有DWConv的FLOPs是减半,其他方法是没有减少的。 这里可以发现,DWConv是性价比最高的结构,PConv是第二的(观察fps与latency)。唯独在ARM (Cortex-A72,using a single thread)架构下,PConv比DWConv要强

注:1、PConv在r为1/4时,FLOPs与group为1/16的分组卷积是一样的,但内存访问量是不同的。
注:2、DWConv是全分组卷积(ksize为3,分组数为通道数,仅实现空间信息交互)+点卷积组成(ksize为1,实现通道信息交互)

内存访问成本对比: 公式2是Pconv的,公式3是conv的,但c’是c的1/4,故而说Pconv的内存访问成本是conv的1/4 这里是假定了模型输入输出的通道数都为c,所以是2c,否则是(c_in+c_out)

1.3 模型效果

宏观对比如下,可以发现FasterNet在GPU上达到了最高的fps,在cpu与arm上达到了最低的延时。

以下图表表示了FasterNet在轻量级与重量级模型中都取得了最近性能。

2、代码实现与分析

2.1 Pconv代码

Pconv的实现代码经过简化后如下所示,可以发现就是简单的split+cat操作。23年博主也做过类似尝试(用pconv全量替换掉conv),并没有训练出好效果

class Partial_conv3(nn.Module):

    def __init__(self, dim, n_div, forward):
        super().__init__()
        self.dim_conv3 = dim // n_div
        self.dim_untouched = dim - self.dim_conv3
        self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False)
    def forward(self, x: Tensor) -> Tensor:
        # only for inference
        x = x.clone()   # !!! Keep the original input intact for the residual connection later
        x[:, :self.dim_conv3, :, :] = self.partial_conv3(x[:, :self.dim_conv3, :, :])
        return x

2.2 Faster Block代码

spatial_mixing对象为pconv层
mlp对象为Faster Block模块中的非pconv层
forword代码如下:

    def forward(self, x: Tensor) -> Tensor:
        shortcut = x
        x = self.spatial_mixing(x)
        x = shortcut + self.drop_path(self.mlp(x))
        return x

完整实现代码如下

class MLPBlock(nn.Module):

    def __init__(self,
                 dim,
                 n_div,
                 mlp_ratio,
                 drop_path,
                 layer_scale_init_value,
                 act_layer,
                 norm_layer,
                 pconv_fw_type
                 ):

        super().__init__()
        self.dim = dim
        self.mlp_ratio = mlp_ratio
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.n_div = n_div

        mlp_hidden_dim = int(dim * mlp_ratio)

        mlp_layer: List[nn.Module] = [
            nn.Conv2d(dim, mlp_hidden_dim, 1, bias=False),
            norm_layer(mlp_hidden_dim),
            act_layer(),
            nn.Conv2d(mlp_hidden_dim, dim, 1, bias=False)
        ]

        self.mlp = nn.Sequential(*mlp_layer)

        self.spatial_mixing = Partial_conv3(
            dim,
            n_div,
            pconv_fw_type
        )

        if layer_scale_init_value > 0:
            self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
            self.forward = self.forward_layer_scale
        else:
            self.forward = self.forward

    def forward(self, x: Tensor) -> Tensor:
        shortcut = x
        x = self.spatial_mixing(x)
        x = shortcut + self.drop_path(self.mlp(x))
        return x

    def forward_layer_scale(self, x: Tensor) -> Tensor:
        shortcut = x
        x = self.spatial_mixing(x)
        x = shortcut + self.drop_path(
            self.layer_scale.unsqueeze(-1).unsqueeze(-1) * self.mlp(x))
        return x

此外还有一个BasicStage类,其主要就是实现多层MLPBlock(即Faster Block)的堆叠

2.3 PatchEmbed与PatchMerging

PatchEmbed是类似于vit模型中的图像切patch,将空间信息转移到通道上。
PatchMerging是基于conv的stride实现特征图的分辨率降低,同时实现通道的增加。


class PatchEmbed(nn.Module):

    def __init__(self, patch_size, patch_stride, in_chans, embed_dim, norm_layer):
        super().__init__()
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, bias=False)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = nn.Identity()

    def forward(self, x: Tensor) -> Tensor:
        x = self.norm(self.proj(x))
        return x


class PatchMerging(nn.Module):

    def __init__(self, patch_size2, patch_stride2, dim, norm_layer):
        super().__init__()
        self.reduction = nn.Conv2d(dim, 2 * dim, kernel_size=patch_size2, stride=patch_stride2, bias=False)
        if norm_layer is not None:
            self.norm = norm_layer(2 * dim)
        else:
            self.norm = nn.Identity()

    def forward(self, x: Tensor) -> Tensor:
        x = self.norm(self.reduction(x))
        return x

2.4 模型代码

class FasterNet(nn.Module):

   def __init__(self,
                in_chans=3,
                num_classes=1000,
                embed_dim=96,
                depths=(1, 2, 8, 2),
                mlp_ratio=2.,
                n_div=4,
                patch_size=4,
                patch_stride=4,
                patch_size2=2,  # for subsequent layers
                patch_stride2=2,
                patch_norm=True,
                feature_dim=1280,
                drop_path_rate=0.1,
                layer_scale_init_value=0,
                norm_layer='BN',
                act_layer='RELU',
                fork_feat=False,
                init_cfg=None,
                pretrained=None,
                pconv_fw_type='split_cat',
                **kwargs):
       super().__init__()

       if norm_layer == 'BN':
           norm_layer = nn.BatchNorm2d
       else:
           raise NotImplementedError

       if act_layer == 'GELU':
           act_layer = nn.GELU
       elif act_layer == 'RELU':
           act_layer = partial(nn.ReLU, inplace=True)
       else:
           raise NotImplementedError

       if not fork_feat:
           self.num_classes = num_classes
       self.num_stages = len(depths)
       self.embed_dim = embed_dim
       self.patch_norm = patch_norm
       self.num_features = int(embed_dim * 2 ** (self.num_stages - 1))
       self.mlp_ratio = mlp_ratio
       self.depths = depths

       # split image into non-overlapping patches
       self.patch_embed = PatchEmbed(
           patch_size=patch_size,
           patch_stride=patch_stride,
           in_chans=in_chans,
           embed_dim=embed_dim,
           norm_layer=norm_layer if self.patch_norm else None
       )

       # stochastic depth decay rule
       dpr = [x.item()
              for x in torch.linspace(0, drop_path_rate, sum(depths))]

       # build layers
       stages_list = []
       for i_stage in range(self.num_stages):
           stage = BasicStage(dim=int(embed_dim * 2 ** i_stage),
                              n_div=n_div,
                              depth=depths[i_stage],
                              mlp_ratio=self.mlp_ratio,
                              drop_path=dpr[sum(depths[:i_stage]):sum(depths[:i_stage + 1])],
                              layer_scale_init_value=layer_scale_init_value,
                              norm_layer=norm_layer,
                              act_layer=act_layer,
                              pconv_fw_type=pconv_fw_type
                              )
           stages_list.append(stage)

           # patch merging layer
           if i_stage < self.num_stages - 1:
               stages_list.append(
                   PatchMerging(patch_size2=patch_size2,
                                patch_stride2=patch_stride2,
                                dim=int(embed_dim * 2 ** i_stage),
                                norm_layer=norm_layer)
               )

       self.stages = nn.Sequential(*stages_list)

       self.fork_feat = fork_feat

       if self.fork_feat:
           self.forward = self.forward_det
           # add a norm layer for each output
           self.out_indices = [0, 2, 4, 6]
           for i_emb, i_layer in enumerate(self.out_indices):
               if i_emb == 0 and os.environ.get('FORK_LAST3', None):
                   raise NotImplementedError
               else:
                   layer = norm_layer(int(embed_dim * 2 ** i_emb))
               layer_name = f'norm{i_layer}'
               self.add_module(layer_name, layer)
       else:
           self.forward = self.forward_cls
           # Classifier head
           self.avgpool_pre_head = nn.Sequential(
               nn.AdaptiveAvgPool2d(1),
               nn.Conv2d(self.num_features, feature_dim, 1, bias=False),
               act_layer()
           )
           self.head = nn.Linear(feature_dim, num_classes) \
               if num_classes > 0 else nn.Identity()

       self.apply(self.cls_init_weights)
       self.init_cfg = copy.deepcopy(init_cfg)
       if self.fork_feat and (self.init_cfg is not None or pretrained is not None):
           self.init_weights()

   def cls_init_weights(self, m):
       if isinstance(m, nn.Linear):
           trunc_normal_(m.weight, std=.02)
           if isinstance(m, nn.Linear) and m.bias is not None:
               nn.init.constant_(m.bias, 0)
       elif isinstance(m, (nn.Conv1d, nn.Conv2d)):
           trunc_normal_(m.weight, std=.02)
           if m.bias is not None:
               nn.init.constant_(m.bias, 0)
       elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
           nn.init.constant_(m.bias, 0)
           nn.init.constant_(m.weight, 1.0)

   # init for mmdetection by loading imagenet pre-trained weights
   def init_weights(self, pretrained=None):
       logger = get_root_logger()
       if self.init_cfg is None and pretrained is None:
           logger.warn(f'No pre-trained weights for '
                       f'{self.__class__.__name__}, '
                       f'training start from scratch')
           pass
       else:
           assert 'checkpoint' in self.init_cfg, f'Only support ' \
                                                 f'specify `Pretrained` in ' \
                                                 f'`init_cfg` in ' \
                                                 f'{self.__class__.__name__} '
           if self.init_cfg is not None:
               ckpt_path = self.init_cfg['checkpoint']
           elif pretrained is not None:
               ckpt_path = pretrained

           ckpt = _load_checkpoint(
               ckpt_path, logger=logger, map_location='cpu')
           if 'state_dict' in ckpt:
               _state_dict = ckpt['state_dict']
           elif 'model' in ckpt:
               _state_dict = ckpt['model']
           else:
               _state_dict = ckpt

           state_dict = _state_dict
           missing_keys, unexpected_keys = \
               self.load_state_dict(state_dict, False)

           # show for debug
           print('missing_keys: ', missing_keys)
           print('unexpected_keys: ', unexpected_keys)

   def forward_cls(self, x):
       # output only the features of last layer for image classification
       x = self.patch_embed(x)
       x = self.stages(x)
       x = self.avgpool_pre_head(x)  # B C 1 1
       x = torch.flatten(x, 1)
       x = self.head(x)

       return x

   def forward_det(self, x: Tensor) -> Tensor:
       # output the features of four stages for dense prediction
       x = self.patch_embed(x)
       outs = []
       for idx, stage in enumerate(self.stages):
           x = stage(x)
           if self.fork_feat and idx in self.out_indices:
               norm_layer = getattr(self, f'norm{idx}')
               x_out = norm_layer(x)
               outs.append(x_out)

       return outs

2.5 完整模型代码

完整模型代码只是用于3.2中的FLOPs分析

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn as nn
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from functools import partial
from typing import List
from torch import Tensor
import copy
import os

try:
    from mmdet.models.builder import BACKBONES as det_BACKBONES
    from mmdet.utils import get_root_logger
    from mmcv.runner import _load_checkpoint
    has_mmdet = True
except ImportError:
    print("If for detection, please install mmdetection first")
    has_mmdet = False


class Partial_conv3(nn.Module):

    def __init__(self, dim, n_div, forward):
        super().__init__()
        self.dim_conv3 = dim // n_div
        self.dim_untouched = dim - self.dim_conv3
        self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False)

        if forward == 'slicing':
            self.forward = self.forward_slicing
        elif forward == 'split_cat':
            self.forward = self.forward_split_cat
        else:
            raise NotImplementedError

    def forward_slicing(self, x: Tensor) -> Tensor:
        # only for inference
        x = x.clone()   # !!! Keep the original input intact for the residual connection later
        x[:, :self.dim_conv3, :, :] = self.partial_conv3(x[:, :self.dim_conv3, :, :])

        return x

    def forward_split_cat(self, x: Tensor) -> Tensor:
        # for training/inference
        x1, x2 = torch.split(x, [self.dim_conv3, self.dim_untouched], dim=1)
        x1 = self.partial_conv3(x1)
        x = torch.cat((x1, x2), 1)

        return x


class MLPBlock(nn.Module):

    def __init__(self,
                 dim,
                 n_div,
                 mlp_ratio,
                 drop_path,
                 layer_scale_init_value,
                 act_layer,
                 norm_layer,
                 pconv_fw_type
                 ):

        super().__init__()
        self.dim = dim
        self.mlp_ratio = mlp_ratio
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.n_div = n_div

        mlp_hidden_dim = int(dim * mlp_ratio)

        mlp_layer: List[nn.Module] = [
            nn.Conv2d(dim, mlp_hidden_dim, 1, bias=False),
            norm_layer(mlp_hidden_dim),
            act_layer(),
            nn.Conv2d(mlp_hidden_dim, dim, 1, bias=False)
        ]

        self.mlp = nn.Sequential(*mlp_layer)

        self.spatial_mixing = Partial_conv3(
            dim,
            n_div,
            pconv_fw_type
        )

        if layer_scale_init_value > 0:
            self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
            self.forward = self.forward_layer_scale
        else:
            self.forward = self.forward

    def forward(self, x: Tensor) -> Tensor:
        shortcut = x
        x = self.spatial_mixing(x)
        x = shortcut + self.drop_path(self.mlp(x))
        return x

    def forward_layer_scale(self, x: Tensor) -> Tensor:
        shortcut = x
        x = self.spatial_mixing(x)
        x = shortcut + self.drop_path(
            self.layer_scale.unsqueeze(-1).unsqueeze(-1) * self.mlp(x))
        return x


class BasicStage(nn.Module):

    def __init__(self,
                 dim,
                 depth,
                 n_div,
                 mlp_ratio,
                 drop_path,
                 layer_scale_init_value,
                 norm_layer,
                 act_layer,
                 pconv_fw_type
                 ):

        super().__init__()

        blocks_list = [
            MLPBlock(
                dim=dim,
                n_div=n_div,
                mlp_ratio=mlp_ratio,
                drop_path=drop_path[i],
                layer_scale_init_value=layer_scale_init_value,
                norm_layer=norm_layer,
                act_layer=act_layer,
                pconv_fw_type=pconv_fw_type
            )
            for i in range(depth)
        ]

        self.blocks = nn.Sequential(*blocks_list)

    def forward(self, x: Tensor) -> Tensor:
        x = self.blocks(x)
        return x


class PatchEmbed(nn.Module):

    def __init__(self, patch_size, patch_stride, in_chans, embed_dim, norm_layer):
        super().__init__()
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, bias=False)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = nn.Identity()

    def forward(self, x: Tensor) -> Tensor:
        x = self.norm(self.proj(x))
        return x


class PatchMerging(nn.Module):

    def __init__(self, patch_size2, patch_stride2, dim, norm_layer):
        super().__init__()
        self.reduction = nn.Conv2d(dim, 2 * dim, kernel_size=patch_size2, stride=patch_stride2, bias=False)
        if norm_layer is not None:
            self.norm = norm_layer(2 * dim)
        else:
            self.norm = nn.Identity()

    def forward(self, x: Tensor) -> Tensor:
        x = self.norm(self.reduction(x))
        return x


class FasterNet(nn.Module):

    def __init__(self,
                 in_chans=3,
                 num_classes=1000,
                 embed_dim=96,
                 depths=(1, 2, 8, 2),
                 mlp_ratio=2.,
                 n_div=4,
                 patch_size=4,
                 patch_stride=4,
                 patch_size2=2,  # for subsequent layers
                 patch_stride2=2,
                 patch_norm=True,
                 feature_dim=1280,
                 drop_path_rate=0.1,
                 layer_scale_init_value=0,
                 norm_layer='BN',
                 act_layer='RELU',
                 fork_feat=False,
                 init_cfg=None,
                 pretrained=None,
                 pconv_fw_type='split_cat',
                 **kwargs):
        super().__init__()

        if norm_layer == 'BN':
            norm_layer = nn.BatchNorm2d
        else:
            raise NotImplementedError

        if act_layer == 'GELU':
            act_layer = nn.GELU
        elif act_layer == 'RELU':
            act_layer = partial(nn.ReLU, inplace=True)
        else:
            raise NotImplementedError

        if not fork_feat:
            self.num_classes = num_classes
        self.num_stages = len(depths)
        self.embed_dim = embed_dim
        self.patch_norm = patch_norm
        self.num_features = int(embed_dim * 2 ** (self.num_stages - 1))
        self.mlp_ratio = mlp_ratio
        self.depths = depths

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            patch_size=patch_size,
            patch_stride=patch_stride,
            in_chans=in_chans,
            embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None
        )

        # stochastic depth decay rule
        dpr = [x.item()
               for x in torch.linspace(0, drop_path_rate, sum(depths))]

        # build layers
        stages_list = []
        for i_stage in range(self.num_stages):
            stage = BasicStage(dim=int(embed_dim * 2 ** i_stage),
                               n_div=n_div,
                               depth=depths[i_stage],
                               mlp_ratio=self.mlp_ratio,
                               drop_path=dpr[sum(depths[:i_stage]):sum(depths[:i_stage + 1])],
                               layer_scale_init_value=layer_scale_init_value,
                               norm_layer=norm_layer,
                               act_layer=act_layer,
                               pconv_fw_type=pconv_fw_type
                               )
            stages_list.append(stage)

            # patch merging layer
            if i_stage < self.num_stages - 1:
                stages_list.append(
                    PatchMerging(patch_size2=patch_size2,
                                 patch_stride2=patch_stride2,
                                 dim=int(embed_dim * 2 ** i_stage),
                                 norm_layer=norm_layer)
                )

        self.stages = nn.Sequential(*stages_list)

        self.fork_feat = fork_feat

        if self.fork_feat:
            self.forward = self.forward_det
            # add a norm layer for each output
            self.out_indices = [0, 2, 4, 6]
            for i_emb, i_layer in enumerate(self.out_indices):
                if i_emb == 0 and os.environ.get('FORK_LAST3', None):
                    raise NotImplementedError
                else:
                    layer = norm_layer(int(embed_dim * 2 ** i_emb))
                layer_name = f'norm{i_layer}'
                self.add_module(layer_name, layer)
        else:
            self.forward = self.forward_cls
            # Classifier head
            self.avgpool_pre_head = nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Conv2d(self.num_features, feature_dim, 1, bias=False),
                act_layer()
            )
            self.head = nn.Linear(feature_dim, num_classes) \
                if num_classes > 0 else nn.Identity()

        self.apply(self.cls_init_weights)
        self.init_cfg = copy.deepcopy(init_cfg)
        if self.fork_feat and (self.init_cfg is not None or pretrained is not None):
            self.init_weights()

    def cls_init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, (nn.Conv1d, nn.Conv2d)):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    # init for mmdetection by loading imagenet pre-trained weights
    def init_weights(self, pretrained=None):
        logger = get_root_logger()
        if self.init_cfg is None and pretrained is None:
            logger.warn(f'No pre-trained weights for '
                        f'{self.__class__.__name__}, '
                        f'training start from scratch')
            pass
        else:
            assert 'checkpoint' in self.init_cfg, f'Only support ' \
                                                  f'specify `Pretrained` in ' \
                                                  f'`init_cfg` in ' \
                                                  f'{self.__class__.__name__} '
            if self.init_cfg is not None:
                ckpt_path = self.init_cfg['checkpoint']
            elif pretrained is not None:
                ckpt_path = pretrained

            ckpt = _load_checkpoint(
                ckpt_path, logger=logger, map_location='cpu')
            if 'state_dict' in ckpt:
                _state_dict = ckpt['state_dict']
            elif 'model' in ckpt:
                _state_dict = ckpt['model']
            else:
                _state_dict = ckpt

            state_dict = _state_dict
            missing_keys, unexpected_keys = \
                self.load_state_dict(state_dict, False)

            # show for debug
            print('missing_keys: ', missing_keys)
            print('unexpected_keys: ', unexpected_keys)

    def forward_cls(self, x):
        # output only the features of last layer for image classification
        x = self.patch_embed(x)
        x = self.stages(x)
        x = self.avgpool_pre_head(x)  # B C 1 1
        x = torch.flatten(x, 1)
        x = self.head(x)

        return x

    def forward_det(self, x: Tensor) -> Tensor:
        # output the features of four stages for dense prediction
        x = self.patch_embed(x)
        outs = []
        for idx, stage in enumerate(self.stages):
            x = stage(x)
            if self.fork_feat and idx in self.out_indices:
                norm_layer = getattr(self, f'norm{idx}')
                x_out = norm_layer(x)
                outs.append(x_out)

        return outs
  

3、相关分析

3.1 PConv可以取代Conv么?

不可以,其仅是实现了对于C_in与C_out相等时,conv的平替;同时,其只有局部空间信息的交互,大部分通道数据是直连输出,因此会是输入数据直传到网络深层。故而需要密集全连接的卷积层进行通道间信息交互。

在整个论文实验中,也没有将FasterNet中pconv替换为Conv的对比,pconv。或许FasterNet的优势仅是因为其结构设计(尤其是对输入进行PatchEmbed,将空间大小降低为原来的1/16),也就是是使用Conv替代pconv,在acc与延时上或许依旧占据优势。

同样,对于PWConv也没有等效对比,将FasterNet中pconv替换为PWConv或许还能再度迎来性能提升。毕竟在作者实验中,PWConv在gpu上推理速度比pconv更具优势,拟合能力与pconv不相上下。

3.2 FasterNet中的FLOPs分布

基于以下代码构建了一个简易的FasterNet模型,并输出了每一层的flops

if __name__=="__main__":
    model=FasterNet( depths=(1, 1, 1, 1),)
    from fvcore.nn import flop_count_table, FlopCountAnalysis, ActivationCountAnalysis    
    x = torch.randn(1, 3, 256, 256)
    # model = SAFMN(dim=36, n_blocks=12, ffn_scale=2.0, upscaling_factor=2)
    print(f'params: {sum(map(lambda x: x.numel(), model.parameters()))}')
    print(flop_count_table(FlopCountAnalysis(model, x), activations=ActivationCountAnalysis(model, x)))
    output = model(x)
    print(output.shape)

代码运行输出效果如下,可以发现模型关键模块FasterBlock中flops的大头在blocks.0.mlp上,spatial_mixing.partial_conv3(即pconv)只占据了模块10%的计算量为0.21m。

| module                                            | #parameters or shape   | #flops     | #activations   |
|:--------------------------------------------------|:-----------------------|:-----------|:---------------|
| model                                             | 7.4M                   | 0.948G     | 3.136M         |
|  patch_embed                                      |  4.8K                  |  20.84M    |  0.393M        |
|   patch_embed.proj                                |   4.608K               |   18.874M  |   0.393M       |
|    patch_embed.proj.weight                        |    (96, 3, 4, 4)       |            |                |
|   patch_embed.norm                                |   0.192K               |   1.966M   |   0            |
|    patch_embed.norm.weight                        |    (96,)               |            |                |
|    patch_embed.norm.bias                          |    (96,)               |            |                |
|  stages                                           |  5.131M                |  0.924G    |  2.74M         |
|   stages.0.blocks.0                               |   42.432K              |   0.176G   |   1.278M       |
|    stages.0.blocks.0.mlp                          |    37.248K             |    0.155G  |    1.18M       |
|    stages.0.blocks.0.spatial_mixing.partial_conv3 |    5.184K              |    21.234M |    98.304K     |
|   stages.1                                        |   74.112K              |   76.481M  |   0.197M       |
|    stages.1.reduction                             |    73.728K             |    75.497M |    0.197M      |
|    stages.1.norm                                  |    0.384K              |    0.983M  |    0           |
|   stages.2.blocks.0                               |   0.169M               |   0.174G   |   0.639M       |
|    stages.2.blocks.0.mlp                          |    0.148M              |    0.153G  |    0.59M       |
|    stages.2.blocks.0.spatial_mixing.partial_conv3 |    20.736K             |    21.234M |    49.152K     |
|   stages.3                                        |   0.296M               |   75.9M  |   98.304K      |
|    stages.3.reduction                             |    0.295M              |    75.497M |    98.304K     |
|    stages.3.norm                                  |    0.768K              |    0.492M  |    0           |
|   stages.4.blocks.0                               |   0.674M               |   0.173G   |   0.319M       |
|    stages.4.blocks.0.mlp                          |    0.591M              |    0.152G  |    0.295M      |
|    stages.4.blocks.0.spatial_mixing.partial_conv3 |    82.944K             |    21.234M |    24.576K     |
|   stages.5                                        |   1.181M               |   75.743M  |   49.152K      |
|    stages.5.reduction                             |    1.18M               |    75.497M |    49.152K     |
|    stages.5.norm                                  |    1.536K              |    0.246M  |    0           |
|   stages.6.blocks.0                               |   2.694M               |   0.173G   |   0.16M        |
|    stages.6.blocks.0.mlp                          |    2.362M              |    0.151G  |    0.147M      |
|    stages.6.blocks.0.spatial_mixing.partial_conv3 |    0.332M              |    21.234M |    12.288K     |
|  avgpool_pre_head                                 |  0.983M                |  1.032M    |  1.28K         |
|   avgpool_pre_head.1                              |   0.983M               |   0.983M   |   1.28K        |
|    avgpool_pre_head.1.weight                      |    (1280, 768, 1, 1)   |            |                |
|   avgpool_pre_head.0                              |                        |   49.152K  |   0            |
|  head                                             |  1.281M                |  1.28M     |  1K            |
|   head.weight                                     |   (1000, 1280)         |            |                |
|   head.bias                                       |   (1000,)              |            |                |

3.3 将PConv替换为Conv的FLops变化

将原来的Partial_conv3类代码替换为以下代码

class Partial_conv3(nn.Module):

    def __init__(self, dim, n_div, forward):
        super().__init__()
        self.conv = nn.Conv2d(dim, dim, 3, 1, 1, bias=False)

    def forward(self, x: Tensor) -> Tensor:
        # only for inference
        x = x.clone()   # !!! Keep the original input intact for the residual connection later
        x = self.conv(x)
        return x

再次运行以下代码后

if __name__=="__main__":
    model=FasterNet( depths=(1, 1, 1, 1),)
    from fvcore.nn import flop_count_table, FlopCountAnalysis, ActivationCountAnalysis    
    x = torch.randn(1, 3, 256, 256)
    # model = SAFMN(dim=36, n_blocks=12, ffn_scale=2.0, upscaling_factor=2)
    print(f'params: {sum(map(lambda x: x.numel(), model.parameters()))}')
    print(flop_count_table(FlopCountAnalysis(model, x), activations=ActivationCountAnalysis(model, x)))
    output = model(x)
    print(output.shape)

这里可以发现flops为2.22g,相比与原来的0.98g翻了一倍。在新的FasterBlock中,spatial_mixing.conv中flops的占比达到了70%,为0.34g,相比于原来的21m为16倍。

| module                                   | #parameters or shape   | #flops     | #activations   |
|:-----------------------------------------|:-----------------------|:-----------|:---------------|
| model                                    | 14.009M                | 2.222G     | 3.6M         |
|  patch_embed                             |  4.8K                  |  20.84M    |  0.393M        |
|   patch_embed.proj                       |   4.608K               |   18.874M  |   0.393M       |
|    patch_embed.proj.weight               |    (96, 3, 4, 4)       |            |                |
|   patch_embed.norm                       |   0.192K               |   1.966M   |   0            |
|    patch_embed.norm.weight               |    (96,)               |            |                |
|    patch_embed.norm.bias                 |    (96,)               |            |                |
|  stages                                  |  11.74M                |  2.199G    |  3.293M        |
|   stages.0.blocks.0                      |   0.12M                |   0.495G   |   1.573M       |
|    stages.0.blocks.0.mlp                 |    37.248K             |    0.155G  |    1.18M       |
|    stages.0.blocks.0.spatial_mixing.conv |    82.944K             |    0.34G   |    0.393M      |
|   stages.1                               |   74.112K              |   76.481M  |   0.197M       |
|    stages.1.reduction                    |    73.728K             |    75.497M |    0.197M      |
|    stages.1.norm                         |    0.384K              |    0.983M  |    0           |
|   stages.2.blocks.0                      |   0.48M                |   0.493G   |   0.786M       |
|    stages.2.blocks.0.mlp                 |    0.148M              |    0.153G  |    0.59M       |
|    stages.2.blocks.0.spatial_mixing.conv |    0.332M              |    0.34G   |    0.197M      |
|   stages.3                               |   0.296M               |   75.9M  |   98.304K      |
|    stages.3.reduction                    |    0.295M              |    75.497M |    98.304K     |
|    stages.3.norm                         |    0.768K              |    0.492M  |    0           |
|   stages.4.blocks.0                      |   1.918M               |   0.492G   |   0.393M       |
|    stages.4.blocks.0.mlp                 |    0.591M              |    0.152G  |    0.295M      |
|    stages.4.blocks.0.spatial_mixing.conv |    1.327M              |    0.34G   |    98.304K     |
|   stages.5                               |   1.181M               |   75.743M  |   49.152K      |
|    stages.5.reduction                    |    1.18M               |    75.497M |    49.152K     |
|    stages.5.norm                         |    1.536K              |    0.246M  |    0           |
|   stages.6.blocks.0                      |   7.671M               |   0.491G   |   0.197M       |
|    stages.6.blocks.0.mlp                 |    2.362M              |    0.151G  |    0.147M      |
|    stages.6.blocks.0.spatial_mixing.conv |    5.308M              |    0.34G   |    49.152K     |
|  avgpool_pre_head                        |  0.983M                |  1.032M    |  1.28K         |
|   avgpool_pre_head.1                     |   0.983M               |   0.983M   |   1.28K        |
|    avgpool_pre_head.1.weight             |    (1280, 768, 1, 1)   |            |                |
|   avgpool_pre_head.0                     |                        |   49.152K  |   0            |
|  head                                    |  1.281M                |  1.28M     |  1K            |
|   head.weight                            |   (1000, 1280)         |            |                |
|   head.bias                              |   (1000,)              |            |                |
torch.Size([1, 1000])

3.3 整体结论

基于3.1-3.3的分析,可以发现我们不能直接用pconv取代模型中所有的conv层,但可以在部分层中取代个别flops较大的conv中。pconv只是近似conv的一个选择,其仅是在FasterNet的架构设计下发挥作用,直接平替到其他模型中必然存在水土不服(需要额外的PWConv层实现信息交互)。

但是,FasterNet却为我们提供了一个强大的backbone,其在轻量级与重量级模型中均达到了最佳精度下的最快速度,可以用于图像分类、目标检测中。然后在我们的实验中,或许可以将FasterNet中的Pconv替换为DWConv,这样也许能再次提升backbone能力的提升。毕竟作者没有做这个对比,也说不定是发现Pconv不如DWConv后隐匿了这一部分实验数据

因篇幅问题不能全部显示,请点此查看更多更全内容

Copyright © 2019- obuygou.com 版权所有 赣ICP备2024042798号-5

违法及侵权请联系:TEL:199 18 7713 E-MAIL:2724546146@qq.com

本站由北京市万商天勤律师事务所王兴未律师提供法律服务