TensorFlow Session的Setup

原创文章,转载请注明: 转载自慢慢的回味

本文链接地址: TensorFlow Session的Setup

TensorFlow Session的Setup完成整个Session的创建,设置输入数据类型(feeds)和输出数据类型(fetches)。然后利用图Graph创建基于Session的基本图,开启线程器(Exectors)等待Session开始。


回目录

SessionPRunSetup

主程序调用c_api.cc中的TF_SessionPRunSetup完成Session的Setup。

	TF_SessionPRunSetup(sess, feeds, TF_ARRAYSIZE(feeds), fetches,
			TF_ARRAYSIZE(fetches), NULL, 0, &handle, s);
 
void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
                         int ninputs, const TF_Output* outputs, int noutputs,
                         const TF_Operation* const* target_opers, int ntargets,
                         const char** handle, TF_Status* status) {
  *handle = nullptr;
 
  if (session->extend_before_run &&
      !ExtendSessionGraphHelper(session, status)) {
    return;
  }
 
  std::vector<string> input_names(ninputs);
  for (int i = 0; i < ninputs; ++i) {
    input_names[i] = OutputName(inputs[i]);
  }
 
  std::vector<string> output_names(noutputs);
  for (int i = 0; i < noutputs; ++i) {
    output_names[i] = OutputName(outputs[i]);
  }
 
  std::vector<string> target_names(ntargets);
  for (int i = 0; i < ntargets; ++i) {
    target_names[i] = target_opers[i]->node.name();
  }
 
  string new_handle;
  status->status = session->session->PRunSetup(input_names, output_names,
                                               target_names, &new_handle);
  if (TF_GetCode(status) == TF_OK) {
    char* buf = new char[new_handle.size() + 1];
    memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
    *handle = buf;
  }
}

创建基本图baseGraph

首先会调用c_api.cc中的ExtendSessionGraphHelper完成图的扩展。

bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
  if (session->graph != nullptr) {
    // Take the graph lock before the session lock to avoid deadlock. This is
    // safe since session->graph does not change.
    session->graph->mu.lock();
    mutex_lock session_lock(session->mu);
    const Graph& graph = session->graph->graph;
 
    const string& mutation_warning = session->graph->sessions[session];
    if (!mutation_warning.empty()) {
      // TODO(b/74949947): turn this back into an error status
      LOG(WARNING) << mutation_warning;
      session->graph->sessions[session].clear();
    }
 
    const auto num_nodes = graph.num_node_ids();
    if (session->last_num_graph_nodes < num_nodes) {
      // TODO(nolivia): check this on a subset of the graph instead of all of
      // it.
      status->status = graph::ValidateGraphHasNoCycle(session->graph->graph);
      if (TF_GetCode(status) != TF_OK) {
        session->graph->mu.unlock();
        return false;
      }
 
      GraphDef graph_def;
      *graph_def.mutable_versions() = graph.versions();
      // Fill graph_def with nodes with ids in the range
      // [session->last_num_graph_nodes, num_nodes), that is the nodes
      // added since the last TF_SessionRun() call.
      for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) {
        Node* const node = graph.FindNodeId(id);
        if (node != nullptr && node->IsOp()) {
          NodeDef* const node_def = graph_def.add_node();
          *node_def = node->def();
        }
      }
      *graph_def.mutable_library() = graph.flib_def().ToProto();
      session->graph->mu.unlock();
      status->status = session->session->Extend(std::move(graph_def));
      if (TF_GetCode(status) != TF_OK) {
        // Contract is we always delete input_values[i].
        return false;
      }
      // Note: session->session is not modified if Extend() fails, so
      // we only set last_num_graph_nodes if it succeeds.
      session->last_num_graph_nodes = num_nodes;
    } else {
      session->graph->mu.unlock();
    }
  }
  return true;
}

这儿会调用direct_session.cc的DirectSession::Extend方法,DirectSession::ExtendLocked方法。

Status DirectSession::Extend(GraphDef&& graph) {
  TF_RETURN_IF_ERROR(CheckNotClosed());
  mutex_lock l(graph_state_lock_);
  return ExtendLocked(std::move(graph));
}
 
Status DirectSession::ExtendLocked(GraphDef graph) {
  if (!(flib_def_ && execution_state_)) {
    // If this is the first call, we can initialize the execution state
    // with `graph` and do not need to call `Extend()`.
    // NOTE(mrry): The function library created here will be used for
    // all subsequent extensions of the graph.
    flib_def_.reset(
        new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
    GraphExecutionStateOptions options;
    options.device_set = &device_set_;
    options.session_options = &options_;
    options.session_handle = session_handle_;
    TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph(
        std::move(graph), options, &execution_state_));
    graph_created_ = true;
  } else {
    TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library()));
    std::unique_ptr<GraphExecutionState> state;
    // TODO(mrry): Rewrite GraphExecutionState::Extend() to take `graph` by
    // value and move `graph` in here.
    TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
    execution_state_.swap(state);
  }
  return Status::OK();
}

调用https://gitee.com/mirrors/tensorflow/blob/v1.15.0/tensorflow/core/common_runtime/graph_execution_state.cc的GraphExecutionState::MakeForBaseGraph方法创建基本图。
ConvertGraphDefToGraph方法把图的定义转换为计算图。
ret->InitBaseGraph会初始化图,检查图中节点的op kernel是否有device可用。并得到在每个device上的优先级。

/* static */ Status GraphExecutionState::MakeForBaseGraph(
    GraphDef&& graph_def, const GraphExecutionStateOptions& options,
    std::unique_ptr<GraphExecutionState>* out_state) {
#ifndef __ANDROID__
  VLOG(4) << "Graph proto is \n" << graph_def.DebugString();
#endif  // __ANDROID__
 
  auto flib_def = absl::make_unique<FunctionLibraryDefinition>(
      OpRegistry::Global(), graph_def.library());
 
  TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&graph_def, *flib_def, 0));
 
  if (options.session_options->config.graph_options().place_pruned_graph() ||
      !options.session_options->config.experimental()
           .optimize_for_static_graph()) {
    auto ret = absl::WrapUnique(new GraphExecutionState(
        absl::make_unique<GraphDef>(std::move(graph_def)), std::move(flib_def),
        options));
 
    // When place_pruned_graph is true, a different Graph* will be initialized
    // each time we prune the original graph, so there is no need to
    // construct a Graph* in this case.
    if (!options.session_options->config.graph_options().place_pruned_graph()) {
      auto base_graph = absl::make_unique<Graph>(OpRegistry::Global());
      TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *ret->original_graph_def_,
                                                base_graph.get()));
      TF_RETURN_IF_ERROR(ret->InitBaseGraph(std::move(base_graph)));
    }
    *out_state = std::move(ret);
  } else {
    auto ret = absl::WrapUnique(
        new GraphExecutionState(nullptr, std::move(flib_def), options));
    auto base_graph = absl::make_unique<Graph>(OpRegistry::Global());
    TF_RETURN_IF_ERROR(
        ConvertGraphDefToGraph({}, std::move(graph_def), base_graph.get()));
    TF_RETURN_IF_ERROR(ret->InitBaseGraph(std::move(base_graph)));
    *out_state = std::move(ret);
  }
  return Status::OK();
}
Session的设置

这儿会调用direct_session.cc的DirectSession::PRunSetup方法完成设置。
首先调用GetOrCreateExecutors创建executors。见TensorFlow Executor的创建
然后异步运行item.executor->RunAsync(args, barrier->Get());直到barrier完成。
注意到(thread::ThreadPool* pool = thread_pools_[0].first;)和( args.runner = [this, pool](Executor::Args::Closure c) { pool->Schedule(std::move(c)); };),executor里的线程用的就是这儿的ThreadPool* pool。

Status DirectSession::PRunSetup(const std::vector<string>& input_names,
                                const std::vector<string>& output_names,
                                const std::vector<string>& target_nodes,
                                string* handle) {
  TF_RETURN_IF_ERROR(CheckNotClosed());
  TF_RETURN_IF_ERROR(CheckGraphCreated("PRunSetup()"));
 
  // RunOptions is not available in PRunSetup, so use thread pool 0.
  thread::ThreadPool* pool = thread_pools_[0].first;
 
  // Check if we already have an executor for these arguments.
  ExecutorsAndKeys* executors_and_keys;
  // TODO(cais): TFDBG support for partial runs.
  DebugOptions debug_options;
  RunStateArgs run_state_args(debug_options);
  run_state_args.is_partial_run = true;
  TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_names, output_names,
                                          target_nodes, &executors_and_keys,
                                          &run_state_args));
 
  // Create the run state and save it for future PRun calls.
  Executor::Args args;
  args.step_id = step_id_counter_.fetch_add(1);
  RunState* run_state =
      new RunState(input_names, output_names, args.step_id, &devices_);
  run_state->rendez = new IntraProcessRendezvous(device_mgr_.get());
  {
    mutex_lock l(executor_lock_);
    if (!partial_runs_
             .emplace(run_state_args.handle,
                      std::unique_ptr<RunState>(run_state))
             .second) {
      return errors::Internal("The handle '", run_state_args.handle,
                              "' created for this partial run is not unique.");
    }
  }
 
  // Start parallel Executors.
  const size_t num_executors = executors_and_keys->items.size();
  ExecutorBarrier* barrier = new ExecutorBarrier(
      num_executors, run_state->rendez, [run_state](const Status& ret) {
        if (!ret.ok()) {
          mutex_lock l(run_state->mu_);
          run_state->status.Update(ret);
        }
        run_state->executors_done.Notify();
      });
 
  args.rendezvous = run_state->rendez;
  args.cancellation_manager = cancellation_manager_;
  // Note that Collectives are not supported in partial runs
  // because RunOptions is not passed in so we can't know whether
  // their use is intended.
  args.collective_executor = nullptr;
  args.runner = [this, pool](Executor::Args::Closure c) {
    pool->Schedule(std::move(c));
  };
  args.session_state = &session_state_;
  args.session_handle = session_handle_;
  args.tensor_store = &run_state->tensor_store;
  args.step_container = &run_state->step_container;
  if (LogMemory::IsEnabled()) {
    LogMemory::RecordStep(args.step_id, run_state_args.handle);
  }
  args.sync_on_finish = sync_on_finish_;
 
  if (options_.config.graph_options().build_cost_model()) {
    run_state->collector.reset(new StepStatsCollector(nullptr));
    args.stats_collector = run_state->collector.get();
  }
 
  for (auto& item : executors_and_keys->items) {
    item.executor->RunAsync(args, barrier->Get());
  }
 
  *handle = run_state_args.handle;
  return Status::OK();
}

本作品采用知识共享署名 4.0 国际许可协议进行许可。

发表回复