原创文章,转载请注明: 转载自慢慢的回味
本文链接地址: TensorFlow计算图的创建
回目录
Tensorflow机器学习任务的核心就是根据用户定义的图graph模型以及参数求解方式进行抽象之后,生成一个由节点和有向边组成,可以确定一个唯一的计算逻辑用有向无环图,称之为计算图。它定义了数据的流转方式,数据的计算方式,以及各种计算之间的相互依赖关系等。节点包括计算节点(Operation)、存储节点(Variable)和数据节点(Placeholder)3类,用于计算数据和存储数据。有向边表示数据的流转方式和依赖。
创建计算图
direct_session.cc中的方法CreateGraphs进行计算图的创建。它调用BuildGraph进行图的创建,调用Partition方法进行图的分区,调用ConvertGraphDefToGraph完成用户图向计算图的转换。
Status DirectSession::CreateGraphs( const BuildGraphOptions& subgraph_options, std::unordered_map<string, std::unique_ptr<Graph>>* outputs, std::unique_ptr<FunctionLibraryDefinition>* flib_def, RunStateArgs* run_state_args, DataTypeVector* input_types, DataTypeVector* output_types, int64* collective_graph_key) { mutex_lock l(graph_state_lock_); std::unique_ptr<ClientGraph> client_graph; std::unique_ptr<GraphExecutionState> temp_exec_state_holder; GraphExecutionState* execution_state = nullptr; if (options_.config.graph_options().place_pruned_graph()) { // Because we are placing pruned graphs, we need to create a // new GraphExecutionState for every new unseen graph, // and then place it. GraphExecutionStateOptions prune_options; prune_options.device_set = &device_set_; prune_options.session_options = &options_; prune_options.stateful_placements = stateful_placements_; prune_options.session_handle = session_handle_; TF_RETURN_IF_ERROR(GraphExecutionState::MakeForPrunedGraph( *execution_state_, prune_options, subgraph_options, &temp_exec_state_holder, &client_graph)); execution_state = temp_exec_state_holder.get(); } else { execution_state = execution_state_.get(); TF_RETURN_IF_ERROR( execution_state->BuildGraph(subgraph_options, &client_graph)); } *collective_graph_key = client_graph->collective_graph_key; if (subgraph_options.callable_options.feed_size() != client_graph->feed_types.size()) { return errors::Internal( "Graph pruning failed: requested number of feed endpoints = ", subgraph_options.callable_options.feed_size(), " versus number of pruned feed endpoints = ", client_graph->feed_types.size()); } if (subgraph_options.callable_options.fetch_size() != client_graph->fetch_types.size()) { return errors::Internal( "Graph pruning failed: requested number of fetch endpoints = ", subgraph_options.callable_options.fetch_size(), " versus number of pruned fetch endpoints = ", client_graph->fetch_types.size()); } auto current_stateful_placements = execution_state->GetStatefulPlacements(); // Update our current state based on the execution_state's // placements. If there are any mismatches for a node, // we should fail, as this should never happen. for (auto placement_pair : current_stateful_placements) { const string& node_name = placement_pair.first; const string& placement = placement_pair.second; auto iter = stateful_placements_.find(node_name); if (iter == stateful_placements_.end()) { stateful_placements_.insert(std::make_pair(node_name, placement)); } else if (iter->second != placement) { return errors::Internal( "Stateful placement mismatch. " "Current assignment of ", node_name, " to ", iter->second, " does not match ", placement); } } stateful_placements_ = execution_state->GetStatefulPlacements(); // Remember the graph in run state if this is a partial run. if (run_state_args->is_partial_run) { run_state_args->graph.reset(new Graph(flib_def_.get())); CopyGraph(*execution_state->full_graph(), run_state_args->graph.get()); } // Partition the graph across devices. PartitionOptions popts; popts.node_to_loc = [](const Node* node) { return node->assigned_device_name(); }; popts.new_name = [this](const string& prefix) { return strings::StrCat(prefix, "/_", edge_name_counter_.fetch_add(1)); }; popts.get_incarnation = [](const string& name) { // The direct session does not have changing incarnation numbers. // Just return '1'. return 1; }; popts.flib_def = &client_graph->graph.flib_def(); popts.control_flow_added = false; std::unordered_map<string, GraphDef> partitions; TF_RETURN_IF_ERROR(Partition(popts, &client_graph->graph, &partitions)); std::vector<string> device_names; for (auto device : devices_) { // Extract the LocalName from the device. device_names.push_back(DeviceNameUtils::LocalName(device->name())); } // Check for valid partitions. for (const auto& partition : partitions) { const string local_partition_name = DeviceNameUtils::LocalName(partition.first); if (std::count(device_names.begin(), device_names.end(), local_partition_name) == 0) { return errors::InvalidArgument( "Creating a partition for ", local_partition_name, " which doesn't exist in the list of available devices. Available " "devices: ", absl::StrJoin(device_names, ",")); } } for (auto& partition : partitions) { std::unique_ptr<Graph> device_graph( new Graph(client_graph->flib_def.get())); GraphConstructorOptions device_opts; // There are internal operations (e.g., send/recv) that we now allow. device_opts.allow_internal_ops = true; device_opts.expect_device_spec = true; TF_RETURN_IF_ERROR(ConvertGraphDefToGraph( device_opts, std::move(partition.second), device_graph.get())); outputs->emplace(partition.first, std::move(device_graph)); } GraphOptimizationPassOptions optimization_options; optimization_options.session_options = &options_; optimization_options.flib_def = client_graph->flib_def.get(); optimization_options.partition_graphs = outputs; TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( OptimizationPassRegistry::POST_PARTITIONING, optimization_options)); Status s; for (auto& partition : *outputs) { const string& partition_name = partition.first; std::unique_ptr<Graph>* graph = &partition.second; VLOG(2) << "Created " << DebugString(graph->get()) << " for " << partition_name; // Give the device an opportunity to rewrite its subgraph. Device* d; s = device_mgr_->LookupDevice(partition_name, &d); if (!s.ok()) break; s = d->MaybeRewriteGraph(graph); if (!s.ok()) { break; } } *flib_def = std::move(client_graph->flib_def); std::swap(*input_types, client_graph->feed_types); std::swap(*output_types, client_graph->fetch_types); return s; } |
编译计算图
graph_execution_state.cc中的方法BuildGraph进行计算图的编译。其中优化图OptimizeGraph部分见TensorFlow计算图的优化。
Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options, std::unique_ptr<ClientGraph>* out) { VLOG(1) << "BuildGraph"; const uint64 start_time_usecs = Env::Default()->NowMicros(); if (!graph_) { // It is only valid to call this method directly when the original graph // was created with the option `place_pruned_graph == false`. return errors::Internal( "Attempted to prune a graph that has not been fully initialized."); } // Grappler optimization might change the structure of a graph itself, and // also it can add/prune functions to/from the library. std::unique_ptr<Graph> optimized_graph; std::unique_ptr<FunctionLibraryDefinition> optimized_flib; Status s = OptimizeGraph(options, &optimized_graph, &optimized_flib); if (!s.ok()) { VLOG(2) << "Grappler optimization failed. Error: " << s.error_message(); // Simply copy the original graph and the function library if we couldn't // optimize it. optimized_graph.reset(new Graph(flib_def_.get())); CopyGraph(*graph_, optimized_graph.get()); optimized_flib.reset(new FunctionLibraryDefinition(*flib_def_)); } subgraph::RewriteGraphMetadata rewrite_metadata; if (session_options_ == nullptr || !session_options_->config.graph_options().place_pruned_graph()) { TF_RETURN_IF_ERROR( PruneGraph(options, optimized_graph.get(), &rewrite_metadata)); } else { // This GraphExecutionState represents a graph that was // pruned when this was constructed, so we copy the metadata from // a member variable. CHECK(rewrite_metadata_); rewrite_metadata = *rewrite_metadata_; } CHECK_EQ(options.callable_options.feed_size(), rewrite_metadata.feed_types.size()); CHECK_EQ(options.callable_options.fetch_size(), rewrite_metadata.fetch_types.size()); // TODO(andydavis): Clarify optimization pass requirements around CostModel. GraphOptimizationPassOptions optimization_options; optimization_options.session_options = session_options_; optimization_options.graph = &optimized_graph; optimization_options.flib_def = optimized_flib.get(); optimization_options.device_set = device_set_; TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options)); int64 collective_graph_key = options.collective_graph_key; if (collective_graph_key == BuildGraphOptions::kNoCollectiveGraphKey) { // BuildGraphOptions does not specify a collective_graph_key. Check all // nodes in the Graph and FunctionLibraryDefinition for collective ops and // if found, initialize a collective_graph_key as a hash of the ordered set // of instance keys. std::set<int32> instance_key_set; for (Node* node : optimized_graph->nodes()) { if (node->IsCollective()) { int32 instance_key; TF_RETURN_IF_ERROR( GetNodeAttr(node->attrs(), "instance_key", &instance_key)); instance_key_set.emplace(instance_key); } else { const FunctionDef* fdef = optimized_flib->Find(node->def().op()); if (fdef != nullptr) { for (const NodeDef& ndef : fdef->node_def()) { if (ndef.op() == "CollectiveReduce" || ndef.op() == "CollectiveBcastSend" || ndef.op() == "CollectiveBcastRecv" || ndef.op() == "CollectiveGather") { int32 instance_key; TF_RETURN_IF_ERROR( GetNodeAttr(ndef, "instance_key", &instance_key)); instance_key_set.emplace(instance_key); } } } } } if (!instance_key_set.empty()) { uint64 hash = 0x8774aa605c729c72ULL; for (int32 instance_key : instance_key_set) { hash = Hash64Combine(instance_key, hash); } collective_graph_key = hash; } } // Make collective execution order deterministic if needed. if (options.collective_order != GraphCollectiveOrder::kNone) { TF_RETURN_IF_ERROR( OrderCollectives(optimized_graph.get(), options.collective_order)); } // Copy the extracted graph in order to make its node ids dense, // since the local CostModel used to record its stats is sized by // the largest node id. std::unique_ptr<ClientGraph> dense_copy( new ClientGraph(std::move(optimized_flib), rewrite_metadata.feed_types, rewrite_metadata.fetch_types, collective_graph_key)); CopyGraph(*optimized_graph, &dense_copy->graph); // TODO(vrv): We should check invariants of the graph here. metrics::UpdateGraphBuildTime(Env::Default()->NowMicros() - start_time_usecs); *out = std::move(dense_copy); return Status::OK(); } |
这里如果不修剪计算图,就直接用前面Extend方法的基本图进行编译。
} else { execution_state = execution_state_.get(); TF_RETURN_IF_ERROR( execution_state->BuildGraph(subgraph_options, &client_graph)); } |
计算图分区
graph_partition.cc中的方法Partition进行计算图的分区,分配不同的节点到对应的设备上。
Status Partition(const PartitionOptions& opts, Graph* g, std::unordered_map<string, GraphDef>* partitions) { Status status; partitions->clear(); GraphInfo g_info; if (!opts.control_flow_added) { // Add the "code" for distributed execution of control flow. Code is // added only for the frames that are placed on multiple devices. The // new graph is an equivalent transformation of the original graph and // has the property that it can be subsequently partitioned arbitrarily // (down to the level of individual device) for distributed execution. status = AddControlFlow(opts, g, &g_info); if (!status.ok()) return status; } // At this point, all the graph mutations have been done. Build memory // and device type info for every node and edge in the graph. status = BuildMemoryDeviceInfo(*g, &g_info); if (!status.ok()) return status; string dstp; std::vector<const Edge*> inputs; DupRecvTable dup_recv(3); // For a node dst, 'ref_recvs' remembers the recvs introduced by a ref // edge to dst. 'ref_control_inputs' remembers the inputs by a non-ref // edge to dst. We will add a control edge for every pair in // (ref_recvs x ref_control_inputs). std::vector<NodeDef*> ref_recvs; std::vector<string> ref_control_inputs; int32 num_data = 0; int32 num_control = 0; for (const Node* dst : g->op_nodes()) { dstp = opts.node_to_loc(dst); GraphDef* dst_graph = &(*partitions)[dstp]; NodeDef* dst_def = dst_graph->add_node(); *dst_def = dst->def(); MergeDebugInfo(NodeDebugInfo(dst->def()), dst_def); dst_def->set_device(dst->assigned_device_name()); dst_def->clear_input(); // Inputs are filled below if (opts.need_to_record_start_times) { int64 start_time; status = GetNodeAttr(*dst_def, "_start_time", &start_time); if (errors::IsNotFound(status)) { start_time = opts.start_times[dst->id()].value(); AddNodeAttr("_start_time", start_time, dst_def); } else if (!status.ok()) { return status; } } // Arrange the incoming edges to dst so that input[i] holds the // input flowing into slot numbered i. Trailing entries in input[] // hold control edges. inputs.clear(); inputs.resize(dst->num_inputs(), nullptr); ref_recvs.clear(); ref_control_inputs.clear(); const Edge* control_flow_edge = nullptr; int32 num_control_flow_edges = 0; int32 num_input_edges = 0; for (const Edge* edge : dst->in_edges()) { if (edge->IsControlEdge()) { if (IsMerge(edge->src()) && IsControlLoop(edge->src())) { // This is one of the control edges added for control flow. There // can be multiple such edges as the dest node may have multiple // remote inputs. We keep track of the number of such edges. control_flow_edge = edge; ++num_control_flow_edges; } else { inputs.push_back(edge); } } else { DCHECK(inputs[edge->dst_input()] == nullptr); inputs[edge->dst_input()] = edge; ++num_input_edges; } } if (num_input_edges != dst->num_inputs()) { return errors::InvalidArgument("Incomplete graph, missing ", (dst->num_inputs() - num_input_edges), " inputs for ", dst->name()); } // Process in order so that all data edges are added as inputs to // dst in Edge::dst_input() order. for (const Edge* edge : inputs) { const Node* src = edge->src(); if (!src->IsOp()) continue; // Skip Sink/Source nodes. GraphDef* src_graph = &(*partitions)[opts.node_to_loc(src)]; if (src_graph == dst_graph && !NeedSameDeviceSendRecv(edge, g_info)) { // Same partition and compatible memory types: AddInput(dst_def, src->name(), edge->src_output()); if (edge->IsControlEdge() || !IsRefType(src->output_type(edge->src_output()))) { ref_control_inputs.push_back(src->name()); } continue; } int64 send_start_time = 0; int64 recv_start_time = 0; if (opts.scheduling_for_recvs) { status = GetNodeAttr(src->attrs(), "_start_time", &send_start_time); if (errors::IsNotFound(status) && opts.need_to_record_start_times) { send_start_time = opts.start_times[src->id()].value(); } else if (!status.ok()) { return status; } status = GetNodeAttr(dst->attrs(), "_start_time", &recv_start_time); if (errors::IsNotFound(status) && opts.need_to_record_start_times) { recv_start_time = opts.start_times[dst->id()].value(); } else if (!status.ok()) { return status; } } // Check whether there is already a send/recv pair transferring // the same tensor/control from the src to dst partition. const bool on_host = IsDstInputOnHost(edge, g_info); DupRecvKey key{src->id(), edge->src_output(), dst_graph, on_host}; auto iter = dup_recv.find(key); if (iter != dup_recv.end()) { // We found one. Reuse the data/control transferred already. const string& recv_node_name = iter->second.recv->name(); if (edge->IsControlEdge()) { AddInput(dst_def, recv_node_name, Graph::kControlSlot); } else { AddInput(dst_def, recv_node_name, 0); } ref_control_inputs.push_back(recv_node_name); // We want the start_time for the recv to be the smallest of the start // times of it's consumers. So we update this whenever we use a recv, // and write it out to the attribute at the end of the subroutine if (iter->second.start_time > recv_start_time) { iter->second.start_time = recv_start_time; } continue; } NodeDefBuilder::NodeOut send_from; if (edge->IsControlEdge()) { // Insert a dummy const node that will generate a tiny // data element to be sent from send to recv. VLOG(1) << "Send/Recv control: " << src->assigned_device_name() << "[" << src->name() << "] -> " << dst->assigned_device_name() << "[" << dst->name() << "]"; NodeDef* dummy = AddDummyConst(opts, src_graph, edge, &status); if (!status.ok()) return status; // Set the start time for this dummy node. if (opts.scheduling_for_recvs) { AddNodeAttr("_start_time", send_start_time, dummy); } AddInput(dummy, src->name(), Graph::kControlSlot); send_from.Reset(dummy->name(), 0, DT_FLOAT); } else { send_from.Reset(src->name(), edge->src_output(), EdgeType(edge)); } // Need to split edge by placing matching send/recv nodes on // the src/dst sides of the edge. NodeDef* send = AddSend(opts, g_info, src_graph, edge, send_from, send_start_time, &status); if (!status.ok()) return status; NodeDef* real_recv = nullptr; NodeDef* recv = AddRecv(opts, g_info, dst_graph, edge, &real_recv, &status); if (!status.ok()) return status; // Fix up the control flow edge. // NOTE(yuanbyu): 'real_recv' must be the real recv node. if (src_graph == dst_graph) { // For same device send/recv, add a control edge from send to recv. // This prevents the asynchronous recv kernel from being scheduled // before the data is available. AddInput(real_recv, send->name(), Graph::kControlSlot); } else if (control_flow_edge != nullptr) { // Redirect control edge to the real recv since this is not the same // device send/recv. --num_control_flow_edges; AddInput(real_recv, control_flow_edge->src()->name(), Graph::kControlSlot); } if (!edge->IsControlEdge() && IsRefType(src->output_type(edge->src_output()))) { AddNodeAttr("_start_time", recv_start_time, recv); if (real_recv != recv) { AddNodeAttr("_start_time", recv_start_time, real_recv); } // If src is of ref type and the edge is not a control edge, dst has // read semantics and therefore we must control the recv. ref_recvs.push_back(real_recv); } else { // Memorize the send/recv pair, only if this is not a "ref" edge. // NOTE(yuanbyu): Collapsing ref edges requires extreme care so // for now we don't do it. dup_recv[key] = {recv, real_recv, recv_start_time}; ref_control_inputs.push_back(recv->name()); } if (edge->IsControlEdge()) { ++num_control; AddInput(dst_def, recv->name(), Graph::kControlSlot); } else { ++num_data; AddInput(dst_def, recv->name(), 0); } } // Add control edges from 'ref_control_inputs' to 'ref_recvs'. // NOTE(yuanbyu): Adding these control edges should not introduce // deadlocks. 'dst' has implicit "read" nodes that, when we split // across devices, are made explicit; Retargeting the dependencies // to 'dst' to those nodes would not introduce cycles if there isn't // one before the transformation. // NOTE(yuanbyu): This may impact performance because it defers the // execution of recvs until all the other inputs become available. AddReadControl(ref_recvs, ref_control_inputs); // Add back the control edges for control flow that are not used. if (control_flow_edge != nullptr) { for (int i = 0; i < num_control_flow_edges; ++i) { AddInput(dst_def, control_flow_edge->src()->name(), Graph::kControlSlot); } } } const FunctionLibraryDefinition* flib_def = opts.flib_def; if (flib_def == nullptr) { flib_def = &g->flib_def(); } // Set versions, function library and send/recv incarnation. for (auto& it : *partitions) { GraphDef* gdef = &it.second; *gdef->mutable_versions() = g->versions(); // Prune unreachable functions from `flib_def` before adding them to `gdef`. *gdef->mutable_library() = flib_def->ReachableDefinitions(*gdef).ToProto(); // Traverse the graph to fill every send/recv op's incarnation // information. SetIncarnation(opts, gdef); } // Set the start times for recvs at the very end. if (opts.scheduling_for_recvs) { for (auto& it : dup_recv) { AddNodeAttr("_start_time", it.second.start_time, it.second.recv); if (it.second.real_recv != it.second.recv) { AddNodeAttr("_start_time", it.second.start_time, it.second.real_recv); } } } VLOG(1) << "Added send/recv: controls=" << num_control << ", data=" << num_data; if (VLOG_IS_ON(2)) { for (auto& it : *partitions) { GraphDef* gdef = &it.second; DumpGraphDefToFile(strings::StrCat("partition_", it.first, "_", reinterpret_cast<uintptr_t>(gdef)), *gdef); } } return Status::OK(); } |
本作品采用知识共享署名 4.0 国际许可协议进行许可。