TF 计算图的执行

原创文章,转载请注明: 转载自慢慢的回味

本文链接地址: TF 计算图的执行

TF计算图优化完成后,在Session开始执行后就轮到TF 计算图的执行了。Tensorflow会根据计算图的节点信息,首先找到一个没有输入的节点作为根节点,创建一个task交给线程池取执行。每个节点完成后会根据Edge通知下游节点计算,直到所有节点完成计算,然后输出结果。

Tensorflow源码解读

计算图执行

还是以“TF 生成计算图”中的程序为例:
进入到如下的堆栈,从DirectSession::CreateExecutors()开始创建Executors。

tensorflow::DirectSession::CreateExecutors() at direct_session.cc:1,301 0x7ffff591344b	
tensorflow::DirectSession::GetOrCreateExecutors() at direct_session.cc:1,435 0x7ffff5914bec	
tensorflow::DirectSession::PRunSetup() at direct_session.cc:849 0x7ffff59162b7	
TF_SessionPRunSetup() at c_api.cc:2,668 0x7ffff2090e54	
main() at TensorflowTest2.cpp:58 0x40522e


先通过调用函数“CreateGraphs”创建“graphs”,且图已经分区。
然后通过“for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {”对分区图迭代创建Executor:
1 通过“device_mgr_->LookupDevice”得到当前分区图的device;
2 创建item(类型为ExecutorsAndKeys.PerPartitionExecutorsAndLib)和params(类型为LocalExecutorParams);
3 调用“NewExecutor( executor_type, params, std::move(partition_graph), &item->executor)”创建Executor。

//code in direct_session.h
  struct ExecutorsAndKeys {
    ExecutorsAndKeys() : step_count(0) {}
 
    std::atomic_int_fast64_t step_count;
    std::unique_ptr<Graph> graph;
    NameNodeMap name_to_node;
    std::vector<PerPartitionExecutorsAndLib> items;
    std::unordered_map<string, size_t> input_name_to_index;
    std::unordered_map<string, string> input_name_to_rendezvous_key;
    std::unordered_map<string, size_t> output_name_to_index;
    std::unordered_map<string, string> output_name_to_rendezvous_key;
 
    DataTypeVector input_types;
    DataTypeVector output_types;
 
    CallableOptions callable_options;
 
    int64 collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey;
  };
//code in direct_session.cc
Status DirectSession::CreateExecutors(
    const CallableOptions& callable_options,
    std::unique_ptr<ExecutorsAndKeys>* out_executors_and_keys,
    std::unique_ptr<FunctionInfo>* out_func_info,
    RunStateArgs* run_state_args) {
  BuildGraphOptions options;
  options.callable_options = callable_options;
  options.use_function_convention = !run_state_args->is_partial_run;
  options.collective_graph_key =
      callable_options.run_options().experimental().collective_graph_key();
  if (options_.config.experimental()
          .collective_deterministic_sequential_execution()) {
    options.collective_order = GraphCollectiveOrder::kEdges;
  }
 
  std::unique_ptr<FunctionInfo> func_info(new FunctionInfo);
  std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
 
  ek->callable_options = callable_options;
 
  std::unordered_map<string, std::unique_ptr<Graph>> graphs;
  TF_RETURN_IF_ERROR(CreateGraphs(
      options, &graphs, &func_info->flib_def, run_state_args, &ek->input_types,
      &ek->output_types, &ek->collective_graph_key));
......
  GraphOptimizer optimizer(optimizer_opts);
  for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
    const string& partition_name = iter->first;
    std::unique_ptr<Graph>& partition_graph = iter->second;
 
    Device* device;
    TF_RETURN_IF_ERROR(device_mgr_->LookupDevice(partition_name, &device));
 
    ek->items.resize(ek->items.size() + 1);
    auto* item = &(ek->items.back());
    auto lib = func_info->proc_flr->GetFLR(partition_name);
/*
func_info->proc_flr : ProcessFunctionLibraryRuntime类型
func_info->proc_flr->GetFLR : FunctionLibraryRuntime类型
FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR
*/
    if (lib == nullptr) {
      return errors::Internal("Could not find device: ", partition_name);
    }
    item->flib = lib;
 
    LocalExecutorParams params;
    params.device = device;/*比如为CPU*/
    params.function_library = lib;/*函数方法的运行时,类型为FunctionLibraryRuntime*/
/*  ProcessFunctionLibraryRuntime(
      const DeviceMgr* device_mgr, Env* env, int graph_def_version,
      const FunctionLibraryDefinition* lib_def,
      const OptimizerOptions& optimizer_options,
      thread::ThreadPool* thread_pool = nullptr,
      DistributedFunctionLibraryRuntime* parent = nullptr);
*/
    auto opseg = device->op_segment();
/*创建Kernel Operation的lambda方法*/
    params.create_kernel = [this, lib, opseg](const NodeDef& ndef,
                                              OpKernel** kernel) {
      // NOTE(mrry): We must not share function kernels (implemented
      // using `CallOp`) between subgraphs, because `CallOp::handle_`
      // is tied to a particular subgraph. Even if the function itself
      // is stateful, the `CallOp` that invokes it is not.
      if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) {
        return lib->CreateKernel(ndef, kernel);
      }
      auto create_fn = [lib, &ndef](OpKernel** kernel) {
/*
Status FunctionLibraryRuntimeImpl::CreateKernel(
    const NodeDef& ndef, const FunctionLibraryDefinition* lib_def,
    OpKernel** kernel) {
......
  if (lib_def->Find(ndef.op()) == nullptr) {
    // A primitive operation. Creates the registered kernel.
    return CreateNonCachedKernel(device_, this, ndef, graph_def_version_,
                                 kernel);
  }
*/
        return lib->CreateKernel(ndef, kernel);
      };
      // Kernels created for subgraph nodes need to be cached.  On
      // cache miss, create_fn() is invoked to create a kernel based
      // on the function library here + global op registry.
      return opseg->FindOrCreate(session_handle_, ndef.name(), kernel,
                                 create_fn);
    };
    params.delete_kernel = [lib](OpKernel* kernel) {
      if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string()))
        delete kernel;
    };
 
    optimizer.Optimize(lib, options_.env, device, &partition_graph,
                       /*shape_map=*/nullptr);
 
    // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph.
    const DebugOptions& debug_options =
        options.callable_options.run_options().debug_options();
    if (!debug_options.debug_tensor_watch_opts().empty()) {
      TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
          debug_options, partition_graph.get(), params.device));
    }
 
    TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
                                         device->name(),
                                         partition_graph.get()));
    // NewLocalExecutor takes ownership of partition_graph.
    item->graph = partition_graph.get();
    item->executor = nullptr;
    item->device = device;
    auto executor_type = options_.config.experimental().executor_type();
    TF_RETURN_IF_ERROR(NewExecutor(
        executor_type, params, std::move(partition_graph), &item->executor));
  }
 
......
  return Status::OK();
}

在代码executor.cc中,一个名叫“DEFAULT”的factory注册到了系统中,所以executor_factory.cc中的函数“GetFactory”会获取到前面的factory。
最后会调用“NewLocalExecutor”方法,创建ExecutorImpl并初始化。
注意到初始化方法里面的代码“root_nodes_.push_back(n)”,它把整个图的root node收集到root_nodes_了。
在后面,通过方法RunAsync开始所有图中节点的第一个task。当这些root node的task完成后,它会触发它的下游node的task进行计算。

//code in executor_factory.cc
Status ExecutorFactory::GetFactory(const string& executor_type,
                                   ExecutorFactory** out_factory) {
  tf_shared_lock l(executor_factory_lock);
 
  auto iter = executor_factories()->find(executor_type);
  if (iter == executor_factories()->end()) {
    return errors::NotFound(
        "No executor factory registered for the given executor type: ",
        executor_type, " ", RegisteredFactoriesErrorMessageLocked());
  }
 
  *out_factory = iter->second;
  return Status::OK();
}
 
Status NewExecutor(const string& executor_type,
                   const LocalExecutorParams& params,
                   std::unique_ptr<const Graph> graph,
                   std::unique_ptr<Executor>* out_executor) {
  ExecutorFactory* factory = nullptr;
  TF_RETURN_IF_ERROR(ExecutorFactory::GetFactory(executor_type, &factory));
  return factory->NewExecutor(params, std::move(graph), out_executor);
}
 
//code in executor.cc
namespace {
 
class DefaultExecutorRegistrar {
 public:
  DefaultExecutorRegistrar() {
    Factory* factory = new Factory;
    ExecutorFactory::Register("", factory);
    ExecutorFactory::Register("DEFAULT", factory);
  }
 
 private:
  class Factory : public ExecutorFactory {
    Status NewExecutor(const LocalExecutorParams& params,
                       std::unique_ptr<const Graph> graph,
                       std::unique_ptr<Executor>* out_executor) override {
      Executor* ret = nullptr;
      TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(graph), &ret));
      out_executor->reset(ret);
      return Status::OK();
    }
  };
};
static DefaultExecutorRegistrar registrar;
} 
 
//code in executor.cc
Status NewLocalExecutor(const LocalExecutorParams& params,
                        std::unique_ptr<const Graph> graph,
                        Executor** executor) {
  ExecutorImpl* impl = new ExecutorImpl(params, std::move(graph));
  const Status s = impl->Initialize();
  if (s.ok()) {
    *executor = impl;
  } else {
    delete impl;
  }
  return s;
}
 
Status ExecutorImpl::Initialize() {
  gview_.Initialize(graph_.get());
......
  // Preprocess every node in the graph to create an instance of op
  // kernel for each node.
  for (const Node* n : graph_->nodes()) {
    const int id = n->id();
    const string& frame_name = cf_info.frame_names[id];
    FrameInfo* frame_info = EnsureFrameInfo(frame_name);
 
    // See if this node is a root node, and if so, add to root_nodes_.
    if (n->in_edges().empty()) {
      root_nodes_.push_back(n);
    }
......

上面的Executors创建完成后就需要对executors_and_keys进行异步执行item.executor->RunAsync:

//code in direct_session.cc
Status DirectSession::PRunSetup(const std::vector<string>& input_names,
                                const std::vector<string>& output_names,
                                const std::vector<string>& target_nodes,
                                string* handle) {
......
  TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_names, output_names,
                                          target_nodes, &executors_and_keys,
                                          &run_state_args));
 
......
 
  for (auto& item : executors_and_keys->items) {
    item.executor->RunAsync(args, barrier->Get());
  }
 
  *handle = run_state_args.handle;
  return Status::OK();
}
 
//code in executor.cc
void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
  (new ExecutorState(args, this))->RunAsync(std::move(done));
}

从impl_->root_nodes_获取所有的TaggedNode加入到ready(TaggedNodeSeq)中,然后调用ScheduleReady(ready, nullptr)对task进行schedule:
这是所有node计算的入口。

//code in executor.cc
ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl)
......
 
void ExecutorState::RunAsync(Executor::DoneCallback done) {
  const Graph* graph = impl_->graph_.get();
  TaggedNodeSeq ready;
 
  // Ask the device to fill in the device context map.
  Device* device = impl_->params_.device;
  const Status fill_status =
      device->FillContextMap(graph, &device_context_map_);
  if (!fill_status.ok()) {
    delete this;
    done(fill_status);
    return;
  }
 
  // Initialize the ready queue.
  for (const Node* n : impl_->root_nodes_) {
    DCHECK_EQ(n->in_edges().size(), 0);
    ready.push_back(TaggedNode{n, root_frame_, 0, false});
  }
  if (ready.empty()) {
    delete this;
    done(Status::OK());
  } else {
    num_outstanding_ops_ = ready.size();
    root_frame_->iterations[0]->outstanding_ops = ready.size();
    done_cb_ = std::move(done);
    // Schedule to run all the ready ops in thread pool.
    ScheduleReady(ready, nullptr);
  }
}

executor.cc中的ScheduleReady方法
对每个tagged_node进行处理Process,最后用函数“[=]() { Process(tagged_node, scheduled_nsec); }”创建一个task,然后加入队列Queue;如果失败,直接调用env_.ExecuteTask(t)进行运行。

//code in executor.cc
void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready,
                                  TaggedNodeReadyQueue* inline_ready) {
  if (ready.empty()) return;
 
  int64 scheduled_nsec = 0;
  if (stats_collector_) {
    scheduled_nsec = nodestats::NowInNsec();
  }
 
  if (inline_ready == nullptr) {
    // Schedule to run all the ready ops in thread pool.
    for (auto& tagged_node : ready) {
      runner_([=]() { Process(tagged_node, scheduled_nsec); });
/*
runner_在前面定义:
  args.runner = [this, pool](Executor::Args::Closure c) {
    SchedClosure(pool, std::move(c));
  };
*/
    }
    return;
  }
 
//code in direct_session.cc
void DirectSession::SchedClosure(thread::ThreadPool* pool,
                                 std::function<void()> c) {
  if (pool != nullptr) {
    pool->Schedule(std::move(c));
  } else {
    c();
  }
}
 
//code in threadpool.cc
void ThreadPool::Schedule(std::function<void()> fn) {
  CHECK(fn != nullptr);
  impl_->Schedule(std::move(fn));
}
 
//code in NonBlockingThreadPool.h
  void Schedule(std::function<void()> fn) EIGEN_OVERRIDE {
    ScheduleWithHint(std::move(fn), 0, num_threads_);
  }
 
  void ScheduleWithHint(std::function<void()> fn, int start,
                        int limit) override {
/*
fn为前面的“[=]() { Process(tagged_node, scheduled_nsec); }”lambda表达式函数。
*/
    Task t = env_.CreateTask(std::move(fn));
    PerThread* pt = GetPerThread();
    if (pt->pool == this) {
      // Worker thread of this pool, push onto the thread's queue.
      Queue& q = thread_data_[pt->thread_id].queue;
      t = q.PushFront(std::move(t));
    } else {
      // A free-standing thread (or worker of another pool), push onto a random
      // queue.
      eigen_plain_assert(start < limit);
      eigen_plain_assert(limit <= num_threads_);
      int num_queues = limit - start;
      int rnd = Rand(&pt->rand) % num_queues;
      eigen_plain_assert(start + rnd < limit);
      Queue& q = thread_data_[start + rnd].queue;
      t = q.PushBack(std::move(t));
    }
    if (!t.f) {
      ec_.Notify(false);
    } else {
      env_.ExecuteTask(t);  // Push failed, execute directly.
    }
  }

加入队列的task会在前面创建的thread pool执行:

//code in NonBlockingThreadPool.h
  ThreadPoolTempl(int num_threads, bool allow_spinning,
                  Environment env = Environment())
      : env_(env),
        num_threads_(num_threads),
        allow_spinning_(allow_spinning),
        thread_data_(num_threads),
        all_coprimes_(num_threads),
        waiters_(num_threads),
        blocked_(0),
        spinning_(0),
        done_(false),
        cancelled_(false),
        ec_(waiters_) {
......
    thread_data_.resize(num_threads_);
    for (int i = 0; i < num_threads_; i++) {
      SetStealPartition(i, EncodePartition(0, num_threads_));
      thread_data_[i].thread.reset(
          env_.CreateThread([this, i]() { WorkerLoop(i); }));
    }
......
  }
 
  void WorkerLoop(int thread_id) {
......
    if (num_threads_ == 1) {
......
    } else {
      while (!cancelled_) {
        Task t = q.PopFront();
......
        if (t.f) {
          env_.ExecuteTask(t);
        }
      }
    }
  }

当node计算完成后,调用executor.cc中的PropagateOutputs方法根据Edge连接把下游节点添加进ready序列待执行,接着调用executor.cc中的NodeDone方法通知下游节点执行,:
通过方法ScheduleReady进行下一步的ScheduleWithHint。

//code in executor.cc
void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
  WithContext wc(context_);
  const GraphView& gview = impl_->gview_;
  TaggedNodeSeq ready;
  TaggedNodeReadyQueue inline_ready;
......
  EntryVector outputs;
  bool completed = false;
  inline_ready.push_back(tagged_node);
  while (!inline_ready.empty()) {
......
      // Set up compute params.
      OpKernel* op_kernel = item.kernel;
      params.op_kernel = op_kernel;
      params.frame_iter = FrameAndIter(input_frame->frame_id, input_iter);
      params.is_input_dead = is_input_dead;
      params.output_attr_array = item.output_attrs();
      params.forward_from_array = item.forward_from();
 
      if (item.kernel_is_async) {
        // Asynchronous computes.
        AsyncOpKernel* async = item.kernel->AsAsync();
        DCHECK(async != nullptr);
        launched_asynchronously = true;
        AsyncState* state =
            new AsyncState(params, tagged_node, &item, first_input, stats);
 
        auto done = [this, state]() {
          Device* device = impl_->params_.device;
          NodeExecStatsInterface* stats = state->stats;  // Shorthand
          Entry* first_input = state->first_input;       // Shorthand
 
          nodestats::SetOpEnd(stats);
          EntryVector outputs;
          Status s = ProcessOutputs(*state->item, &state->ctx, &outputs, stats);
          nodestats::SetMemory(stats, &state->ctx);
          if (vlog_) {
            VLOG(2) << "Async kernel done: " << state->item->node->id()
                    << " step " << step_id_ << " "
                    << SummarizeNode(*state->item->node)
                    << (state->tagged_node.is_dead ? " is dead" : "")
                    << " device: " << device->name();
          }
 
          // Clears inputs.
          const int num_inputs = state->item->num_inputs;
          for (int i = 0; i < num_inputs; ++i) {
            (first_input + i)->ClearVal();
          }
          FrameState* input_frame = state->tagged_node.input_frame;
          const int64 input_iter = state->tagged_node.input_iter;
          const int id = state->tagged_node.node->id();
          MaybeMarkCompleted(input_frame, input_iter, id);
          TaggedNodeSeq ready;
          if (s.ok()) {
            PropagateOutputs(state->tagged_node, state->item, &outputs, &ready);
          }
          outputs.clear();
          if (s.ok() && impl_->device_record_tensor_accesses_) {
            // Get the list of all tensors accessed during the execution
            TensorReferenceVector accessed;
            state->ctx.retrieve_accessed_tensors(&accessed);
            nodestats::SetReferencedTensors(stats, accessed);
            // callee takes ownership of the vector
            device->ConsumeListOfAccessedTensors(state->ctx.op_device_context(),
                                                 accessed);
          }
          const bool completed =
              NodeDone(s, state->item->node, ready, stats, nullptr);
          delete state;
          if (completed) ScheduleFinish();
        };
        nodestats::SetOpStart(stats);
        {
          profiler::TraceMe activity(
              [&] {
                return strings::StrCat(
                    op_kernel->name(), ":", op_kernel->type_string(),
                    "#id=", step_container_ ? step_container_->step_id() : 0,
                    ",device=", device->name(), ",async=true#");
              },
              profiler::GetTFTraceMeLevel(op_kernel->IsExpensive()));
          device->ComputeAsync(async, &state->ctx, done);
        }
      } else {
......
 
        nodestats::SetOpEnd(stats);
        s = ProcessOutputs(item, &ctx, &outputs, stats);
        if (s.ok() && impl_->device_record_tensor_accesses_) {
          // Get the list of all tensors accessed during the execution
          ctx.retrieve_accessed_tensors(&accessed_tensors);
          device_context = ctx.op_device_context();
        }
        nodestats::SetMemory(stats, &ctx);
      }
    }
 
    if (!launched_asynchronously) {
......
      // Propagates outputs.
      if (s.ok()) {
        PropagateOutputs(tagged_node, &item, &outputs, &ready);
      }
      outputs.clear();
      if (!accessed_tensors.empty()) {
        nodestats::SetReferencedTensors(stats, accessed_tensors);
        // device_context is set above in synchronous computes
        device->ConsumeListOfAccessedTensors(device_context, accessed_tensors);
      }
      if (stats) {
        scheduled_nsec = nodestats::NowInNsec();
      }
      // Postprocess.
      completed = NodeDone(s, item.node, ready, stats, &inline_ready);
    }
  }  // while !inline_ready.empty()
 
  // This thread of computation is done if completed = true.
  if (completed) Finish();
}

executor.cc中的PropagateOutputs调用output_frame->ActivateNodes,ActivateNodes进而执行ready->emplace_back(dst_item->node, this, iter, dst_dead);完成节点添加进ready。
NodeDone方法调用ScheduleReady继续任务执行。

void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
                                     const NodeItem* item, EntryVector* outputs,
                                     TaggedNodeSeq* ready) {
  auto activity_handle = absl::make_unique<profiler::TraceMe>(
      [&]() {
        return strings::StrCat("ExecutorPropagateOutputs:",
                               item->kernel->name(), "#id=", step_id_, "#");
      },
      profiler::GetTFTraceMeLevel(/*is_expensive=*/false));
 
  const Node* node = tagged_node.node;
  FrameState* input_frame = tagged_node.input_frame;
  const int64 input_iter = tagged_node.input_iter;
  const bool is_dead = tagged_node.is_dead;
 
  // Propagates outputs along out edges, and puts newly ready nodes
  // into the ready queue.
  ready->clear();
  bool is_frame_done = false;
  FrameState* output_frame = input_frame;
  int64 output_iter = input_iter;
 
  if (!item->is_enter_exit_or_next_iter) {
    // Fast path for nodes types that don't need special handling
    DCHECK_EQ(input_frame, output_frame);
    // Normal path for most nodes
    mutex_lock l(input_frame->mu);
    output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
    is_frame_done = input_frame->DecrementOutstandingOpsLocked(
        &impl_->gview_, input_iter, ready);
  } else if (item->is_enter) {
    FindOrCreateChildFrame(input_frame, input_iter, node, &output_frame);
    output_iter = 0;
    {
      const NodeItem* item = impl_->gview_.node(node->id());
      mutex_lock l(output_frame->mu);
      if (item->is_constant_enter) {
        // Propagate to all active iterations if this is a loop invariant.
        output_frame->AddLoopInv(item, (*outputs)[0], ready);
      } else {
        output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
      }
      output_frame->num_pending_inputs--;
    }
    is_frame_done =
        input_frame->DecrementOutstandingOps(&impl_->gview_, input_iter, ready);
  } else if (item->is_exit) {
    if (is_dead) {
      mutex_lock l(input_frame->mu);
      // Stop and remember this node if it is a dead exit.
      if (input_iter == input_frame->iteration_count) {
        input_frame->dead_exits.push_back(node);
      }
      is_frame_done = input_frame->DecrementOutstandingOpsLocked(
          &impl_->gview_, input_iter, ready);
    } else {
      output_frame = input_frame->parent_frame;
      output_iter = input_frame->parent_iter;
      {
        mutex_lock l(output_frame->mu);
        output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
      }
      is_frame_done = input_frame->DecrementOutstandingOps(&impl_->gview_,
                                                           input_iter, ready);
    }
  } else {
    DCHECK(IsNextIteration(node));
    mutex_lock l(input_frame->mu);
    if (is_dead) {
      // Stop the deadness propagation.
      output_frame = nullptr;
    } else {
      if (input_iter == input_frame->iteration_count &&
          input_frame->num_outstanding_iterations ==
              input_frame->max_parallel_iterations) {
        // Reached the maximum for parallel iterations.
        input_frame->next_iter_roots.push_back({node, (*outputs)[0]});
        output_frame = nullptr;
      } else {
        // If this is a new iteration, start it.
        if (input_iter == input_frame->iteration_count) {
          input_frame->IncrementIteration(&impl_->gview_, ready);
        }
        output_iter = input_iter + 1;
      }
    }
    if (output_frame != nullptr) {
      // This is the case when node is not Enter, Exit, or NextIteration.
      DCHECK(input_frame == output_frame);
      output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
    }
    is_frame_done = input_frame->DecrementOutstandingOpsLocked(
        &impl_->gview_, input_iter, ready);
  }
 
  // At this point, this node is completely done. We also know if the
  // completion of this node makes its frame completed.
  if (is_frame_done) {
    FrameState* parent_frame = input_frame->parent_frame;
    const int64 parent_iter = input_frame->parent_iter;
    DeleteFrame(input_frame, ready);
    if (parent_frame != nullptr) {
      // The completion of frame may cause completions in its parent frame.
      // So clean things up recursively.
      CleanupFramesIterations(parent_frame, parent_iter, ready);
    }
  }
}
 
bool ExecutorState::NodeDone(const Status& s, const Node* node,
                             const TaggedNodeSeq& ready,
                             NodeExecStatsInterface* stats,
                             TaggedNodeReadyQueue* inline_ready) {
  nodestats::SetAllEnd(stats);
  if (stats) {
    if (stats_collector_) {
      stats->Done(impl_->params_.device->name());
    } else {
      delete stats;
    }
  }
 
  bool abort_run = false;
  if (!s.ok()) {
    // Some error happened. This thread of computation is done.
    mutex_lock l(mu_);
    if (status_.ok()) {
      abort_run = true;
 
      // If execution has been cancelled, mark any new errors as being derived.
      // This ensures any errors triggered by cancellation are marked as
      // derived.
      if (cancellation_manager_ && cancellation_manager_->IsCancelled()) {
        status_ = StatusGroup::MakeDerived(s);
      } else {
        status_ = s;
      }
    }
  }
  if (abort_run) {
    TRACEPRINTF("StartAbort: %s", s.ToString().c_str());
    if (cancellation_manager_) {
      // only log when the abort happens during the actual run time.
      auto device_name = impl_->params_.device->name();
      // Use VLOG instead of LOG(warning) because error status is expected when
      // the executor is run under the grappler optimization phase or when
      // iterating through a tf.data input pipeline.
      VLOG(1) << "[" << device_name << "] Executor start aborting: " << s;
    }
 
    if (rendezvous_) {
      rendezvous_->StartAbort(s);
    }
    if (collective_executor_) {
      collective_executor_->StartAbort(s);
    }
    if (cancellation_manager_) {
      cancellation_manager_->StartCancel();
    }
  }
 
  bool completed = false;
  const size_t ready_size = ready.size();
  if (ready_size == 0 || !s.ok()) {
    completed = (num_outstanding_ops_.fetch_sub(1) == 1);
  } else if (ready_size > 1) {
    num_outstanding_ops_.fetch_add(ready_size - 1, std::memory_order_relaxed);
  }
 
  // Schedule the ready nodes in 'ready'.
  if (s.ok()) {
    ScheduleReady(ready, inline_ready);
  }
  return completed;
}
 
void ExecutorState::FrameState::ActivateNodes(const NodeItem* item,
                                              const bool is_dead, int64 iter,
                                              EntryVector* outputs,
                                              TaggedNodeSeq* ready) {
  const GraphView& gview = executor->gview_;
  IterationState* iter_state = GetIteration(iter);
  const size_t num_output_edges = item->num_output_edges;
  const EdgeInfo* edges = item->output_edge_list();
  Entry* input_tensors = iter_state->input_tensors;
  for (size_t out_index = 0; out_index < num_output_edges; out_index++) {
    const EdgeInfo& e = edges[out_index];
    const int dst_id = e.dst_id;
    const NodeItem* dst_item = gview.node(dst_id);
    const PendingCounts::Handle dst_pending_id = dst_item->pending_id;
    const int src_slot = e.output_slot;
 
    // TODO(yuanbyu): We don't need this if we require the subgraph
    // given to an executor not to contain a sink node.
    if (dst_item->is_sink) continue;
 
    bool dst_dead = false;
    bool dst_ready = false;
    // True iff this input for dst is needed. We only set this input for
    // dst if this flag is true. This is needed to make the thread safety
    // analysis happy.
    const bool is_control_edge = (src_slot == Graph::kControlSlot);
    bool dst_need_input = !is_control_edge;
    if (dst_item->is_merge) {
      // A merge node is ready if all control inputs have arrived and either
      // a) a live data input becomes available or b) all data inputs are
      // dead. For Merge, pending's LSB is set iff a live data input has
      // arrived.
      if (is_control_edge) {
        iter_state->decrement_pending(dst_pending_id, 2);
        int count = iter_state->pending(dst_pending_id);
        int dead_cnt = iter_state->dead_count(dst_pending_id);
        dst_dead = (dead_cnt == dst_item->num_inputs);
        dst_ready = (count == 0) || ((count == 1) && dst_dead);
      } else {
        if ((*outputs)[src_slot].has_value) {
          // This is a live data input.
          int count = iter_state->pending(dst_pending_id);
          iter_state->mark_live(dst_pending_id);
          // Only the first live edge sets the input and (potentially)
          // triggers execution. The low bit of count is set if and
          // only if no live input has been used yet (mark_live clears
          // it). The node should be started if and only if this is
          // the first live input and there are no pending control
          // edges, i.e. count == 1.
          dst_ready = (count == 1);
          dst_need_input = ((count & 0x1) == 1);
        } else {
          // This is a dead data input. Note that dst_node is dead if node is
          // a dead enter. We need this to handle properly a while loop on
          // the untaken branch of a conditional.
          // TODO(yuanbyu): This is a bit hacky, but a good solution for
          // now.
          iter_state->increment_dead_count(dst_pending_id);
          const int dead_cnt = iter_state->dead_count(dst_pending_id);
          dst_dead = (dead_cnt == dst_item->num_inputs) || item->is_enter;
          dst_ready = (iter_state->pending(dst_pending_id) == 1) && dst_dead;
          dst_need_input = false;
        }
      }
    } else {
      const bool increment_dead =
          (is_dead || (!is_control_edge && !(*outputs)[src_slot].has_value));
      int pending, dead;
      iter_state->adjust_for_activation(dst_pending_id, increment_dead,
                                        &pending, &dead);
      dst_dead = (dead > 0);
      dst_ready = (pending == 0);
    }
 
    if (dst_need_input) {
      const int dst_slot = e.input_slot;
      const int dst_loc = dst_item->input_start + dst_slot;
      if (e.is_last) {
        input_tensors[dst_loc] = std::move((*outputs)[src_slot]);
      } else {
        input_tensors[dst_loc] = (*outputs)[src_slot];
      }
    }
 
    // Add dst to the ready queue if it's ready
    if (dst_ready) {
      if (dst_item->is_control_trigger) dst_dead = false;
      ready->emplace_back(dst_item->node, this, iter, dst_dead);
      iter_state->outstanding_ops++;
    }
  }
}

最后,从执行的log也可以验证程序确实按计算图执行的:

executor.cc:1688] Process node: 0 step 1 {{node _SOURCE}} = NoOp[]() 
executor.cc:1837] Synchronous kernel done: 0 step 1 {{node _SOURCE}} = NoOp[]() 
executor.cc:1688] Process node: 2 step 1 {{node scalar}} = Const[dtype=DT_INT32, value=Tensor<type: int32 shape: [] values: 2>, ]() 
executor.cc:1837] Synchronous kernel done: 2 step 1 {{node scalar}} = Const[dtype=DT_INT32, value=Tensor<type: int32 shape: [] values: 2>, ]() 
executor.cc:1688] Process node: 3 step 1 {{node _recv_A_0}} = _Recv[client_terminated=true, recv, , , tensor_name="A:0", tensor_type=DT_INT32, ]() 
executor.cc:1688] Process node: 5 step 1 {{node _recv_B_0}} = _Recv[client_terminated=true, recv, , , tensor_name="B:0", tensor_type=DT_INT32, ]() 
executor.cc:1748] Async kernel done: 3 step 1 {{node _recv_A_0}} = _Recv[client_terminated=true, recv, , , tensor_name="A:0", tensor_type=DT_INT32, ]() 
executor.cc:1688] Process node: 4 step 1 {{node plus2}} = AddN[N=2, T=DT_INT32, ](_recv_A_0, scalar) 
executor.cc:1837] Synchronous kernel done: 4 step 1 {{node plus2}} = AddN[N=2, T=DT_INT32, ](_recv_A_0, scalar) 
executor.cc:1688] Process node: 7 step 1 {{node _send_plus2_0}} = _Send[T=DT_INT32, client_terminated=true, recv, , , tensor_name="plus2:0", ](plus2) 
executor.cc:1837] Synchronous kernel done: 7 step 1 {{node _send_plus2_0}} = _Send[T=DT_INT32, client_terminated=true, recv, , , tensor_name="plus2:0", ](plus2)

本作品采用知识共享署名 4.0 国际许可协议进行许可。

发表回复