Categorías
Otros

Pytorch corrige algunos parámetros (solo entrena algunas capas)

Referencia:https://www.cnblogs.com/jiangkejie/p/11199847.html

En el aprendizaje de transferencia, a menudo usamos modelos pre-entrenados y agregamos capas adicionales a los modelos pre-entrenados. Al entrenar, los parámetros de la capa de pre-entrenamiento se fijan primero, y solo se entrenan las piezas adicionales. Después de terminar todos los entrenamientos y afinación.

Cuando pytorch corrige algunos parámetros para el entrenamiento, es necesario aplicar el filtrado en el optimizador.

class RESNET_attention(nn.Module):
    def __init__(self, model, pretrained):
        super(RESNET_attetnion, self).__init__()
        self.resnet = model(pretrained)
        for p in self.parameters():
            p.requires_grad = False
        self.f = nn.Conv2d(2048, 512, 1)
        self.g = nn.Conv2d(2048, 512, 1)
        self.h = nn.Conv2d(2048, 2048, 1)
        self.softmax = nn.Softmax(-1)
        self.gamma = nn.Parameter(torch.FloatTensor([0.0]))
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.resnet.fc = nn.Linear(2048, 10)

nota: El código anterior reproduce la parte de atención de SAGAN, este no es el principal problema

De esta manera, los parámetros por encima del bucle for son fijos, y sólo se entrenan los siguientes parámetros (f, g, h, gamma, fc, etc.), pero tenga en cuenta que debe agregar una oración de este tipo al filtro del optimizador (lambda p: p.requires_grad, model.parameters()
La ubicación añadida es:
optimizador = optim. Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)

.

  [MYSQL Statement Optimization] (Artículo 3) Optimización de detalles simples

Por Programación.Click

Más de 20 años programando en diferentes lenguajes de programación. Apasionado del code clean y el terminar lo que se empieza. ¿Programamos de verdad?

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *