原创文章,转载请注明: 转载自慢慢的回味
本文链接地址: TF Operation的注册
Tensorflow OP的注册
以AddNOp为例说明Operation的注册:
当Tensorflow框架(dll)启动的时候,aggregate_ops.cc的静态方法被调用:
//Code in aggregate_ops.cc #define REGISTER_ADDN(type, dev) \ REGISTER_KERNEL_BUILDER( \ Name("AddN").Device(DEVICE_##dev).TypeConstraint<type>("T"), \ AddNOp<dev##Device, type>) #define REGISTER_ADDN_CPU(type) REGISTER_ADDN(type, CPU) TF_CALL_NUMBER_TYPES(REGISTER_ADDN_CPU); |
宏定义展开后为:
constexpr bool should_register_389__flag = \ true; \ static ::tensorflow::kernel_factory::OpKernelRegistrar \ registrar__body__389__object( \ should_register_389__flag \ ? ::tensorflow::register_kernel::Name("AddN").Device(DEVICE_CPU).TypeConstraint<::tensorflow::int64>("T").Build() \ : nullptr, \ "AddNOp<CPUDevice, ::tensorflow::int64>", \ [](::tensorflow::OpKernelConstruction* context) \ -> ::tensorflow::OpKernel* { \ return new AddNOp<CPUDevice, ::tensorflow::int64>(context); \ }); |
然后::tensorflow::kernel_factory::OpKernelRegistrar的构造函数被调用:
通过注册, 把kernel_def, kernel_class_name, factory注册到一个map里面。
后面创建AddNOp的时候,通过调用PtrOpKernelFactory->Create方法,利用上面的lambda表达式[](::tensorflow::OpKernelConstruction* context) -> ::tensorflow::OpKernel* { return new AddNOp
//Code in op_kernel.h class OpKernelRegistrar { public: OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name, OpKernel* (*create_fn)(OpKernelConstruction*)) { if (kernel_def != nullptr) { struct PtrOpKernelFactory : public OpKernelFactory { explicit PtrOpKernelFactory( OpKernel* (*create_func)(OpKernelConstruction*)) : create_func_(create_func) {} OpKernel* Create(OpKernelConstruction* context) override { return (*create_func_)(context); } OpKernel* (*create_func_)(OpKernelConstruction*); }; InitInternal(kernel_def, kernel_class_name, absl::make_unique<PtrOpKernelFactory>(create_fn)); } } //Code in op_kernel.cc void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name, std::unique_ptr<OpKernelFactory> factory) { if (kernel_def->op() != "_no_register") { const string key = Key(kernel_def->op(), DeviceType(kernel_def->device_type()), kernel_def->label()); reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry()) ->emplace(key, KernelRegistration(*kernel_def, kernel_class_name, std::move(factory))); } delete kernel_def; } |
注意到GlobalKernelRegistry()是一个Map:
//Code in op_kernel.cc void* GlobalKernelRegistry() { static KernelRegistry* global_kernel_registry = new KernelRegistry; return global_kernel_registry; } typedef std::unordered_multimap<string, KernelRegistration> KernelRegistry; |
本作品采用知识共享署名 4.0 国际许可协议进行许可。