摸鱼之路(一)

编程入门 行业动态 更新时间:2024-10-25 06:24:25

摸鱼<a href=https://www.elefans.com/category/jswz/34/1770107.html style=之路(一)"/>

摸鱼之路(一)

       第一篇博客,浅记录一下菜鸟读代码的心酸,又不想太摸鱼,一天一点进度,希望这个月能完成大部分,在这里留个足迹来安慰自己,顺便养成好习惯。 

      (意识流写文章,想到什么写什么吧。。。自己看的懂就好)

       又是看OLTR代码发懵的一天。代码还没跑通,现在在改github上的代码,想先把CIFAR10的数据集代替进去,代码架构复杂得一匹。

       捋一下代码的运行流程。

parser = argparse.ArgumentParser()
parser.add_argument('--config', default='./temp.py', type=str)
parser.add_argument('--test', default=False, action='store_true')
args = parser.parse_args()
test_mode = args.test
config = source_import("./temp.py").config
training_opt = config['training_opt']

    读了parser其实就是读一下config字典和test_mode而已,其他的不需要用到就删了。

  (注:test_mode好像要改了,直接拆成两个代码)

     先看一下训练mode里面的。。

sampler_defs = training_opt['sampler']if sampler_defs:sampler_dic = {'sampler': source_import(sampler_defs['def_file']).get_sampler(),'num_samples_cls': sampler_defs['num_samples_cls']}else:sampler_dic = Nonedata = {x: load_data(phase=x,batch_size=training_opt['batch_size'],sampler_dic=sampler_dic,num_workers=training_opt['num_workers'])for x in (['train', 'val', 'train_plain'] if relatin_opt['init_centroids'] else ['train', 'val'])}training_model = model(config, data, test=False)training_model.train()

     这边有个sampler采样器,在shuffle为true的时候会自动调用,默认是顺序采样,不过这里好像写了一个自定义的采样器。(还没看明白)

      注意到需不需要采样器取决于training_opt,这是配置字典里面的一个字段。注意到stage1是不需要采样器的。采样器是作为data参数的。

      最主要是这个data,在这里面改的应该会比较多。

      data这里有个relatin_opt,是等于config['memory']的,里面有centroids和init-centroids两个参数,stage-1里面两者都是false,元嵌入两者都是true。所以在stage1中,data中只有两个,train和validation。

        所以对于stage1来说,data就是{train: load_data(phase='train', batch_size=64, sampler=None, num_worker=4, val: load_data(phase='val', batch_size=64, sampler=None, num_worker=4)}。在dataloader.py里面具体改动(主要使返回值,github代码里面有个.txt日志可以删除掉)。

        好的。然后train_model等于model传三个参数,config,data以及test=false。

        重头戏还是在runnet.py里面,我是真的看不懂。。。

self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')self.config = configself.training_opt = self.config['training_opt']self.memory = self.config['memory']self.data = dataself.test_mode = test# Initialize modelself.init_models()# Under training mode, initialize training steps, optimizers, schedulers, criterions, and centroidsif not self.test_mode:# If using steps for training, we need to calculate training steps# for each epoch based on actual number of training data instead of# oversampled data numberprint('Using steps for training.')self.training_data_num = len(self.data['train'].dataset)self.epoch_steps = int(self.training_data_num/ self.training_opt['batch_size'])# Initialize model optimizer and schedulerprint('Initializing model optimizer.')self.scheduler_params = self.training_opt['scheduler_params']self.model_optimizer, \self.model_optimizer_scheduler = self.init_optimizers(self.model_optim_params_list)self.init_criterions()if self.memory['init_centroids']:##如果要初始化质心self.criterions['FeatureLoss'].centroids.data = \self.centroids_cal(self.data['train_plain'])##

   看一下这个init代码, 初始化了config, training_opt, data, 还有test(这里等于False)。

   train模式要计算steps for training, 包括数据集大小,epoch大小,总体学习率, 还要初始化优化器参数,初始化损失函数。stage_1需要不需要初始化质心,stage_2需要初始化质心。

   

 __init_model__模块初始化model,networks配置的字典有网络本身以及分类器,每个都由基本参数和优化器参数组成。stage1_weight参数需要注意一下,stage_1是不需要用到的。这里根据参数create出网络模型和分类器(很奇怪不知道为什么可以??)。

    fix这个地方不是很懂。以后来补。

    

def train(self):# When training the networkprint_str = ['Phase: train']time.sleep(0.25)# Initialize best modelbest_model_weights = {}best_model_weights['feat_model'] = copy.deepcopy(selfworks['feat_model'].state_dict())best_model_weights['classifier'] = copy.deepcopy(selfworks['classifier'].state_dict())best_acc = 0.0best_epoch = 0end_epoch = self.training_opt['num_epochs']# Loop over epochsfor epoch in range(1, end_epoch + 1):for model in selfworks.values():model.train()torch.cuda.empty_cache()# Iterate over datasetfor step, (inputs, labels, _) in enumerate(self.data['train']):# Break when step equal to epoch stepif step == self.epoch_steps:breakinputs, labels = inputs.to(self.device), labels.to(self.device)# If on training phase, enable gradientswith torch.set_grad_enabled(True):# If training, forward with loss, and no top 5 accuracy calculationself.batch_forward(inputs, labels,centroids=self.memory['centroids'],phase='train')self.batch_loss(labels)self.batch_backward()# Output minibatch training resultsif step % self.training_opt['display_step'] == 0:minibatch_loss_feat = self.loss_feat.item() \if 'FeatureLoss' in self.criterions.keys() else Noneminibatch_loss_perf = self.loss_perf.item()_, preds = torch.max(self.logits, 1)minibatch_acc = mic_acc_cal(preds, labels)print_str = ['Epoch: [%d/%d]'% (epoch, self.training_opt['num_epochs']),'Step: %5d'% (step),'Minibatch_loss_feature: %.3f'% (minibatch_loss_feat) if minibatch_loss_feat else '','Minibatch_loss_performance: %.3f'% (minibatch_loss_perf),'Minibatch_accuracy_micro: %.3f'% (minibatch_acc)]# Set model modes and set scheduler# In training, step optimizer scheduler and set model to train()self.model_optimizer_scheduler.step()if self.criterion_optimizer:self.criterion_optimizer_scheduler.step()# After every epoch, validationself.eval(phase='val')# Under validation, the best model need to be updatedif self.eval_acc_mic_top1 > best_acc:best_epoch = copy.deepcopy(epoch)best_acc = copy.deepcopy(self.eval_acc_mic_top1)best_centroids = copy.deepcopy(self.centroids)best_model_weights['feat_model'] = copy.deepcopy(selfworks['feat_model'].state_dict())best_model_weights['classifier'] = copy.deepcopy(selfworks['classifier'].state_dict())print()print('Training Complete.')print_str = ['Best validation accuracy is %.3f at epoch %d' % (best_acc, best_epoch)]# Save the best model and best centroids if calculatedprint(print_str,"*********")self.save_model(epoch, best_epoch, best_model_weights, best_acc, centroids=best_centroids)print('Done')

train模块这里也比较重要。

在batch_forward传入参数值为(data,lable,False,'train')(Train模式)。

---------------------------------------------------------------------------------------------------------------------------------

分界线11.05

更多推荐

摸鱼之路(一)

本文发布于:2024-02-27 06:18:26,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1705499.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:之路   摸鱼

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!