TensorFlow Session的Setup完成整个Session的创建,设置输入数据类型(feeds)和输出数据类型(fetches)。然后利用图Graph创建基于Session的基本图,开启线程器(Exectors)等待Session开始。
回目录
SessionPRunSetup
主程序调用c_api.cc中的TF_SessionPRunSetup完成Session的Setup。
TF_SessionPRunSetup(sess, feeds, TF_ARRAYSIZE(feeds), fetches,
TF_ARRAYSIZE(fetches), NULL, 0, &handle, s);
void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
int ninputs, const TF_Output* outputs, int noutputs,
const TF_Operation* const* target_opers, int ntargets,
const char** handle, TF_Status* status) {
*handle = nullptr;
if (session->extend_before_run &&
!ExtendSessionGraphHelper(session, status)) {
return;
}
std::vector<string> input_names(ninputs);
for (int i = 0; i < ninputs; ++i) {
input_names[i] = OutputName(inputs[i]);
}
std::vector<string> output_names(noutputs);
for (int i = 0; i < noutputs; ++i) {
output_names[i] = OutputName(outputs[i]);
}
std::vector<string> target_names(ntargets);
for (int i = 0; i < ntargets; ++i) {
target_names[i] = target_opers[i]->node.name();
}
string new_handle;
status->status = session->session->PRunSetup(input_names, output_names,
target_names, &new_handle);
if (TF_GetCode(status) == TF_OK) {
char* buf = new char[new_handle.size() + 1];
memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
*handle = buf;
}
} |
TF_SessionPRunSetup(sess, feeds, TF_ARRAYSIZE(feeds), fetches,
TF_ARRAYSIZE(fetches), NULL, 0, &handle, s);
void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
int ninputs, const TF_Output* outputs, int noutputs,
const TF_Operation* const* target_opers, int ntargets,
const char** handle, TF_Status* status) {
*handle = nullptr;
if (session->extend_before_run &&
!ExtendSessionGraphHelper(session, status)) {
return;
}
std::vector<string> input_names(ninputs);
for (int i = 0; i < ninputs; ++i) {
input_names[i] = OutputName(inputs[i]);
}
std::vector<string> output_names(noutputs);
for (int i = 0; i < noutputs; ++i) {
output_names[i] = OutputName(outputs[i]);
}
std::vector<string> target_names(ntargets);
for (int i = 0; i < ntargets; ++i) {
target_names[i] = target_opers[i]->node.name();
}
string new_handle;
status->status = session->session->PRunSetup(input_names, output_names,
target_names, &new_handle);
if (TF_GetCode(status) == TF_OK) {
char* buf = new char[new_handle.size() + 1];
memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
*handle = buf;
}
}
继续阅读“TensorFlow Session的Setup”本作品采用知识共享署名 4.0 国际许可协议进行许可。