TF计算图优化完成后,在Session开始执行后就轮到TF 计算图的执行了。Tensorflow会根据计算图的节点信息,首先找到一个没有输入的节点作为根节点,创建一个task交给线程池取执行。每个节点完成后会根据Edge通知下游节点计算,直到所有节点完成计算,然后输出结果。
计算图执行
还是以“TF 生成计算图”中的程序为例:
进入到如下的堆栈,从DirectSession::CreateExecutors()开始创建Executors。
tensorflow::DirectSession::CreateExecutors() at direct_session.cc:1,301 0x7ffff591344b tensorflow::DirectSession::GetOrCreateExecutors() at direct_session.cc:1,435 0x7ffff5914bec tensorflow::DirectSession::PRunSetup() at direct_session.cc:849 0x7ffff59162b7 TF_SessionPRunSetup() at c_api.cc:2,668 0x7ffff2090e54 main() at TensorflowTest2.cpp:58 0x40522e |
先通过调用函数“CreateGraphs”创建“graphs”,且图已经分区。
然后通过“for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {”对分区图迭代创建Executor:
1 通过“device_mgr_->LookupDevice”得到当前分区图的device;
2 创建item(类型为ExecutorsAndKeys.PerPartitionExecutorsAndLib)和params(类型为LocalExecutorParams);
3 调用“NewExecutor( executor_type, params, std::move(partition_graph), &item->executor)”创建Executor。
//code in direct_session.h struct ExecutorsAndKeys { ExecutorsAndKeys() : step_count(0) {} std::atomic_int_fast64_t step_count; std::unique_ptr<Graph> graph; NameNodeMap name_to_node; std::vector<PerPartitionExecutorsAndLib> items; std::unordered_map<string, size_t> input_name_to_index; std::unordered_map<string, string> input_name_to_rendezvous_key; std::unordered_map<string, size_t> output_name_to_index; std::unordered_map<string, string> output_name_to_rendezvous_key; DataTypeVector input_types; DataTypeVector output_types; CallableOptions callable_options; int64 collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey; }; //code in direct_session.cc Status DirectSession::CreateExecutors( const CallableOptions& callable_options, std::unique_ptr<ExecutorsAndKeys>* out_executors_and_keys, std::unique_ptr<FunctionInfo>* out_func_info, RunStateArgs* run_state_args) { BuildGraphOptions options; options.callable_options = callable_options; options.use_function_convention = !run_state_args->is_partial_run; options.collective_graph_key = callable_options.run_options().experimental().collective_graph_key(); if (options_.config.experimental() .collective_deterministic_sequential_execution()) { options.collective_order = GraphCollectiveOrder::kEdges; } std::unique_ptr<FunctionInfo> func_info(new FunctionInfo); std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys); ek->callable_options = callable_options; std::unordered_map<string, std::unique_ptr<Graph>> graphs; TF_RETURN_IF_ERROR(CreateGraphs( options, &graphs, &func_info->flib_def, run_state_args, &ek->input_types, &ek->output_types, &ek->collective_graph_key)); ...... GraphOptimizer optimizer(optimizer_opts); for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) { const string& partition_name = iter->first; std::unique_ptr<Graph>& partition_graph = iter->second; Device* device; TF_RETURN_IF_ERROR(device_mgr_->LookupDevice(partition_name, &device)); ek->items.resize(ek->items.size() + 1); auto* item = &(ek->items.back()); auto lib = func_info->proc_flr->GetFLR(partition_name); /* func_info->proc_flr : ProcessFunctionLibraryRuntime类型 func_info->proc_flr->GetFLR : FunctionLibraryRuntime类型 FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR */ if (lib == nullptr) { return errors::Internal("Could not find device: ", partition_name); } item->flib = lib; LocalExecutorParams params; params.device = device;/*比如为CPU*/ params.function_library = lib;/*函数方法的运行时,类型为FunctionLibraryRuntime*/ /* ProcessFunctionLibraryRuntime( const DeviceMgr* device_mgr, Env* env, int graph_def_version, const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options, thread::ThreadPool* thread_pool = nullptr, DistributedFunctionLibraryRuntime* parent = nullptr); */ auto opseg = device->op_segment(); /*创建Kernel Operation的lambda方法*/ params.create_kernel = [this, lib, opseg](const NodeDef& ndef, OpKernel** kernel) { // NOTE(mrry): We must not share function kernels (implemented // using `CallOp`) between subgraphs, because `CallOp::handle_` // is tied to a particular subgraph. Even if the function itself // is stateful, the `CallOp` that invokes it is not. if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) { return lib->CreateKernel(ndef, kernel); } auto create_fn = [lib, &ndef](OpKernel** kernel) { /* Status FunctionLibraryRuntimeImpl::CreateKernel( const NodeDef& ndef, const FunctionLibraryDefinition* lib_def, OpKernel** kernel) { ...... if (lib_def->Find(ndef.op()) == nullptr) { // A primitive operation. Creates the registered kernel. return CreateNonCachedKernel(device_, this, ndef, graph_def_version_, kernel); } */ return lib->CreateKernel(ndef, kernel); }; // Kernels created for subgraph nodes need to be cached. On // cache miss, create_fn() is invoked to create a kernel based // on the function library here + global op registry. return opseg->FindOrCreate(session_handle_, ndef.name(), kernel, create_fn); }; params.delete_kernel = [lib](OpKernel* kernel) { if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) delete kernel; }; optimizer.Optimize(lib, options_.env, device, &partition_graph, /*shape_map=*/nullptr); // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph. const DebugOptions& debug_options = options.callable_options.run_options().debug_options(); if (!debug_options.debug_tensor_watch_opts().empty()) { TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug( debug_options, partition_graph.get(), params.device)); } TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()), device->name(), partition_graph.get())); // NewLocalExecutor takes ownership of partition_graph. item->graph = partition_graph.get(); item->executor = nullptr; item->device = device; auto executor_type = options_.config.experimental().executor_type(); TF_RETURN_IF_ERROR(NewExecutor( executor_type, params, std::move(partition_graph), &item->executor)); } ...... return Status::OK(); } |
在代码executor.cc中,一个名叫“DEFAULT”的factory注册到了系统中,所以executor_factory.cc中的函数“GetFactory”会获取到前面的factory。
最后会调用“NewLocalExecutor”方法,创建ExecutorImpl并初始化。
注意到初始化方法里面的代码“root_nodes_.push_back(n)”,它把整个图的root node收集到root_nodes_了。
在后面,通过方法RunAsync开始所有图中节点的第一个task。当这些root node的task完成后,它会触发它的下游node的task进行计算。
//code in executor_factory.cc Status ExecutorFactory::GetFactory(const string& executor_type, ExecutorFactory** out_factory) { tf_shared_lock l(executor_factory_lock); auto iter = executor_factories()->find(executor_type); if (iter == executor_factories()->end()) { return errors::NotFound( "No executor factory registered for the given executor type: ", executor_type, " ", RegisteredFactoriesErrorMessageLocked()); } *out_factory = iter->second; return Status::OK(); } Status NewExecutor(const string& executor_type, const LocalExecutorParams& params, std::unique_ptr<const Graph> graph, std::unique_ptr<Executor>* out_executor) { ExecutorFactory* factory = nullptr; TF_RETURN_IF_ERROR(ExecutorFactory::GetFactory(executor_type, &factory)); return factory->NewExecutor(params, std::move(graph), out_executor); } //code in executor.cc namespace { class DefaultExecutorRegistrar { public: DefaultExecutorRegistrar() { Factory* factory = new Factory; ExecutorFactory::Register("", factory); ExecutorFactory::Register("DEFAULT", factory); } private: class Factory : public ExecutorFactory { Status NewExecutor(const LocalExecutorParams& params, std::unique_ptr<const Graph> graph, std::unique_ptr<Executor>* out_executor) override { Executor* ret = nullptr; TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(graph), &ret)); out_executor->reset(ret); return Status::OK(); } }; }; static DefaultExecutorRegistrar registrar; } //code in executor.cc Status NewLocalExecutor(const LocalExecutorParams& params, std::unique_ptr<const Graph> graph, Executor** executor) { ExecutorImpl* impl = new ExecutorImpl(params, std::move(graph)); const Status s = impl->Initialize(); if (s.ok()) { *executor = impl; } else { delete impl; } return s; } Status ExecutorImpl::Initialize() { gview_.Initialize(graph_.get()); ...... // Preprocess every node in the graph to create an instance of op // kernel for each node. for (const Node* n : graph_->nodes()) { const int id = n->id(); const string& frame_name = cf_info.frame_names[id]; FrameInfo* frame_info = EnsureFrameInfo(frame_name); // See if this node is a root node, and if so, add to root_nodes_. if (n->in_edges().empty()) { root_nodes_.push_back(n); } ...... |
上面的Executors创建完成后就需要对executors_and_keys进行异步执行item.executor->RunAsync:
//code in direct_session.cc Status DirectSession::PRunSetup(const std::vector<string>& input_names, const std::vector<string>& output_names, const std::vector<string>& target_nodes, string* handle) { ...... TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_names, output_names, target_nodes, &executors_and_keys, &run_state_args)); ...... for (auto& item : executors_and_keys->items) { item.executor->RunAsync(args, barrier->Get()); } *handle = run_state_args.handle; return Status::OK(); } //code in executor.cc void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) { (new ExecutorState(args, this))->RunAsync(std::move(done)); } |
从impl_->root_nodes_获取所有的TaggedNode加入到ready(TaggedNodeSeq)中,然后调用ScheduleReady(ready, nullptr)对task进行schedule:
这是所有node计算的入口。
//code in executor.cc ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl) ...... void ExecutorState::RunAsync(Executor::DoneCallback done) { const Graph* graph = impl_->graph_.get(); TaggedNodeSeq ready; // Ask the device to fill in the device context map. Device* device = impl_->params_.device; const Status fill_status = device->FillContextMap(graph, &device_context_map_); if (!fill_status.ok()) { delete this; done(fill_status); return; } // Initialize the ready queue. for (const Node* n : impl_->root_nodes_) { DCHECK_EQ(n->in_edges().size(), 0); ready.push_back(TaggedNode{n, root_frame_, 0, false}); } if (ready.empty()) { delete this; done(Status::OK()); } else { num_outstanding_ops_ = ready.size(); root_frame_->iterations[0]->outstanding_ops = ready.size(); done_cb_ = std::move(done); // Schedule to run all the ready ops in thread pool. ScheduleReady(ready, nullptr); } } |
executor.cc中的ScheduleReady方法
对每个tagged_node进行处理Process,最后用函数“[=]() { Process(tagged_node, scheduled_nsec); }”创建一个task,然后加入队列Queue;如果失败,直接调用env_.ExecuteTask(t)进行运行。
//code in executor.cc void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready, TaggedNodeReadyQueue* inline_ready) { if (ready.empty()) return; int64 scheduled_nsec = 0; if (stats_collector_) { scheduled_nsec = nodestats::NowInNsec(); } if (inline_ready == nullptr) { // Schedule to run all the ready ops in thread pool. for (auto& tagged_node : ready) { runner_([=]() { Process(tagged_node, scheduled_nsec); }); /* runner_在前面定义: args.runner = [this, pool](Executor::Args::Closure c) { SchedClosure(pool, std::move(c)); }; */ } return; } //code in direct_session.cc void DirectSession::SchedClosure(thread::ThreadPool* pool, std::function<void()> c) { if (pool != nullptr) { pool->Schedule(std::move(c)); } else { c(); } } //code in threadpool.cc void ThreadPool::Schedule(std::function<void()> fn) { CHECK(fn != nullptr); impl_->Schedule(std::move(fn)); } //code in NonBlockingThreadPool.h void Schedule(std::function<void()> fn) EIGEN_OVERRIDE { ScheduleWithHint(std::move(fn), 0, num_threads_); } void ScheduleWithHint(std::function<void()> fn, int start, int limit) override { /* fn为前面的“[=]() { Process(tagged_node, scheduled_nsec); }”lambda表达式函数。 */ Task t = env_.CreateTask(std::move(fn)); PerThread* pt = GetPerThread(); if (pt->pool == this) { // Worker thread of this pool, push onto the thread's queue. Queue& q = thread_data_[pt->thread_id].queue; t = q.PushFront(std::move(t)); } else { // A free-standing thread (or worker of another pool), push onto a random // queue. eigen_plain_assert(start < limit); eigen_plain_assert(limit <= num_threads_); int num_queues = limit - start; int rnd = Rand(&pt->rand) % num_queues; eigen_plain_assert(start + rnd < limit); Queue& q = thread_data_[start + rnd].queue; t = q.PushBack(std::move(t)); } if (!t.f) { ec_.Notify(false); } else { env_.ExecuteTask(t); // Push failed, execute directly. } } |
加入队列的task会在前面创建的thread pool执行:
//code in NonBlockingThreadPool.h ThreadPoolTempl(int num_threads, bool allow_spinning, Environment env = Environment()) : env_(env), num_threads_(num_threads), allow_spinning_(allow_spinning), thread_data_(num_threads), all_coprimes_(num_threads), waiters_(num_threads), blocked_(0), spinning_(0), done_(false), cancelled_(false), ec_(waiters_) { ...... thread_data_.resize(num_threads_); for (int i = 0; i < num_threads_; i++) { SetStealPartition(i, EncodePartition(0, num_threads_)); thread_data_[i].thread.reset( env_.CreateThread([this, i]() { WorkerLoop(i); })); } ...... } void WorkerLoop(int thread_id) { ...... if (num_threads_ == 1) { ...... } else { while (!cancelled_) { Task t = q.PopFront(); ...... if (t.f) { env_.ExecuteTask(t); } } } } |
当node计算完成后,调用executor.cc中的PropagateOutputs方法根据Edge连接把下游节点添加进ready序列待执行,接着调用executor.cc中的NodeDone方法通知下游节点执行,:
通过方法ScheduleReady进行下一步的ScheduleWithHint。
//code in executor.cc void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { WithContext wc(context_); const GraphView& gview = impl_->gview_; TaggedNodeSeq ready; TaggedNodeReadyQueue inline_ready; ...... EntryVector outputs; bool completed = false; inline_ready.push_back(tagged_node); while (!inline_ready.empty()) { ...... // Set up compute params. OpKernel* op_kernel = item.kernel; params.op_kernel = op_kernel; params.frame_iter = FrameAndIter(input_frame->frame_id, input_iter); params.is_input_dead = is_input_dead; params.output_attr_array = item.output_attrs(); params.forward_from_array = item.forward_from(); if (item.kernel_is_async) { // Asynchronous computes. AsyncOpKernel* async = item.kernel->AsAsync(); DCHECK(async != nullptr); launched_asynchronously = true; AsyncState* state = new AsyncState(params, tagged_node, &item, first_input, stats); auto done = [this, state]() { Device* device = impl_->params_.device; NodeExecStatsInterface* stats = state->stats; // Shorthand Entry* first_input = state->first_input; // Shorthand nodestats::SetOpEnd(stats); EntryVector outputs; Status s = ProcessOutputs(*state->item, &state->ctx, &outputs, stats); nodestats::SetMemory(stats, &state->ctx); if (vlog_) { VLOG(2) << "Async kernel done: " << state->item->node->id() << " step " << step_id_ << " " << SummarizeNode(*state->item->node) << (state->tagged_node.is_dead ? " is dead" : "") << " device: " << device->name(); } // Clears inputs. const int num_inputs = state->item->num_inputs; for (int i = 0; i < num_inputs; ++i) { (first_input + i)->ClearVal(); } FrameState* input_frame = state->tagged_node.input_frame; const int64 input_iter = state->tagged_node.input_iter; const int id = state->tagged_node.node->id(); MaybeMarkCompleted(input_frame, input_iter, id); TaggedNodeSeq ready; if (s.ok()) { PropagateOutputs(state->tagged_node, state->item, &outputs, &ready); } outputs.clear(); if (s.ok() && impl_->device_record_tensor_accesses_) { // Get the list of all tensors accessed during the execution TensorReferenceVector accessed; state->ctx.retrieve_accessed_tensors(&accessed); nodestats::SetReferencedTensors(stats, accessed); // callee takes ownership of the vector device->ConsumeListOfAccessedTensors(state->ctx.op_device_context(), accessed); } const bool completed = NodeDone(s, state->item->node, ready, stats, nullptr); delete state; if (completed) ScheduleFinish(); }; nodestats::SetOpStart(stats); { profiler::TraceMe activity( [&] { return strings::StrCat( op_kernel->name(), ":", op_kernel->type_string(), "#id=", step_container_ ? step_container_->step_id() : 0, ",device=", device->name(), ",async=true#"); }, profiler::GetTFTraceMeLevel(op_kernel->IsExpensive())); device->ComputeAsync(async, &state->ctx, done); } } else { ...... nodestats::SetOpEnd(stats); s = ProcessOutputs(item, &ctx, &outputs, stats); if (s.ok() && impl_->device_record_tensor_accesses_) { // Get the list of all tensors accessed during the execution ctx.retrieve_accessed_tensors(&accessed_tensors); device_context = ctx.op_device_context(); } nodestats::SetMemory(stats, &ctx); } } if (!launched_asynchronously) { ...... // Propagates outputs. if (s.ok()) { PropagateOutputs(tagged_node, &item, &outputs, &ready); } outputs.clear(); if (!accessed_tensors.empty()) { nodestats::SetReferencedTensors(stats, accessed_tensors); // device_context is set above in synchronous computes device->ConsumeListOfAccessedTensors(device_context, accessed_tensors); } if (stats) { scheduled_nsec = nodestats::NowInNsec(); } // Postprocess. completed = NodeDone(s, item.node, ready, stats, &inline_ready); } } // while !inline_ready.empty() // This thread of computation is done if completed = true. if (completed) Finish(); } |
executor.cc中的PropagateOutputs调用output_frame->ActivateNodes,ActivateNodes进而执行ready->emplace_back(dst_item->node, this, iter, dst_dead);完成节点添加进ready。
NodeDone方法调用ScheduleReady继续任务执行。
void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node, const NodeItem* item, EntryVector* outputs, TaggedNodeSeq* ready) { auto activity_handle = absl::make_unique<profiler::TraceMe>( [&]() { return strings::StrCat("ExecutorPropagateOutputs:", item->kernel->name(), "#id=", step_id_, "#"); }, profiler::GetTFTraceMeLevel(/*is_expensive=*/false)); const Node* node = tagged_node.node; FrameState* input_frame = tagged_node.input_frame; const int64 input_iter = tagged_node.input_iter; const bool is_dead = tagged_node.is_dead; // Propagates outputs along out edges, and puts newly ready nodes // into the ready queue. ready->clear(); bool is_frame_done = false; FrameState* output_frame = input_frame; int64 output_iter = input_iter; if (!item->is_enter_exit_or_next_iter) { // Fast path for nodes types that don't need special handling DCHECK_EQ(input_frame, output_frame); // Normal path for most nodes mutex_lock l(input_frame->mu); output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); is_frame_done = input_frame->DecrementOutstandingOpsLocked( &impl_->gview_, input_iter, ready); } else if (item->is_enter) { FindOrCreateChildFrame(input_frame, input_iter, node, &output_frame); output_iter = 0; { const NodeItem* item = impl_->gview_.node(node->id()); mutex_lock l(output_frame->mu); if (item->is_constant_enter) { // Propagate to all active iterations if this is a loop invariant. output_frame->AddLoopInv(item, (*outputs)[0], ready); } else { output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); } output_frame->num_pending_inputs--; } is_frame_done = input_frame->DecrementOutstandingOps(&impl_->gview_, input_iter, ready); } else if (item->is_exit) { if (is_dead) { mutex_lock l(input_frame->mu); // Stop and remember this node if it is a dead exit. if (input_iter == input_frame->iteration_count) { input_frame->dead_exits.push_back(node); } is_frame_done = input_frame->DecrementOutstandingOpsLocked( &impl_->gview_, input_iter, ready); } else { output_frame = input_frame->parent_frame; output_iter = input_frame->parent_iter; { mutex_lock l(output_frame->mu); output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); } is_frame_done = input_frame->DecrementOutstandingOps(&impl_->gview_, input_iter, ready); } } else { DCHECK(IsNextIteration(node)); mutex_lock l(input_frame->mu); if (is_dead) { // Stop the deadness propagation. output_frame = nullptr; } else { if (input_iter == input_frame->iteration_count && input_frame->num_outstanding_iterations == input_frame->max_parallel_iterations) { // Reached the maximum for parallel iterations. input_frame->next_iter_roots.push_back({node, (*outputs)[0]}); output_frame = nullptr; } else { // If this is a new iteration, start it. if (input_iter == input_frame->iteration_count) { input_frame->IncrementIteration(&impl_->gview_, ready); } output_iter = input_iter + 1; } } if (output_frame != nullptr) { // This is the case when node is not Enter, Exit, or NextIteration. DCHECK(input_frame == output_frame); output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); } is_frame_done = input_frame->DecrementOutstandingOpsLocked( &impl_->gview_, input_iter, ready); } // At this point, this node is completely done. We also know if the // completion of this node makes its frame completed. if (is_frame_done) { FrameState* parent_frame = input_frame->parent_frame; const int64 parent_iter = input_frame->parent_iter; DeleteFrame(input_frame, ready); if (parent_frame != nullptr) { // The completion of frame may cause completions in its parent frame. // So clean things up recursively. CleanupFramesIterations(parent_frame, parent_iter, ready); } } } bool ExecutorState::NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready, NodeExecStatsInterface* stats, TaggedNodeReadyQueue* inline_ready) { nodestats::SetAllEnd(stats); if (stats) { if (stats_collector_) { stats->Done(impl_->params_.device->name()); } else { delete stats; } } bool abort_run = false; if (!s.ok()) { // Some error happened. This thread of computation is done. mutex_lock l(mu_); if (status_.ok()) { abort_run = true; // If execution has been cancelled, mark any new errors as being derived. // This ensures any errors triggered by cancellation are marked as // derived. if (cancellation_manager_ && cancellation_manager_->IsCancelled()) { status_ = StatusGroup::MakeDerived(s); } else { status_ = s; } } } if (abort_run) { TRACEPRINTF("StartAbort: %s", s.ToString().c_str()); if (cancellation_manager_) { // only log when the abort happens during the actual run time. auto device_name = impl_->params_.device->name(); // Use VLOG instead of LOG(warning) because error status is expected when // the executor is run under the grappler optimization phase or when // iterating through a tf.data input pipeline. VLOG(1) << "[" << device_name << "] Executor start aborting: " << s; } if (rendezvous_) { rendezvous_->StartAbort(s); } if (collective_executor_) { collective_executor_->StartAbort(s); } if (cancellation_manager_) { cancellation_manager_->StartCancel(); } } bool completed = false; const size_t ready_size = ready.size(); if (ready_size == 0 || !s.ok()) { completed = (num_outstanding_ops_.fetch_sub(1) == 1); } else if (ready_size > 1) { num_outstanding_ops_.fetch_add(ready_size - 1, std::memory_order_relaxed); } // Schedule the ready nodes in 'ready'. if (s.ok()) { ScheduleReady(ready, inline_ready); } return completed; } void ExecutorState::FrameState::ActivateNodes(const NodeItem* item, const bool is_dead, int64 iter, EntryVector* outputs, TaggedNodeSeq* ready) { const GraphView& gview = executor->gview_; IterationState* iter_state = GetIteration(iter); const size_t num_output_edges = item->num_output_edges; const EdgeInfo* edges = item->output_edge_list(); Entry* input_tensors = iter_state->input_tensors; for (size_t out_index = 0; out_index < num_output_edges; out_index++) { const EdgeInfo& e = edges[out_index]; const int dst_id = e.dst_id; const NodeItem* dst_item = gview.node(dst_id); const PendingCounts::Handle dst_pending_id = dst_item->pending_id; const int src_slot = e.output_slot; // TODO(yuanbyu): We don't need this if we require the subgraph // given to an executor not to contain a sink node. if (dst_item->is_sink) continue; bool dst_dead = false; bool dst_ready = false; // True iff this input for dst is needed. We only set this input for // dst if this flag is true. This is needed to make the thread safety // analysis happy. const bool is_control_edge = (src_slot == Graph::kControlSlot); bool dst_need_input = !is_control_edge; if (dst_item->is_merge) { // A merge node is ready if all control inputs have arrived and either // a) a live data input becomes available or b) all data inputs are // dead. For Merge, pending's LSB is set iff a live data input has // arrived. if (is_control_edge) { iter_state->decrement_pending(dst_pending_id, 2); int count = iter_state->pending(dst_pending_id); int dead_cnt = iter_state->dead_count(dst_pending_id); dst_dead = (dead_cnt == dst_item->num_inputs); dst_ready = (count == 0) || ((count == 1) && dst_dead); } else { if ((*outputs)[src_slot].has_value) { // This is a live data input. int count = iter_state->pending(dst_pending_id); iter_state->mark_live(dst_pending_id); // Only the first live edge sets the input and (potentially) // triggers execution. The low bit of count is set if and // only if no live input has been used yet (mark_live clears // it). The node should be started if and only if this is // the first live input and there are no pending control // edges, i.e. count == 1. dst_ready = (count == 1); dst_need_input = ((count & 0x1) == 1); } else { // This is a dead data input. Note that dst_node is dead if node is // a dead enter. We need this to handle properly a while loop on // the untaken branch of a conditional. // TODO(yuanbyu): This is a bit hacky, but a good solution for // now. iter_state->increment_dead_count(dst_pending_id); const int dead_cnt = iter_state->dead_count(dst_pending_id); dst_dead = (dead_cnt == dst_item->num_inputs) || item->is_enter; dst_ready = (iter_state->pending(dst_pending_id) == 1) && dst_dead; dst_need_input = false; } } } else { const bool increment_dead = (is_dead || (!is_control_edge && !(*outputs)[src_slot].has_value)); int pending, dead; iter_state->adjust_for_activation(dst_pending_id, increment_dead, &pending, &dead); dst_dead = (dead > 0); dst_ready = (pending == 0); } if (dst_need_input) { const int dst_slot = e.input_slot; const int dst_loc = dst_item->input_start + dst_slot; if (e.is_last) { input_tensors[dst_loc] = std::move((*outputs)[src_slot]); } else { input_tensors[dst_loc] = (*outputs)[src_slot]; } } // Add dst to the ready queue if it's ready if (dst_ready) { if (dst_item->is_control_trigger) dst_dead = false; ready->emplace_back(dst_item->node, this, iter, dst_dead); iter_state->outstanding_ops++; } } } |
最后,从执行的log也可以验证程序确实按计算图执行的:
executor.cc:1688] Process node: 0 step 1 {{node _SOURCE}} = NoOp[]() executor.cc:1837] Synchronous kernel done: 0 step 1 {{node _SOURCE}} = NoOp[]() executor.cc:1688] Process node: 2 step 1 {{node scalar}} = Const[dtype=DT_INT32, value=Tensor<type: int32 shape: [] values: 2>, ]() executor.cc:1837] Synchronous kernel done: 2 step 1 {{node scalar}} = Const[dtype=DT_INT32, value=Tensor<type: int32 shape: [] values: 2>, ]() executor.cc:1688] Process node: 3 step 1 {{node _recv_A_0}} = _Recv[client_terminated=true, recv, , , tensor_name="A:0", tensor_type=DT_INT32, ]() executor.cc:1688] Process node: 5 step 1 {{node _recv_B_0}} = _Recv[client_terminated=true, recv, , , tensor_name="B:0", tensor_type=DT_INT32, ]() executor.cc:1748] Async kernel done: 3 step 1 {{node _recv_A_0}} = _Recv[client_terminated=true, recv, , , tensor_name="A:0", tensor_type=DT_INT32, ]() executor.cc:1688] Process node: 4 step 1 {{node plus2}} = AddN[N=2, T=DT_INT32, ](_recv_A_0, scalar) executor.cc:1837] Synchronous kernel done: 4 step 1 {{node plus2}} = AddN[N=2, T=DT_INT32, ](_recv_A_0, scalar) executor.cc:1688] Process node: 7 step 1 {{node _send_plus2_0}} = _Send[T=DT_INT32, client_terminated=true, recv, , , tensor_name="plus2:0", ](plus2) executor.cc:1837] Synchronous kernel done: 7 step 1 {{node _send_plus2_0}} = _Send[T=DT_INT32, client_terminated=true, recv, , , tensor_name="plus2:0", ](plus2) |
本作品采用知识共享署名 4.0 国际许可协议进行许可。