原创文章,转载请注明: 转载自慢慢的回味
本文链接地址: TF Operation的创建
Tensorflow创建OP的过程
以AddNOp为例说明Operation怎样从ops的定义创建具体的kernel实例。
在Tensorflow Excecutor初始化的时候,会迭代计算图中的所有节点,对每个节点的Operation进行创建。如下方法params_.create_kernel(n->def(), &item->kernel):
Tensorflow源码解读
// Code in executor.cc Status ExecutorImpl::Initialize() { ...... for (const Node* n : graph_->nodes()) { const int id = n->id(); const string& frame_name = cf_info.frame_names[id]; FrameInfo* frame_info = EnsureFrameInfo(frame_name); // See if this node is a root node, and if so, add to root_nodes_. if (n->in_edges().empty()) { root_nodes_.push_back(n); } NodeItem* item = gview_.node(id); item->node = n; item->input_start = frame_info->total_inputs; frame_info->total_inputs += n->num_inputs(); Status s = params_.create_kernel(n->def(), &item->kernel); |
params_.create_kernel是一个前面创建的lambda函数,对它的调用最后会调用到函数lib->CreateKernel(ndef, kernel)上,lib为FunctionLibraryRuntimeImpl实例:
//Code in direct_session.cc params.create_kernel = [this, lib, opseg](const NodeDef& ndef, OpKernel** kernel) { // NOTE(mrry): We must not share function kernels (implemented // using `CallOp`) between subgraphs, because `CallOp::handle_` // is tied to a particular subgraph. Even if the function itself // is stateful, the `CallOp` that invokes it is not. if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) { return lib->CreateKernel(ndef, kernel); } auto create_fn = [lib, &ndef](OpKernel** kernel) { return lib->CreateKernel(ndef, kernel); }; // Kernels created for subgraph nodes need to be cached. On // cache miss, create_fn() is invoked to create a kernel based // on the function library here + global op registry. return opseg->FindOrCreate(session_handle_, ndef.name(), kernel, create_fn); }; |
代码CreateKernel最后调用CreateNonCachedKernel:
//Code in function.cc Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, OpKernel** kernel) { return CreateKernel(ndef, base_lib_def_, kernel); } Status FunctionLibraryRuntimeImpl::CreateKernel( const NodeDef& ndef, const FunctionLibraryDefinition* lib_def, OpKernel** kernel) { // If a custom kernel creator is given, try that. Status s; if (custom_kernel_creator_) { std::unique_ptr<OpKernel> ret; s = custom_kernel_creator_(this, ndef, &ret); if (s.ok()) { *kernel = ret.release(); return s; } else { VLOG(2) << "Custom creator error: " << s; // Falls through. s = Status::OK(); } } if (lib_def->Find(ndef.op()) == nullptr) { // A primitive operation. Creates the registered kernel. return CreateNonCachedKernel(device_, this, ndef, graph_def_version_, kernel); } |
executor.cc中的CreateNonCachedKernel方法调用op_kernel.cc中的CreateOpKernel方法,通过registration->factory->Create(&context)创建Operation。 其中,registration是通过FindKernelRegistration方法在GlobalKernelRegistry()里面根据名称AddN找到的。registraction->factory就是在注册时创建的PtrOpKernelFactory实例。registraction->factory->Create最后就是调用new AddNOp
//Code in executor.cc Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib, const NodeDef& ndef, int graph_def_version, OpKernel** kernel) { const auto device_type = DeviceType(device->attributes().device_type()); auto allocator = device->GetAllocator(AllocatorAttributes()); return CreateOpKernel(device_type, device, allocator, flib, ndef, graph_def_version, kernel); } //Code in op_kernel.cc Status CreateOpKernel(DeviceType device_type, DeviceBase* device, Allocator* allocator, FunctionLibraryRuntime* flib, const NodeDef& node_def, int graph_def_version, OpKernel** kernel) { ...... // Look up kernel registration. const KernelRegistration* registration; bool was_attr_mismatch; s = FindKernelRegistration(device_type, node_def, ®istration, &was_attr_mismatch); ...... // Everything needed for OpKernel construction. OpKernelConstruction context( device_type, device, allocator, &node_def, op_def, flib, inputs, input_memory_types, outputs, output_memory_types, graph_def_version, &s); *kernel = registration->factory->Create(&context); if (!s.ok()) { delete *kernel; *kernel = nullptr; } return s; } |
通过AddNOp的构造方法把context参数传入:
Session run以后,在节点调用Operation计算的时候就会调用Compute方法。
//Code in aggregate_ops.cc template <typename Device, typename T> class AddNOp : public OpKernel { public: explicit AddNOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* ctx) override { if (!ctx->ValidateInputsAreSameShape(this)) return; |
本作品采用知识共享署名 4.0 国际许可协议进行许可。