概述
前面介绍了网络Net模块来描述一个神经网络,Solver负责生成训练网络和测试网络并按照算法对网络进行参数优化。跟Layer一样,Caffe把Solver实现成一个接口,使得开发者可以开发自己的Solver,其中Solver的子类需要实现ApplyUpdate函数来根据当前网络的状态对网络参数进行更新。我们可以通过SolverParameter给Solver配置一个训练网络和若干个测试网络。测试网络的作用是我们每进行若干次训练和参数更新之后,我们可以把当前的网络参数用在测试网络上测试效果。Solver也提供了接口使得客户端可以发出指令要求Solver终止训练(如果发现在测试集上效果变差)或者做快照(以便下次可以从快照开始继续训练而不用重新开始)。
Solver成员变量
SolverParameter param_;
int iter_;
int current_step_;
shared_ptr<Net<Dtype> > net_;
vector<shared_ptr<Net<Dtype> > > test_nets_;
vector<Dtype> losses_;
Dtype smoothed_loss_;
const Solver*
const root_solver_;
ActionCallback action_request_function_;
vector<Callback*> callbacks_;
Solver初始化函数Init
void Init(
const SolverParameter& param) {
param_ = param;
InitTrainNet();
if (Caffe::root_solver()) {
InitTestNets();
}
iter_ =
0;
current_step_ =
0;
}
void InitTrainNet() {
NetParameter net_param;
net_param.CopyFrom(param_.train_net_param());
NetState net_state;
net_state.set_phase(TRAIN);
net_state.MergeFrom(net_param.state());
net_state.MergeFrom(param_.train_state());
net_param.mutable_state()->CopyFrom(net_state);
if (Caffe::root_solver()) {
net_.reset(
new Net<Dtype>(net_param));
}
else {
net_.reset(
new Net<Dtype>(net_param, root_solver_->net_.get()));
}
}
void Solver<Dtype>::InitTestNets() {
int test_net_id =
0;
vector<string> sources(num_test_net_instances);
vector<NetParameter> net_params(num_test_net_instances);
for (
int i =
0; i < num_test_net_params; ++i, ++test_net_id) {
sources[test_net_id] =
"test_net_param";
net_params[test_net_id].CopyFrom(param_.test_net_param(i));
}
test_nets_.resize(num_test_net_instances);
for (
int i =
0; i < num_test_net_instances; ++i) {
NetState net_state;
net_state.set_phase(TEST);
net_state.MergeFrom(net_params[i].state());
if (param_.test_state_size()) {
net_state.MergeFrom(param_.test_state(i));
}
net_params[i].mutable_state()->CopyFrom(net_state);
if (Caffe::root_solver()) {
test_nets_[i].reset(
new Net<Dtype>(net_params[i]));
}
else {
test_nets_[i].reset(
new Net<Dtype>(net_params[i],
root_solver_->test_nets_[i].get()));
}
test_nets_[i]->set_debug_info(param_.debug_info());
}
}
Solver的Solve函数
void Solver<Dtype>::Solve(
const char* resume_file) {
CHECK(Caffe::root_solver());
requested_early_exit_ =
false;
if (resume_file) {
Restore(resume_file);
}
int start_iter = iter_;
Step(param_.max_iter() - iter_);
if (requested_early_exit_) {
return; }
if (param_.display() && iter_ % param_.display() ==
0) {
int average_loss =
this->param_.average_loss();
Dtype loss;
net_->Forward(&loss);
UpdateSmoothedLoss(loss, start_iter, average_loss);
}
if (param_.test_interval() && iter_ % param_.test_interval() ==
0) {
TestAll();
}
}
void Solver<Dtype>::Step(
int iters) {
const int start_iter = iter_;
const int stop_iter = iter_ + iters;
int average_loss =
this->param_.average_loss();
losses_.clear();
smoothed_loss_ =
0;
while (iter_ < stop_iter) {
net_->ClearParamDiffs();
for (
int i =
0; i < callbacks_.size(); ++i) {
callbacks_[i]->on_start();
}
Dtype loss =
0;
for (
int i =
0; i < param_.iter_size(); ++i) {
loss += net_->ForwardBackward();
}
loss /= param_.iter_size();
UpdateSmoothedLoss(loss, start_iter, average_loss);
if (display) {
}
for (
int i =
0; i < callbacks_.size(); ++i) {
callbacks_[i]->on_gradients_ready();
}
ApplyUpdate();
++iter_;
SolverAction::Enum request = GetRequestedAction();
}
}
Solver的Test函数
void TestAll() {
for (
int test_net_id =
0;
test_net_id < test_nets_.size() && !requested_early_exit_;
++test_net_id) {
Test(test_net_id);
}
}
void Solver<Dtype>::Test(
const int test_net_id) {
vector<Dtype> test_score;
vector<int> test_score_output_id;
const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
Dtype loss =
0;
for (
int i =
0; i < param_.test_iter(test_net_id); ++i) {
SolverAction::Enum request = GetRequestedAction();
Dtype iter_loss;
const vector<Blob<Dtype>*>& result = test_net->Forward(&iter_loss);
if (param_.test_compute_loss()) {
loss += iter_loss;
}
if (i ==
0) {
for (
int j =
0; j < result.size(); ++j) {
const Dtype* result_vec = result[j]->cpu_data();
for (
int k =
0; k < result[j]->count(); ++k) {
test_score.push_back(result_vec[k]);
test_score_output_id.push_back(j);
}
}
}
else {
int idx =
0;
for (
int j =
0; j < result.size(); ++j) {
const Dtype* result_vec = result[j]->cpu_data();
for (
int k =
0; k < result[j]->count(); ++k) {
test_score[idx++] += result_vec[k];
}
}
}
}
if (param_.test_compute_loss()) {
loss /= param_.test_iter(test_net_id);
}
for (
int i =
0; i < test_score.size(); ++i) {
const int output_blob_index =
test_net->output_blob_indices()[test_score_output_id[i]];
const string& output_name = test_net->blob_names()[output_blob_index];
const Dtype loss_weight = test_net->blob_loss_weights()[output_blob_index];
ostringstream loss_msg_stream;
const Dtype mean_score = test_score[i] / param_.test_iter(test_net_id);
if (loss_weight) {
loss_msg_stream <<
" (* " << loss_weight
<<
" = " << loss_weight * mean_score <<
" loss)";
}
}
}