原创文章,转载请注明: 转载自慢慢的回味
本文链接地址: TensorFlow Executor的创建
回目录
Session中的Executeors是一个线程池,用来执行每个节点Node的计算。在上一篇Session的setup中,其中调用了GetOrCreateExecutors类创建Executeors。
direct_session.cc中的GetOrCreateExecutors方法用于获取设置callable_options的inputs,fetches,target,然后继续调用CreateExecutors方法创建executors。
Status DirectSession::GetOrCreateExecutors( gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs, gtl::ArraySlice<string> target_nodes, ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args) { int64 handle_name_counter_value = -1; if (LogMemory::IsEnabled() || run_state_args->is_partial_run) { handle_name_counter_value = handle_name_counter_.fetch_add(1); } string debug_tensor_watches_summary; if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) { debug_tensor_watches_summary = SummarizeDebugTensorWatches( run_state_args->debug_options.debug_tensor_watch_opts()); } // Fast lookup path, no sorting. const string key = strings::StrCat( absl::StrJoin(inputs, ","), "->", absl::StrJoin(outputs, ","), "/", absl::StrJoin(target_nodes, ","), "/", run_state_args->is_partial_run, "/", debug_tensor_watches_summary); // Set the handle, if it's needed to log memory or for partial run. if (handle_name_counter_value >= 0) { run_state_args->handle = strings::StrCat(key, ";", handle_name_counter_value); } // See if we already have the executors for this run. { mutex_lock l(executor_lock_); // could use reader lock auto it = executors_.find(key); if (it != executors_.end()) { *executors_and_keys = it->second.get(); return Status::OK(); } } // Slow lookup path, the unsorted key missed the cache. // Sort the inputs and outputs, and look up with the sorted key in case an // earlier call used a different order of inputs and outputs. // // We could consider some other signature instead of sorting that // preserves the same property to avoid the sort in the future. std::vector<string> inputs_sorted(inputs.begin(), inputs.end()); std::sort(inputs_sorted.begin(), inputs_sorted.end()); std::vector<string> outputs_sorted(outputs.begin(), outputs.end()); std::sort(outputs_sorted.begin(), outputs_sorted.end()); std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end()); std::sort(tn_sorted.begin(), tn_sorted.end()); const string sorted_key = strings::StrCat( absl::StrJoin(inputs_sorted, ","), "->", absl::StrJoin(outputs_sorted, ","), "/", absl::StrJoin(tn_sorted, ","), "/", run_state_args->is_partial_run, "/", debug_tensor_watches_summary); // Set the handle, if its needed to log memory or for partial run. if (handle_name_counter_value >= 0) { run_state_args->handle = strings::StrCat(sorted_key, ";", handle_name_counter_value); } // See if we already have the executors for this run. { mutex_lock l(executor_lock_); auto it = executors_.find(sorted_key); if (it != executors_.end()) { *executors_and_keys = it->second.get(); // Insert this under the original key. executors_.emplace(key, it->second); return Status::OK(); } } // Nothing found, so create the executors and store in the cache. // The executor_lock_ is intentionally released while executors are // being created. CallableOptions callable_options; for (const string& input : inputs_sorted) { callable_options.add_feed(input); } for (const string& output : outputs_sorted) { callable_options.add_fetch(output); } for (const string& target : tn_sorted) { callable_options.add_target(target); } *callable_options.mutable_run_options()->mutable_debug_options() = run_state_args->debug_options; callable_options.mutable_run_options() ->mutable_experimental() ->set_collective_graph_key(run_state_args->collective_graph_key); std::unique_ptr<ExecutorsAndKeys> ek; std::unique_ptr<FunctionInfo> func_info; TF_RETURN_IF_ERROR( CreateExecutors(callable_options, &ek, &func_info, run_state_args)); // Reacquire the lock, try to insert into the map. mutex_lock l(executor_lock_); // Another thread may have created the entry before us, in which case we will // reuse the already created one. auto insert_result = executors_.emplace( sorted_key, std::shared_ptr<ExecutorsAndKeys>(std::move(ek))); if (insert_result.second) { functions_.push_back(std::move(func_info)); } // Insert the value under the original key, so the fast path lookup will work // if the user uses the same order of inputs, outputs, and targets again. executors_.emplace(key, insert_result.first->second); *executors_and_keys = insert_result.first->second.get(); return Status::OK(); } |
direct_session.cc中的CreateExecutors方法如下。它会调用CreateGraphs方法创建计算图,参加另一篇文章TensorFlow计算图的创建。
然后对创建的每个图进行优化(optimizer.Optimize(lib, options_.env, device, &partition_graph, /*shape_map=*/nullptr);),见另一篇文章TensorFlow计算图的优化分析。
方法TF_RETURN_IF_ERROR(NewExecutor(executor_type, params, std::move(partition_graph), &item->executor));根据上面创建的参数params和分区图partition_graph创建Executor。
参数params中定义了设备(params.device),创建op kernel的lambda表达式(params.create_kernel),以及创建ops的代码段function_library,不同进程间通信机制的rendezvous_factory 等。
item->executor中item的类型为结构体struct PerPartitionExecutorsAndLib,item最终是结构类型struct ExecutorsAndKeys中的vector items的元素。
结构体定义std::unique_ptr
这样executors就创建完成了。
Status DirectSession::CreateExecutors( const CallableOptions& callable_options, std::unique_ptr<ExecutorsAndKeys>* out_executors_and_keys, std::unique_ptr<FunctionInfo>* out_func_info, RunStateArgs* run_state_args) { BuildGraphOptions options; options.callable_options = callable_options; options.use_function_convention = !run_state_args->is_partial_run; options.collective_graph_key = callable_options.run_options().experimental().collective_graph_key(); if (options_.config.experimental() .collective_deterministic_sequential_execution()) { options.collective_order = GraphCollectiveOrder::kEdges; } else if (options_.config.experimental().collective_nccl()) { options.collective_order = GraphCollectiveOrder::kAttrs; } std::unique_ptr<FunctionInfo> func_info(new FunctionInfo); std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys); ek->callable_options = callable_options; std::unordered_map<string, std::unique_ptr<Graph>> graphs; TF_RETURN_IF_ERROR(CreateGraphs( options, &graphs, &func_info->flib_def, run_state_args, &ek->input_types, &ek->output_types, &ek->collective_graph_key)); if (run_state_args->is_partial_run) { ek->graph = std::move(run_state_args->graph); std::unordered_set<StringPiece, StringPieceHasher> names; for (const string& input : callable_options.feed()) { TensorId id(ParseTensorName(input)); names.emplace(id.first); } for (const string& output : callable_options.fetch()) { TensorId id(ParseTensorName(output)); names.emplace(id.first); } for (Node* n : ek->graph->nodes()) { if (names.count(n->name()) > 0) { ek->name_to_node.insert({n->name(), n}); } } } ek->items.reserve(graphs.size()); const auto& optimizer_opts = options_.config.graph_options().optimizer_options(); int graph_def_version = graphs.begin()->second->versions().producer(); const auto* session_metadata = options_.config.experimental().has_session_metadata() ? &options_.config.experimental().session_metadata() : nullptr; func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime( device_mgr_.get(), options_.env, graph_def_version, func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first, nullptr, nullptr, session_metadata)); GraphOptimizer optimizer(optimizer_opts); for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) { const string& partition_name = iter->first; std::unique_ptr<Graph>& partition_graph = iter->second; Device* device; TF_RETURN_IF_ERROR(device_mgr_->LookupDevice(partition_name, &device)); ek->items.resize(ek->items.size() + 1); auto* item = &(ek->items.back()); auto lib = func_info->proc_flr->GetFLR(partition_name); if (lib == nullptr) { return errors::Internal("Could not find device: ", partition_name); } item->flib = lib; LocalExecutorParams params; params.device = device; params.session_metadata = session_metadata; params.function_library = lib; auto opseg = device->op_segment(); params.create_kernel = [this, lib, opseg](const NodeDef& ndef, OpKernel** kernel) { // NOTE(mrry): We must not share function kernels (implemented // using `CallOp`) between subgraphs, because `CallOp::handle_` // is tied to a particular subgraph. Even if the function itself // is stateful, the `CallOp` that invokes it is not. if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) { return lib->CreateKernel(ndef, kernel); } auto create_fn = [lib, &ndef](OpKernel** kernel) { return lib->CreateKernel(ndef, kernel); }; // Kernels created for subgraph nodes need to be cached. On // cache miss, create_fn() is invoked to create a kernel based // on the function library here + global op registry. return opseg->FindOrCreate(session_handle_, ndef.name(), kernel, create_fn); }; params.delete_kernel = [lib](OpKernel* kernel) { if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) delete kernel; }; params.rendezvous_factory = [](const int64, const DeviceMgr* device_mgr, Rendezvous** r) { *r = new IntraProcessRendezvous(device_mgr); return Status::OK(); }; optimizer.Optimize(lib, options_.env, device, &partition_graph, /*shape_map=*/nullptr); // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph. const DebugOptions& debug_options = options.callable_options.run_options().debug_options(); if (!debug_options.debug_tensor_watch_opts().empty()) { TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug( debug_options, partition_graph.get(), params.device)); } TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()), device->name(), partition_graph.get())); // NewLocalExecutor takes ownership of partition_graph. item->graph = partition_graph.get(); item->executor = nullptr; item->device = device; auto executor_type = options_.config.experimental().executor_type(); TF_RETURN_IF_ERROR(NewExecutor( executor_type, params, std::move(partition_graph), &item->executor)); } // Cache the mapping from input/output names to graph elements to // avoid recomputing it every time. if (!run_state_args->is_partial_run) { // For regular `Run()`, we use the function calling convention, and so // maintain a mapping from input/output names to // argument/return-value ordinal index. for (int i = 0; i < callable_options.feed().size(); ++i) { const string& input = callable_options.feed(i); ek->input_name_to_index[input] = i; } for (int i = 0; i < callable_options.fetch().size(); ++i) { const string& output = callable_options.fetch(i); ek->output_name_to_index[output] = i; } } else { // For `PRun()`, we use the rendezvous calling convention, and so // maintain a mapping from input/output names to rendezvous keys. // // We always use the first device as the device name portion of the // key, even if we're feeding another graph. for (int i = 0; i < callable_options.feed().size(); ++i) { const string& input = callable_options.feed(i); ek->input_name_to_rendezvous_key[input] = GetRendezvousKey( input, device_set_.client_device()->attributes(), FrameAndIter(0, 0)); } for (int i = 0; i < callable_options.fetch().size(); ++i) { const string& output = callable_options.fetch(i); ek->output_name_to_rendezvous_key[output] = GetRendezvousKey(output, device_set_.client_device()->attributes(), FrameAndIter(0, 0)); } } *out_executors_and_keys = std::move(ek); *out_func_info = std::move(func_info); return Status::OK(); } |
上面的代码会调用executor_factory.cc中的NewExecutor, 接着依次调用executor.cc的NewLocalExecutor和ExecutorImpl->Initialize方法。其中,在Initialize方法中完成op kernel的创建赋值(Status s = params_.create_kernel(n->def(), &item->kernel);)。
Status NewExecutor(const string& executor_type, const LocalExecutorParams& params, std::unique_ptr<const Graph> graph, std::unique_ptr<Executor>* out_executor) { ExecutorFactory* factory = nullptr; TF_RETURN_IF_ERROR(ExecutorFactory::GetFactory(executor_type, &factory)); return factory->NewExecutor(params, std::move(graph), out_executor); } Status NewLocalExecutor(const LocalExecutorParams& params, std::unique_ptr<const Graph> graph, Executor** executor) { ExecutorImpl* impl = new ExecutorImpl(params, std::move(graph)); const Status s = impl->Initialize(); if (s.ok()) { *executor = impl; } else { delete impl; } return s; } Status ExecutorImpl::Initialize() { //初始化gview_ gview_.Initialize(graph_.get()); // Build the information about frames in this subgraph. ControlFlowInfo cf_info; TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph_.get(), &cf_info)); // Cache this value so we make this virtual function call once, rather // that O(# steps * # nodes per step) times. device_record_tensor_accesses_ = params_.device->RequiresRecordingAccessedTensors(); for (auto& it : cf_info.unique_frame_names) { EnsureFrameInfo(it)->nodes = new std::vector<const Node*>; } // Preprocess every node in the graph to create an instance of op // kernel for each node. //迭代graph_里面所有的node,对每个node创建NodeItem并填充进gview_ for (const Node* n : graph_->nodes()) { const int id = n->id(); const string& frame_name = cf_info.frame_names[id]; FrameInfo* frame_info = EnsureFrameInfo(frame_name); // See if this node is a root node, and if so, add to root_nodes_. if (n->in_edges().empty()) { root_nodes_.push_back(n); } NodeItem* item = gview_.node(id); item->node = n; item->input_start = frame_info->total_inputs; frame_info->total_inputs += n->num_inputs(); //调用params_.create_kernel方法完成kernel的创建,即对每个op的定义创建对应的实例kernel //<a href="https://liuxiaofei.com.cn/blog/tf-operation%e7%9a%84%e5%88%9b%e5%bb%ba/" rel="noopener" target="_blank">TF Operation的创建</a>讲了具体创建流程。 Status s = params_.create_kernel(n->def(), &item->kernel); if (!s.ok()) { item->kernel = nullptr; s = AttachDef(s, *n); LOG(ERROR) << "Executor failed to create kernel. " << s; return s; } CHECK(item->kernel); item->kernel_is_async = (item->kernel->AsAsync() != nullptr); item->is_merge = IsMerge(n); item->is_enter = IsEnter(n); if (item->is_enter) { bool is_constant_enter; TF_RETURN_IF_ERROR( GetNodeAttr(n->attrs(), "is_constant", &is_constant_enter)); item->is_constant_enter = is_constant_enter; } else { item->is_constant_enter = false; } item->is_exit = IsExit(n); item->is_control_trigger = IsControlTrigger(n); item->is_sink = IsSink(n); item->is_enter_exit_or_next_iter = (IsEnter(n) || IsExit(n) || IsNextIteration(n)); // Compute the maximum values we'll store for this node in the // pending counts data structure, and allocate a handle in // that frame's pending counts data structure that has enough // space to store these maximal count values. size_t max_pending, max_dead; GetMaxPendingCounts(n, &max_pending, &max_dead); item->pending_id = frame_info->pending_counts_layout.CreateHandle(max_pending, max_dead); // Initialize static information about the frames in the graph. frame_info->nodes->push_back(n); if (IsEnter(n)) { string enter_name; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &enter_name)); EnsureFrameInfo(enter_name)->input_count++; } } // Initialize PendingCounts only after item->pending_id is initialized for // all nodes. InitializePending(graph_.get(), cf_info); return gview_.SetAllocAttrs(graph_.get(), params_.device); } |
本作品采用知识共享署名 4.0 国际许可协议进行许可。