TF Operation的创建

原创文章,转载请注明: 转载自慢慢的回味

本文链接地址: TF Operation的创建

Tensorflow创建OP的过程

以AddNOp为例说明Operation怎样从ops的定义创建具体的kernel实例。
在Tensorflow Excecutor初始化的时候,会迭代计算图中的所有节点,对每个节点的Operation进行创建。如下方法params_.create_kernel(n->def(), &item->kernel):
Tensorflow源码解读

// Code in executor.cc 
Status ExecutorImpl::Initialize() {
......
  for (const Node* n : graph_->nodes()) {
    const int id = n->id();
    const string& frame_name = cf_info.frame_names[id];
    FrameInfo* frame_info = EnsureFrameInfo(frame_name);
 
    // See if this node is a root node, and if so, add to root_nodes_.
    if (n->in_edges().empty()) {
      root_nodes_.push_back(n);
    }
 
    NodeItem* item = gview_.node(id);
    item->node = n;
 
    item->input_start = frame_info->total_inputs;
    frame_info->total_inputs += n->num_inputs();
 
    Status s = params_.create_kernel(n->def(), &item->kernel);


params_.create_kernel是一个前面创建的lambda函数,对它的调用最后会调用到函数lib->CreateKernel(ndef, kernel)上,lib为FunctionLibraryRuntimeImpl实例:

//Code in direct_session.cc
    params.create_kernel = [this, lib, opseg](const NodeDef& ndef,
                                              OpKernel** kernel) {
      // NOTE(mrry): We must not share function kernels (implemented
      // using `CallOp`) between subgraphs, because `CallOp::handle_`
      // is tied to a particular subgraph. Even if the function itself
      // is stateful, the `CallOp` that invokes it is not.
      if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) {
        return lib->CreateKernel(ndef, kernel);
      }
      auto create_fn = [lib, &ndef](OpKernel** kernel) {
        return lib->CreateKernel(ndef, kernel);
      };
      // Kernels created for subgraph nodes need to be cached.  On
      // cache miss, create_fn() is invoked to create a kernel based
      // on the function library here + global op registry.
      return opseg->FindOrCreate(session_handle_, ndef.name(), kernel,
                                 create_fn);
    };

代码CreateKernel最后调用CreateNonCachedKernel:

//Code in function.cc
Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
                                                OpKernel** kernel) {
  return CreateKernel(ndef, base_lib_def_, kernel);
}
 
Status FunctionLibraryRuntimeImpl::CreateKernel(
    const NodeDef& ndef, const FunctionLibraryDefinition* lib_def,
    OpKernel** kernel) {
  // If a custom kernel creator is given, try that.
  Status s;
  if (custom_kernel_creator_) {
    std::unique_ptr<OpKernel> ret;
    s = custom_kernel_creator_(this, ndef, &ret);
    if (s.ok()) {
      *kernel = ret.release();
      return s;
    } else {
      VLOG(2) << "Custom creator error: " << s;
      // Falls through.
      s = Status::OK();
    }
  }
 
  if (lib_def->Find(ndef.op()) == nullptr) {
    // A primitive operation. Creates the registered kernel.
    return CreateNonCachedKernel(device_, this, ndef, graph_def_version_,
                                 kernel);
  }

executor.cc中的CreateNonCachedKernel方法调用op_kernel.cc中的CreateOpKernel方法,通过registration->factory->Create(&context)创建Operation。 其中,registration是通过FindKernelRegistration方法在GlobalKernelRegistry()里面根据名称AddN找到的。registraction->factory就是在注册时创建的PtrOpKernelFactory实例。registraction->factory->Create最后就是调用new AddNOp(context)了。

//Code in executor.cc
Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
                             const NodeDef& ndef, int graph_def_version,
                             OpKernel** kernel) {
  const auto device_type = DeviceType(device->attributes().device_type());
  auto allocator = device->GetAllocator(AllocatorAttributes());
  return CreateOpKernel(device_type, device, allocator, flib, ndef,
                        graph_def_version, kernel);
}
 
//Code in op_kernel.cc
Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
                      Allocator* allocator, FunctionLibraryRuntime* flib,
                      const NodeDef& node_def, int graph_def_version,
                      OpKernel** kernel) {
......
  // Look up kernel registration.
  const KernelRegistration* registration;
  bool was_attr_mismatch;
  s = FindKernelRegistration(device_type, node_def, &registration,
                             &was_attr_mismatch);
......
  // Everything needed for OpKernel construction.
  OpKernelConstruction context(
      device_type, device, allocator, &node_def, op_def, flib, inputs,
      input_memory_types, outputs, output_memory_types, graph_def_version, &s);
  *kernel = registration->factory->Create(&context);
  if (!s.ok()) {
    delete *kernel;
    *kernel = nullptr;
  }
  return s;
}

通过AddNOp的构造方法把context参数传入:
Session run以后,在节点调用Operation计算的时候就会调用Compute方法。

//Code in aggregate_ops.cc
template <typename Device, typename T>
class AddNOp : public OpKernel {
 public:
  explicit AddNOp(OpKernelConstruction* context) : OpKernel(context) {}
 
  void Compute(OpKernelContext* ctx) override {
    if (!ctx->ValidateInputsAreSameShape(this)) return;

本作品采用知识共享署名 4.0 国际许可协议进行许可。

发表回复