♪⁶ 步骤
1.卸载库中的ultralytics
bash1pip uninstall ultralytics2pip uninstall ultralytics-thop3
2.找到ultralytics/ultralytics/nn/modules路径
创建SA.py
内容为:
Python1import numpy as np2import torch3from torch import nn4from torch.nn import init5from torch.nn.parameter import Parameter67class ShuffleAttention(nn.Module):8 def __init__(self, channel=512, out_channel=512, reduction=16, G=8):9 super().__init__()10 self.G = G11 self.channel = channel12 self.avg_pool = nn.AdaptiveAvgPool2d(1)13 self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))14 self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))15 self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))16 self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))17 self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))18 self.sigmoid = nn.Sigmoid()19 def init_weights(self):20 for m in self.modules():21 if isinstance(m, nn.Conv2d):22 init.kaiming_normal_(m.weight, mode='fan_out')23 if m.bias is not None:24 init.constant_(m.bias, 0)25 elif isinstance(m, nn.BatchNorm2d):26 init.constant_(m.weight, 1)27 init.constant_(m.bias, 0)28 elif isinstance(m, nn.Linear):29 init.normal_(m.weight, std=0.001)30 if m.bias is not None:31 init.constant_(m.bias, 0)32 @staticmethod33 def channel_shuffle(x, groups):34 b, c, h, w = x.shape35 x = x.reshape(b, groups, -1, h, w)36 x = x.permute(0, 2, 1, 3, 4)37 # flatten38 x = x.reshape(b, -1, h, w)39 return x40 def forward(self, x):41 b, c, h, w = x.size()42 # group into subfeatures43 x = x.view(b * self.G, -1, h, w) # bs*G,c//G,h,w44 # channel_split45 x_0, x_1 = x.chunk(2, dim=1) # bs*G,c//(2*G),h,w46 # channel attention47 x_channel = self.avg_pool(x_0) # bs*G,c//(2*G),1,148 x_channel = self.cweight * x_channel + self.cbias # bs*G,c//(2*G),1,149 x_channel = x_0 * self.sigmoid(x_channel)50 # spatial attention51 x_spatial = self.gn(x_1) # bs*G,c//(2*G),h,w52 x_spatial = self.sweight * x_spatial + self.sbias # bs*G,c//(2*G),h,w53 x_spatial = x_1 * self.sigmoid(x_spatial) # bs*G,c//(2*G),h,w54 # concatenate along channel axis55 out = torch.cat([x_channel, x_spatial], dim=1) # bs*G,c//G,h,w56 out = out.contiguous().view(b, -1, h, w)57 # channel shuffle58 out = self.channel_shuffle(out, 2)59 return out60
3.在同一路径下修改init.py文件
添加
Python1from .SA import ShuffleAttention2
并修改__all__变量
Python1将 "ShuffleAttention" 加到 __all__中2
4.找到ultralytics/ultralytics/nn/tasks.py文件
添加
Python1from ultralytics.nn.modules import ShuffleAttention2
并找到这一行代码
Python1elif m is torch.nn.BatchNorm2d:2 args = [ch[f]]3
在其上方添加以下代码
Python1elif m is ShuffleAttention: # add shuffle attention2 c1, c2 = ch[f], args[0]3 if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)4 c2 = make_divisible(min(c2, max_channels) * width, 8)5 args = [c1, c2, *args[1:]]6
5.修改yaml配置文件
复制一份ultralytics/ultralytics/cfg/models/v8/yolov8.yaml文件
命名为yolov8-SA.yaml,并修改为以下内容,也可以自行修改。
yaml1# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license2# Ultralytics YOLOv8 object detection model with P3/8 - P5/32 outputs3# Model docs: https://docs.ultralytics.com/models/yolov84# Task docs: https://docs.ultralytics.com/tasks/detect56# Parameters7nc: 80 # number of classes8scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'9 # [depth, width, max_channels]10 n: [0.33, 0.25, 1024] # YOLOv8n summary: 129 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPS11 s: [0.33, 0.50, 1024] # YOLOv8s summary: 129 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPS12 m: [0.67, 0.75, 768] # YOLOv8m summary: 169 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPS13 l: [1.00, 1.00, 512] # YOLOv8l summary: 209 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPS14 x: [1.00, 1.25, 512] # YOLOv8x summary: 209 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPS1516# YOLOv8.0n backbone17backbone:18 # [from, repeats, module, args]19 - [-1, 1, Conv, [64, 3, 2]] # 0-P1/220 - [-1, 1, Conv, [128, 3, 2]] # 1-P2/421 - [-1, 3, C2f, [128, True]]22 - [-1, 1, Conv, [256, 3, 2]] # 3-P3/823 - [-1, 6, C2f, [256, True]]24 - [-1, 1, Conv, [512, 3, 2]] # 5-P4/1625 - [-1, 6, C2f, [512, True]]26 - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/3227 - [-1, 3, C2f, [1024, True]]28 - [-1, 3, ShuffleAttention, [1024]]29 - [-1, 1, SPPF, [1024, 5]] # 93031# YOLOv8.0n head32head:33 - [-1, 1, nn.Upsample, [None, 2, "nearest"]]34 - [[-1, 6], 1, Concat, [1]] # cat backbone P435 - [-1, 3, C2f, [512]] # 123637 - [-1, 1, nn.Upsample, [None, 2, "nearest"]]38 - [[-1, 4], 1, Concat, [1]] # cat backbone P339 - [-1, 3, C2f, [256]] # 15 (P3/8-small)4041 - [-1, 1, ShuffleAttention, [256, 16, 8]]42 - [-1, 1, Conv, [256, 3, 2]]43 - [[-1, 12], 1, Concat, [1]] # cat head P444 - [-1, 3, C2f, [512]] # 18 (P4/16-medium)4546 - [-1, 1, ShuffleAttention, [512, 16, 8]]47 - [-1, 1, Conv, [512, 3, 2]]48 - [[-1, 9], 1, Concat, [1]] # cat head P549 - [-1, 3, C2f, [1024]] # 21 (P5/32-large)5051 - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)52
6.修改训练脚本:
注释 # model = YOLO("pre_weight/yolov8n.pt")
该代码会反序列化模型,使修改后的模型失效
正确加载预训练权重的方法是:
ln1import torch2# Load a model3model = YOLO("ultralytics/cfg/models/v8/YOLOv8-SA.yaml") # build a new model from scratch4ckpt = torch.load("pre_weight/yolov8n.pt", map_location="cpu")5state = ckpt["model"].float().state_dict()6model.model.load_state_dict(state, strict=False) # 只匹配同名层7

