TF Operation的注册

原创文章,转载请注明: 转载自慢慢的回味

本文链接地址: TF Operation的注册

Tensorflow OP的注册

以AddNOp为例说明Operation的注册:
当Tensorflow框架(dll)启动的时候,aggregate_ops.cc的静态方法被调用:

//Code in aggregate_ops.cc
#define REGISTER_ADDN(type, dev)                                   \
  REGISTER_KERNEL_BUILDER(                                         \
      Name("AddN").Device(DEVICE_##dev).TypeConstraint<type>("T"), \
      AddNOp<dev##Device, type>)
 
#define REGISTER_ADDN_CPU(type) REGISTER_ADDN(type, CPU)
 
TF_CALL_NUMBER_TYPES(REGISTER_ADDN_CPU);


宏定义展开后为:

constexpr bool should_register_389__flag =                      \
      true;                        \
  static ::tensorflow::kernel_factory::OpKernelRegistrar              \
      registrar__body__389__object(                               \
          should_register_389__flag                               \
              ? ::tensorflow::register_kernel::Name("AddN").Device(DEVICE_CPU).TypeConstraint<::tensorflow::int64>("T").Build() \
              : nullptr,                                              \
          "AddNOp<CPUDevice, ::tensorflow::int64>",                                               \
          [](::tensorflow::OpKernelConstruction* context)             \
              -> ::tensorflow::OpKernel* {                            \
            return new AddNOp<CPUDevice, ::tensorflow::int64>(context);                          \
          });

然后::tensorflow::kernel_factory::OpKernelRegistrar的构造函数被调用:
通过注册, 把kernel_def, kernel_class_name, factory注册到一个map里面。
后面创建AddNOp的时候,通过调用PtrOpKernelFactory->Create方法,利用上面的lambda表达式[](::tensorflow::OpKernelConstruction* context) -> ::tensorflow::OpKernel* { return new AddNOp(context); }就能创建AddNOp了。

//Code in op_kernel.h
class OpKernelRegistrar {
 public:
  OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
                    OpKernel* (*create_fn)(OpKernelConstruction*)) {
    if (kernel_def != nullptr) {
      struct PtrOpKernelFactory : public OpKernelFactory {
        explicit PtrOpKernelFactory(
            OpKernel* (*create_func)(OpKernelConstruction*))
            : create_func_(create_func) {}
 
        OpKernel* Create(OpKernelConstruction* context) override {
          return (*create_func_)(context);
        }
 
        OpKernel* (*create_func_)(OpKernelConstruction*);
      };
      InitInternal(kernel_def, kernel_class_name,
                   absl::make_unique<PtrOpKernelFactory>(create_fn));
    }
  }
 
//Code in op_kernel.cc
void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
                                     StringPiece kernel_class_name,
                                     std::unique_ptr<OpKernelFactory> factory) {
  if (kernel_def->op() != "_no_register") {
    const string key =
        Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
            kernel_def->label());
    reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry())
        ->emplace(key, KernelRegistration(*kernel_def, kernel_class_name,
                                          std::move(factory)));
  }
  delete kernel_def;
}

注意到GlobalKernelRegistry()是一个Map:

//Code in op_kernel.cc
void* GlobalKernelRegistry() {
  static KernelRegistry* global_kernel_registry = new KernelRegistry;
  return global_kernel_registry;
}
typedef std::unordered_multimap<string, KernelRegistration> KernelRegistry;

本作品采用知识共享署名 4.0 国际许可协议进行许可。

发表回复