Caffe 源码阅读笔记 [基本模块] Solver

    xiaoxiao2023-03-24  3

    概述

    前面介绍了网络Net模块来描述一个神经网络,Solver负责生成训练网络和测试网络并按照算法对网络进行参数优化。跟Layer一样,Caffe把Solver实现成一个接口,使得开发者可以开发自己的Solver,其中Solver的子类需要实现ApplyUpdate函数来根据当前网络的状态对网络参数进行更新。我们可以通过SolverParameter给Solver配置一个训练网络和若干个测试网络。测试网络的作用是我们每进行若干次训练和参数更新之后,我们可以把当前的网络参数用在测试网络上测试效果。Solver也提供了接口使得客户端可以发出指令要求Solver终止训练(如果发现在测试集上效果变差)或者做快照(以便下次可以从快照开始继续训练而不用重新开始)。

    Solver成员变量

    SolverParameter param_; // Solver参数 int iter_; // 第几次迭代 int current_step_; // shared_ptr<Net<Dtype> > net_; // 要优化的网络 vector<shared_ptr<Net<Dtype> > > test_nets_; // 用于测试的网络 vector<Dtype> losses_; // 存储最后average_loss次迭代的loss值。 Dtype smoothed_loss_; // 当前的最后average_loss次迭代的loss的平均值 const Solver* const root_solver_; // 它包含了root_net(有共享Layer的网络) ActionCallback action_request_function_; // 客户端可以通过这个callback来要求Solver做一个snapshot或者退出 vector<Callback*> callbacks_;

    Solver初始化函数Init

    void Init(const SolverParameter& param) { param_ = param; InitTrainNet(); // 初始化训练网络 if (Caffe::root_solver()) { InitTestNets(); // 初始化测试网络 } iter_ = 0; current_step_ = 0; } void InitTrainNet() { NetParameter net_param; // 从Solver的参数中拷贝网络参数,也可以从文件param_.train_net()、param_.net_param()和param_.net()里读出来 net_param.CopyFrom(param_.train_net_param()); NetState net_state; // 从优先级由低到高设置网络的状态 net_state.set_phase(TRAIN); net_state.MergeFrom(net_param.state()); net_state.MergeFrom(param_.train_state()); net_param.mutable_state()->CopyFrom(net_state); if (Caffe::root_solver()) { // 构造一个新的root网络 net_.reset(new Net<Dtype>(net_param)); } else { // 构造一个新的non-root网络,有一部分网络是基于给定的root网络的 net_.reset(new Net<Dtype>(net_param, root_solver_->net_.get())); } } void Solver<Dtype>::InitTestNets() { int test_net_id = 0; vector<string> sources(num_test_net_instances); vector<NetParameter> net_params(num_test_net_instances); // 从test_net_param得到所有Test网络的设置,也可以从net_param, net文件,test_net文件里读出来 for (int i = 0; i < num_test_net_params; ++i, ++test_net_id) { sources[test_net_id] = "test_net_param"; net_params[test_net_id].CopyFrom(param_.test_net_param(i)); } test_nets_.resize(num_test_net_instances); for (int i = 0; i < num_test_net_instances; ++i) { // 从优先级由低到高设置网络的状态 NetState net_state; net_state.set_phase(TEST); net_state.MergeFrom(net_params[i].state()); if (param_.test_state_size()) { net_state.MergeFrom(param_.test_state(i)); } net_params[i].mutable_state()->CopyFrom(net_state); if (Caffe::root_solver()) { // 构造新的root网络 test_nets_[i].reset(new Net<Dtype>(net_params[i])); } else { // 构造新的non-root网络 test_nets_[i].reset(new Net<Dtype>(net_params[i], root_solver_->test_nets_[i].get())); } test_nets_[i]->set_debug_info(param_.debug_info()); } }

    Solver的Solve函数

    void Solver<Dtype>::Solve(const char* resume_file) { CHECK(Caffe::root_solver()); // 只有root_solver才能优化网络 requested_early_exit_ = false; // 如果被设为true,则退出 if (resume_file) { // 从保存好的snapshot开始,而不是从头开始 Restore(resume_file); } int start_iter = iter_; // 迭代到最大允许次数 Step(param_.max_iter() - iter_); if (requested_early_exit_) { return; } if (param_.display() && iter_ % param_.display() == 0) { // average_loss控制我们计算loss值是最后average_loss次迭代的平均值 int average_loss = this->param_.average_loss(); Dtype loss; net_->Forward(&loss); // 前向传播计算loss值 UpdateSmoothedLoss(loss, start_iter, average_loss); // 通过当前smoothed_loss_和loss_数组计算最后average_loss次迭代loss值的平均值 } if (param_.test_interval() && iter_ % param_.test_interval() == 0) { // 每test_interval次迭代计算一下在测试集的效果。 TestAll(); } } // 迭代iters次 void Solver<Dtype>::Step(int iters) { const int start_iter = iter_; const int stop_iter = iter_ + iters; int average_loss = this->param_.average_loss(); losses_.clear(); smoothed_loss_ = 0; while (iter_ < stop_iter) { // 把参数清空 net_->ClearParamDiffs(); // 调用callback for (int i = 0; i < callbacks_.size(); ++i) { callbacks_[i]->on_start(); } Dtype loss = 0; for (int i = 0; i < param_.iter_size(); ++i) { // 对网络先进行前向传播,再进行反向传播。然后计算loss的总和以求平均值 loss += net_->ForwardBackward(); } loss /= param_.iter_size(); //求平均值 // 计算最后average_loss次迭代的平均loss值smoothed_loss_ UpdateSmoothedLoss(loss, start_iter, average_loss); if (display) { // 打印网络输出blob的值,略 } // 调用callback for (int i = 0; i < callbacks_.size(); ++i) { callbacks_[i]->on_gradients_ready(); } ApplyUpdate(); //由Solver的子类实现来更新网络参数 ++iter_; // 获得SolverAction,根据客户端要求可以做Snapshot或者提前退出 SolverAction::Enum request = GetRequestedAction(); } }

    Solver的Test函数

    // 测试所有的数据 void TestAll() { for (int test_net_id = 0; test_net_id < test_nets_.size() && !requested_early_exit_; ++test_net_id) { Test(test_net_id); } } // 测试一个数据集 void Solver<Dtype>::Test(const int test_net_id) { vector<Dtype> test_score; // 存储网络输出blob的所有值 vector<int> test_score_output_id; // 存储test_score[i]对应的top blob的id const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id]; Dtype loss = 0; // 做test_iter次计算,取平均loss值 for (int i = 0; i < param_.test_iter(test_net_id); ++i) { // 根据SolverAction决定是否做snapshot或者退出程序,略 SolverAction::Enum request = GetRequestedAction(); Dtype iter_loss; // 做前向传播,计算loss值 const vector<Blob<Dtype>*>& result = test_net->Forward(&iter_loss); // 如果计算loss,那么累加每次迭代的loss值 if (param_.test_compute_loss()) { loss += iter_loss; } if (i == 0) { for (int j = 0; j < result.size(); ++j) { const Dtype* result_vec = result[j]->cpu_data(); for (int k = 0; k < result[j]->count(); ++k) { // 把所有top blob打平到一维数组test_score上, 并在test_score_output_id记录对应的blob id test_score.push_back(result_vec[k]); test_score_output_id.push_back(j); } } } else { int idx = 0; for (int j = 0; j < result.size(); ++j) { const Dtype* result_vec = result[j]->cpu_data(); for (int k = 0; k < result[j]->count(); ++k) { // 累加所有top blob对应的元素,最后是要再除以test_iter以求平均值的 test_score[idx++] += result_vec[k]; } } } } if (param_.test_compute_loss()) { // 计算平均值并打印 loss /= param_.test_iter(test_net_id); } for (int i = 0; i < test_score.size(); ++i) { // 对每个test_score,得到它对应的top blob的名字和loss权重 const int output_blob_index = test_net->output_blob_indices()[test_score_output_id[i]]; const string& output_name = test_net->blob_names()[output_blob_index]; const Dtype loss_weight = test_net->blob_loss_weights()[output_blob_index]; ostringstream loss_msg_stream; // 求出平均的test_score const Dtype mean_score = test_score[i] / param_.test_iter(test_net_id); // 如果loss_weight不为0,则计算加权的mean score。 if (loss_weight) { loss_msg_stream << " (* " << loss_weight << " = " << loss_weight * mean_score << " loss)"; } } }
    转载请注明原文地址: https://ju.6miu.com/read-1201433.html
    最新回复(0)