原创文章,转载请注明: 转载自慢慢的回味
本文链接地址: TensorFlow Session的Setup
TensorFlow Session的Setup完成整个Session的创建,设置输入数据类型(feeds)和输出数据类型(fetches)。然后利用图Graph创建基于Session的基本图,开启线程器(Exectors)等待Session开始。
回目录
SessionPRunSetup
主程序调用c_api.cc中的TF_SessionPRunSetup完成Session的Setup。
TF_SessionPRunSetup(sess, feeds, TF_ARRAYSIZE(feeds), fetches, TF_ARRAYSIZE(fetches), NULL, 0, &handle, s); void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs, int ninputs, const TF_Output* outputs, int noutputs, const TF_Operation* const* target_opers, int ntargets, const char** handle, TF_Status* status) { *handle = nullptr; if (session->extend_before_run && !ExtendSessionGraphHelper(session, status)) { return; } std::vector<string> input_names(ninputs); for (int i = 0; i < ninputs; ++i) { input_names[i] = OutputName(inputs[i]); } std::vector<string> output_names(noutputs); for (int i = 0; i < noutputs; ++i) { output_names[i] = OutputName(outputs[i]); } std::vector<string> target_names(ntargets); for (int i = 0; i < ntargets; ++i) { target_names[i] = target_opers[i]->node.name(); } string new_handle; status->status = session->session->PRunSetup(input_names, output_names, target_names, &new_handle); if (TF_GetCode(status) == TF_OK) { char* buf = new char[new_handle.size() + 1]; memcpy(buf, new_handle.c_str(), new_handle.size() + 1); *handle = buf; } } |
创建基本图baseGraph
首先会调用c_api.cc中的ExtendSessionGraphHelper完成图的扩展。
bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { if (session->graph != nullptr) { // Take the graph lock before the session lock to avoid deadlock. This is // safe since session->graph does not change. session->graph->mu.lock(); mutex_lock session_lock(session->mu); const Graph& graph = session->graph->graph; const string& mutation_warning = session->graph->sessions[session]; if (!mutation_warning.empty()) { // TODO(b/74949947): turn this back into an error status LOG(WARNING) << mutation_warning; session->graph->sessions[session].clear(); } const auto num_nodes = graph.num_node_ids(); if (session->last_num_graph_nodes < num_nodes) { // TODO(nolivia): check this on a subset of the graph instead of all of // it. status->status = graph::ValidateGraphHasNoCycle(session->graph->graph); if (TF_GetCode(status) != TF_OK) { session->graph->mu.unlock(); return false; } GraphDef graph_def; *graph_def.mutable_versions() = graph.versions(); // Fill graph_def with nodes with ids in the range // [session->last_num_graph_nodes, num_nodes), that is the nodes // added since the last TF_SessionRun() call. for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) { Node* const node = graph.FindNodeId(id); if (node != nullptr && node->IsOp()) { NodeDef* const node_def = graph_def.add_node(); *node_def = node->def(); } } *graph_def.mutable_library() = graph.flib_def().ToProto(); session->graph->mu.unlock(); status->status = session->session->Extend(std::move(graph_def)); if (TF_GetCode(status) != TF_OK) { // Contract is we always delete input_values[i]. return false; } // Note: session->session is not modified if Extend() fails, so // we only set last_num_graph_nodes if it succeeds. session->last_num_graph_nodes = num_nodes; } else { session->graph->mu.unlock(); } } return true; } |
这儿会调用direct_session.cc的DirectSession::Extend方法,DirectSession::ExtendLocked方法。
Status DirectSession::Extend(GraphDef&& graph) { TF_RETURN_IF_ERROR(CheckNotClosed()); mutex_lock l(graph_state_lock_); return ExtendLocked(std::move(graph)); } Status DirectSession::ExtendLocked(GraphDef graph) { if (!(flib_def_ && execution_state_)) { // If this is the first call, we can initialize the execution state // with `graph` and do not need to call `Extend()`. // NOTE(mrry): The function library created here will be used for // all subsequent extensions of the graph. flib_def_.reset( new FunctionLibraryDefinition(OpRegistry::Global(), graph.library())); GraphExecutionStateOptions options; options.device_set = &device_set_; options.session_options = &options_; options.session_handle = session_handle_; TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph( std::move(graph), options, &execution_state_)); graph_created_ = true; } else { TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library())); std::unique_ptr<GraphExecutionState> state; // TODO(mrry): Rewrite GraphExecutionState::Extend() to take `graph` by // value and move `graph` in here. TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state)); execution_state_.swap(state); } return Status::OK(); } |
调用https://gitee.com/mirrors/tensorflow/blob/v1.15.0/tensorflow/core/common_runtime/graph_execution_state.cc的GraphExecutionState::MakeForBaseGraph方法创建基本图。
ConvertGraphDefToGraph方法把图的定义转换为计算图。
ret->InitBaseGraph会初始化图,检查图中节点的op kernel是否有device可用。并得到在每个device上的优先级。
/* static */ Status GraphExecutionState::MakeForBaseGraph( GraphDef&& graph_def, const GraphExecutionStateOptions& options, std::unique_ptr<GraphExecutionState>* out_state) { #ifndef __ANDROID__ VLOG(4) << "Graph proto is \n" << graph_def.DebugString(); #endif // __ANDROID__ auto flib_def = absl::make_unique<FunctionLibraryDefinition>( OpRegistry::Global(), graph_def.library()); TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&graph_def, *flib_def, 0)); if (options.session_options->config.graph_options().place_pruned_graph() || !options.session_options->config.experimental() .optimize_for_static_graph()) { auto ret = absl::WrapUnique(new GraphExecutionState( absl::make_unique<GraphDef>(std::move(graph_def)), std::move(flib_def), options)); // When place_pruned_graph is true, a different Graph* will be initialized // each time we prune the original graph, so there is no need to // construct a Graph* in this case. if (!options.session_options->config.graph_options().place_pruned_graph()) { auto base_graph = absl::make_unique<Graph>(OpRegistry::Global()); TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *ret->original_graph_def_, base_graph.get())); TF_RETURN_IF_ERROR(ret->InitBaseGraph(std::move(base_graph))); } *out_state = std::move(ret); } else { auto ret = absl::WrapUnique( new GraphExecutionState(nullptr, std::move(flib_def), options)); auto base_graph = absl::make_unique<Graph>(OpRegistry::Global()); TF_RETURN_IF_ERROR( ConvertGraphDefToGraph({}, std::move(graph_def), base_graph.get())); TF_RETURN_IF_ERROR(ret->InitBaseGraph(std::move(base_graph))); *out_state = std::move(ret); } return Status::OK(); } |
Session的设置
这儿会调用direct_session.cc的DirectSession::PRunSetup方法完成设置。
首先调用GetOrCreateExecutors创建executors。见TensorFlow Executor的创建
然后异步运行item.executor->RunAsync(args, barrier->Get());直到barrier完成。
注意到(thread::ThreadPool* pool = thread_pools_[0].first;)和( args.runner = [this, pool](Executor::Args::Closure c) { pool->Schedule(std::move(c)); };),executor里的线程用的就是这儿的ThreadPool* pool。
Status DirectSession::PRunSetup(const std::vector<string>& input_names, const std::vector<string>& output_names, const std::vector<string>& target_nodes, string* handle) { TF_RETURN_IF_ERROR(CheckNotClosed()); TF_RETURN_IF_ERROR(CheckGraphCreated("PRunSetup()")); // RunOptions is not available in PRunSetup, so use thread pool 0. thread::ThreadPool* pool = thread_pools_[0].first; // Check if we already have an executor for these arguments. ExecutorsAndKeys* executors_and_keys; // TODO(cais): TFDBG support for partial runs. DebugOptions debug_options; RunStateArgs run_state_args(debug_options); run_state_args.is_partial_run = true; TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_names, output_names, target_nodes, &executors_and_keys, &run_state_args)); // Create the run state and save it for future PRun calls. Executor::Args args; args.step_id = step_id_counter_.fetch_add(1); RunState* run_state = new RunState(input_names, output_names, args.step_id, &devices_); run_state->rendez = new IntraProcessRendezvous(device_mgr_.get()); { mutex_lock l(executor_lock_); if (!partial_runs_ .emplace(run_state_args.handle, std::unique_ptr<RunState>(run_state)) .second) { return errors::Internal("The handle '", run_state_args.handle, "' created for this partial run is not unique."); } } // Start parallel Executors. const size_t num_executors = executors_and_keys->items.size(); ExecutorBarrier* barrier = new ExecutorBarrier( num_executors, run_state->rendez, [run_state](const Status& ret) { if (!ret.ok()) { mutex_lock l(run_state->mu_); run_state->status.Update(ret); } run_state->executors_done.Notify(); }); args.rendezvous = run_state->rendez; args.cancellation_manager = cancellation_manager_; // Note that Collectives are not supported in partial runs // because RunOptions is not passed in so we can't know whether // their use is intended. args.collective_executor = nullptr; args.runner = [this, pool](Executor::Args::Closure c) { pool->Schedule(std::move(c)); }; args.session_state = &session_state_; args.session_handle = session_handle_; args.tensor_store = &run_state->tensor_store; args.step_container = &run_state->step_container; if (LogMemory::IsEnabled()) { LogMemory::RecordStep(args.step_id, run_state_args.handle); } args.sync_on_finish = sync_on_finish_; if (options_.config.graph_options().build_cost_model()) { run_state->collector.reset(new StepStatsCollector(nullptr)); args.stats_collector = run_state->collector.get(); } for (auto& item : executors_and_keys->items) { item.executor->RunAsync(args, barrier->Get()); } *handle = run_state_args.handle; return Status::OK(); } |
本作品采用知识共享署名 4.0 国际许可协议进行许可。