首页 > 编程知识 正文

dann的alpha torchtrainpy eleven11wangpytorchDANN Giteecom

时间:2023-05-06 16:06:35 阅读:193946 作者:2767

import torch

import numpy as np

import utils

import torch.optim as optim

import torch.nn as nn

import test

import mnist

import mnistm

from utils import save_model

from utils import visualize

import params

# Source : 0, Target :1

source_test_loader = mnist.mnist_test_loader

target_test_loader = mnistm.mnistm_test_loader

def source_only(encoder, classifier, discriminator, source_train_loader, target_train_loader, save_name):

print("Source-only training")

for epoch in range(params.epochs):

print('Epoch : {}'.format(epoch))

encoder = encoder.train()

classifier = classifier.train()

discriminator = discriminator.train()

classifier_criterion = nn.CrossEntropyLoss().cuda()

start_steps = epoch * len(source_train_loader)

total_steps = params.epochs * len(target_train_loader)

for batch_idx, (source_data, target_data) in enumerate(zip(source_train_loader, target_train_loader)):

source_image, source_label = source_data

p = float(batch_idx + start_steps) / total_steps

source_image = torch.cat((source_image, source_image, source_image), 1) # MNIST convert to 3 channel

source_image, source_label = source_image.cuda(), source_label.cuda() # 32

optimizer = optim.SGD(

list(encoder.parameters()) +

list(classifier.parameters()),

lr=0.01, momentum=0.9)

optimizer = utils.optimizer_scheduler(optimizer=optimizer, p=p)

optimizer.zero_grad()

source_feature = encoder(source_image)

# Classification loss

class_pred = classifier(source_feature)

class_loss = classifier_criterion(class_pred, source_label)

class_loss.backward()

optimizer.step()

if (batch_idx + 1) % 50 == 0:

print('[{}/{} ({:.0f}%)]tClass Loss: {:.6f}'.format(batch_idx * len(source_image), len(source_train_loader.dataset), 100. * batch_idx / len(source_train_loader), class_loss.item()))

if (epoch + 1) % 10 == 0:

test.tester(encoder, classifier, discriminator, source_test_loader, target_test_loader, training_mode='source_only')

save_model(encoder, classifier, discriminator, 'source', save_name)

visualize(encoder, 'source', save_name)

def dann(encoder, classifier, discriminator, source_train_loader, target_train_loader, save_name):

print("DANN training")

for epoch in range(params.epochs):

print('Epoch : {}'.format(epoch))

encoder = encoder.train()

classifier = classifier.train()

discriminator = discriminator.train()

classifier_criterion = nn.CrossEntropyLoss().cuda()

discriminator_criterion = nn.CrossEntropyLoss().cuda()

start_steps = epoch * len(source_train_loader)

total_steps = params.epochs * len(target_train_loader)

for batch_idx, (source_data, target_data) in enumerate(zip(source_train_loader, target_train_loader)):

source_image, source_label = source_data

target_image, target_label = target_data

p = float(batch_idx + start_steps) / total_steps

alpha = 2. / (1. + np.exp(-10 * p)) - 1

source_image = torch.cat((source_image, source_image, source_image), 1)

source_image, source_label = source_image.cuda(), source_label.cuda()

target_image, target_label = target_image.cuda(), target_label.cuda()

combined_image = torch.cat((source_image, target_image), 0)

optimizer = optim.SGD(

list(encoder.parameters()) +

list(classifier.parameters()) +

list(discriminator.parameters()),

lr=0.01,

momentum=0.9)

optimizer = utils.optimizer_scheduler(optimizer=optimizer, p=p)

optimizer.zero_grad()

combined_feature = encoder(combined_image)

source_feature = encoder(source_image)

# 1.Classification loss

class_pred = classifier(source_feature)

class_loss = classifier_criterion(class_pred, source_label)

# 2. Domain loss

domain_pred = discriminator(combined_feature, alpha)

domain_source_labels = torch.zeros(source_label.shape[0]).type(torch.LongTensor)

domain_target_labels = torch.ones(target_label.shape[0]).type(torch.LongTensor)

domain_combined_label = torch.cat((domain_source_labels, domain_target_labels), 0).cuda()

domain_loss = discriminator_criterion(domain_pred, domain_combined_label)

total_loss = class_loss + domain_loss

total_loss.backward()

optimizer.step()

if (batch_idx + 1) % 50 == 0:

print('[{}/{} ({:.0f}%)]tLoss: {:.6f}tClass Loss: {:.6f}tDomain Loss: {:.6f}'.format(

batch_idx * len(target_image), len(target_train_loader.dataset), 100. * batch_idx / len(target_train_loader), total_loss.item(), class_loss.item(), domain_loss.item()))

if (epoch + 1) % 10 == 0:

test.tester(encoder, classifier, discriminator, source_test_loader, target_test_loader, training_mode='dann')

save_model(encoder, classifier, discriminator, 'source', save_name)

visualize(encoder, 'source', save_name)

版权声明:该文观点仅代表作者本人。处理文章:请发送邮件至 三1五14八八95#扣扣.com 举报,一经查实,本站将立刻删除。