Skip to content

Conversation

@nlper01
Copy link

@nlper01 nlper01 commented Sep 13, 2022

import numpy as np
import torch
from torch import nn
from torch.nn import init

class PSA(nn.Module):

def __init__(self, channel=512, reduction=4, S=4):
    super().__init__()
    self.S = S

    self.convs = nn.ModuleList(
        [nn.Conv2d(channel // S, channel // S, kernel_size=2 * (i + 1) + 1, padding=(i + 1)) for i in range(S)])
    # self.convs=[]
    # for i in range(S):
    #     self.convs.append(nn.Conv2d(channel//S,channel//S,kernel_size=2*(i+1)+1,padding=i+1))

    self.se_blocks = nn.ModuleList(
        nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channel // S, channel // (S * reduction), kernel_size=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel // (S * reduction), channel // S, kernel_size=1, bias=False),
            nn.Sigmoid()
        ) for i in range(S)
    )
    # self.se_blocks=[]
    # for i in range(S):
    #     self.se_blocks.append(nn.Sequential(
    #         nn.AdaptiveAvgPool2d(1),
    #         nn.Conv2d(channel//S, channel // (S*reduction),kernel_size=1, bias=False),
    #         nn.ReLU(inplace=True),
    #         nn.Conv2d(channel // (S*reduction), channel//S,kernel_size=1, bias=False),
    #         nn.Sigmoid()
    #     ))

    self.softmax = nn.Softmax(dim=1)


def init_weights(self):
    for m in self.modules():
        if isinstance(m, nn.Conv2d):
            init.kaiming_normal_(m.weight, mode='fan_out')
            if m.bias is not None:
                init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            init.constant_(m.weight, 1)
            init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            init.normal_(m.weight, std=0.001)
            if m.bias is not None:
                init.constant_(m.bias, 0)

def forward(self, x):
    b, c, h, w = x.size()

    #Step1:SPC module
    SPC_out=x.view(b,self.S,c//self.S,h,w) #bs,s,ci,h,w
    for idx,conv in enumerate(self.convs):
        SPC_out[:,idx,:,:,:]=conv(SPC_out[:,idx,:,:,:])

    #Step2:SE weight
    se_out=[]
    for idx,se in enumerate(self.se_blocks):
        se_out.append(se(SPC_out[:,idx,:,:,:]))
    SE_out=torch.stack(se_out,dim=1)
    SE_out=SE_out.expand_as(SPC_out)

    #Step3:Softmax
    softmax_out=self.softmax(SE_out)

    #Step4:SPA
    PSA_out=SPC_out*softmax_out
    PSA_out=PSA_out.view(b,-1,h,w)

    return PSA_out

if name == 'main':
device = torch.device('cuda')
input = torch.randn(8, 512, 7, 7).to(device)
psa = PSA(channel=512, reduction=8).to(device)
output = psa(input)
a = output.view(-1).sum()
a.backward()
print(output.shape)

解决了PSA.py模块用到自己的网络中时会出现 RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same的问题
不过还存在就地操作问题,梯度计算出错,希望大佬能够帮忙解决一下 RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [8, 128, 7, 7]], which is output 0 of AsStridedBackward0, is at version 4; expected version 3 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

解决了用到自己的网络中时会出现 RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same的问题
不过还存在就地操作问题,梯度计算出错,希望大佬能够帮忙解决一下 RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [8, 128, 7, 7]], which is output 0 of AsStridedBackward0, is at version 4; expected version 3 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
@zeng-cy
Copy link

zeng-cy commented Dec 2, 2022

one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [4, 64, 20, 20]], which is output 0 of ReluBackward1, is at version 5; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).只改这个nn.ModuleList,确实出现一样的问题

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants