原创文章,转载请注明: 转载自慢慢的回味
本文链接地址: TensorFlow op和op kernel的注册
回目录
生成ops的定义
TensorFlow Library在加载的时候,其中so里面的静态变量会实例化。
以AddN这个op为例,在文件math_ops.cc里面的REGISTER_OP(“AddN”)其实是个静态变量定义:
REGISTER_OP("AddN") .Input("inputs: N * T") .Output("sum: T") .Attr("N: int >= 1") .Attr("T: {numbertype, variant}") .SetIsCommutative() .SetIsAggregate() .SetShapeFn([](InferenceContext* c) { ShapeHandle cur = c->input(c->num_inputs() - 1); ............ }); |
REGISTER_OP(“AddN”)展开后为:
static ::tensorflow::register_op::OpDefBuilderReceiver register_op165 \ __attribute__((unused)) = \ ::tensorflow::register_op::OpDefBuilderWrapper<true>("AddN") |
在文件op.cc中有OpDefBuilderReceiver的构造方法:
namespace register_op { OpDefBuilderReceiver::OpDefBuilderReceiver( const OpDefBuilderWrapper<true>& wrapper) { OpRegistry::Global()->Register( [wrapper](OpRegistrationData* op_reg_data) -> Status { return wrapper.builder().Finalize(op_reg_data); }); } } |
通过文件op.cc中的OpRegistry::Register调用OpRegistry::RegisterAlreadyLocked来使op_data_factory产生op的定义OpRegistrationData。
static OpRegistry* OpRegistry::Global() { static OpRegistry* global_op_registry = new OpRegistry; return global_op_registry; } void OpRegistry::Register(const OpRegistrationDataFactory& op_data_factory) { mutex_lock lock(mu_); if (initialized_) { TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory)); } else { deferred_.push_back(op_data_factory); } } mutable std::unordered_map<string, const OpRegistrationData*> registry_; Status OpRegistry::RegisterAlreadyLocked( const OpRegistrationDataFactory& op_data_factory) const { std::unique_ptr<OpRegistrationData> op_reg_data(new OpRegistrationData); Status s = op_data_factory(op_reg_data.get()); if (s.ok()) { s = ValidateOpDef(op_reg_data->op_def); if (s.ok() && !gtl::InsertIfNotPresent(®istry_, op_reg_data->op_def.name(), op_reg_data.get())) { s = errors::AlreadyExists("Op with name ", op_reg_data->op_def.name()); } } Status watcher_status = s; if (watcher_) { watcher_status = watcher_(s, op_reg_data->op_def); } if (s.ok()) { op_reg_data.release(); } else { op_reg_data.reset(); } return watcher_status; } |
通过文件op_def_builder.cc中的OpDefBuilder::Finalize来给OpRegistrationData赋值。
Status OpDefBuilder::Finalize(OpRegistrationData* op_reg_data) const { std::vector<string> errors = errors_; *op_reg_data = op_reg_data_; OpDef* op_def = &op_reg_data->op_def; for (StringPiece attr : attrs_) { FinalizeAttr(attr, op_def, &errors); } for (StringPiece input : inputs_) { FinalizeInputOrOutput(input, false, op_def, &errors); } for (StringPiece output : outputs_) { FinalizeInputOrOutput(output, true, op_def, &errors); } for (StringPiece control_output : control_outputs_) { FinalizeControlOutput(control_output, op_def, &errors); } FinalizeDoc(doc_, op_def, &errors); if (errors.empty()) return Status::OK(); return errors::InvalidArgument(absl::StrJoin(errors, "\n")); } |
以后使用的时候就可以通过文件op.cc中的OpRegistry::LookUp获取到op type对应的定义。
Status OpRegistry::LookUp(const string& op_type_name, const OpRegistrationData** op_reg_data) const { { tf_shared_lock l(mu_); if (initialized_) { if (const OpRegistrationData* res = gtl::FindWithDefault(registry_, op_type_name, nullptr)) { *op_reg_data = res; return Status::OK(); } } } return LookUpSlow(op_type_name, op_reg_data); } |
生成op kernels的factory定义
同样,也是在TensorFlow Library加载的时候。
还是以AddN这个op为例,在文件aggregate_ops.cc里面,关于AddNOp这个op_kernel相关代码如下:
template <typename Device> class AddNOp<Device, Variant> : public OpKernel { public: explicit AddNOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* ctx) override { if (!ctx->ValidateInputsAreSameShape(this)) return; ............ ctx->set_output(0, out); } }; #define REGISTER_ADDN(type, dev) \ REGISTER_KERNEL_BUILDER( \ Name("AddN").Device(DEVICE_##dev).TypeConstraint<type>("T"), \ AddNOp<dev##Device, type>) #define REGISTER_ADDN_CPU(type) REGISTER_ADDN(type, CPU) TF_CALL_NUMBER_TYPES(REGISTER_ADDN_CPU); REGISTER_ADDN_CPU(Variant); #undef REGISTER_ADDN_CPU |
TF_CALL_NUMBER_TYPES宏展开后如下,它会对所以的数据类型都分别创建一个OpKernelRegistrar对象:
constexpr bool should_register_389__flag = \ true; \ static ::tensorflow::kernel_factory::OpKernelRegistrar \ registrar__body__389__object( \ should_register_389__flag \ ? ::tensorflow::register_kernel::Name("AddN").Device(DEVICE_CPU).TypeConstraint<::tensorflow::int64>("T").Build() \ : nullptr, \ "AddNOp<CPUDevice, ::tensorflow::int64>", \ [](::tensorflow::OpKernelConstruction* context) \ -> ::tensorflow::OpKernel* { \ return new AddNOp<CPUDevice, ::tensorflow::int64>(context); \ }); ............ |
op_kernel.h中有上面OpKernelRegistrar的构造函数定义:
::tensorflow::register_kernel::Name(“AddN”).Device(DEVICE_CPU).TypeConstraint<::tensorflow::int64>(“T”).Build()对应KernelDef类型,
lambda表达式[](::tensorflow::OpKernelConstruction* context) -> ::tensorflow::OpKernel* { return new AddNOp
class OpKernelRegistrar { public: ............ // Registers the given factory function with TensorFlow. This is equivalent // to registering a factory whose Create function invokes `create_fn`. OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name, OpKernel* (*create_fn)(OpKernelConstruction*)) { // Perform the check in the header to allow compile-time optimization // to a no-op, allowing the linker to remove the kernel symbols. if (kernel_def != nullptr) { InitInternal(kernel_def, kernel_class_name, absl::make_unique<PtrOpKernelFactory>(create_fn)); } } |
文件op_kernel.cc中的OpKernelRegistrar::InitInternal完成把新创建的KernelRegistration对象插入到全局的global_registry->registry map中去。
后面如果计算图中如果需要它,会通过kernel的key来获取对象KernelRegistration,然后调用它的factory来产生AddNOp实例,然后赋值给节点node。
struct KernelRegistry { mutex mu; std::unordered_multimap<string, KernelRegistration> registry GUARDED_BY(mu); }; void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name, std::unique_ptr<OpKernelFactory> factory) { // See comments in register_kernel::Name in header for info on _no_register. if (kernel_def->op() != "_no_register") { const string key = Key(kernel_def->op(), DeviceType(kernel_def->device_type()), kernel_def->label()); // To avoid calling LoadDynamicKernels DO NOT CALL GlobalKernelRegistryTyped // here. // InitInternal gets called by static initializers, so it ends up executing // before main. This causes LoadKernelLibraries function to get called // before some file libraries can initialize, which in turn crashes the // program flakily. Until we get rid of static initializers in kernel // registration mechanism, we have this workaround here. auto global_registry = reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry()); mutex_lock l(global_registry->mu); global_registry->registry.emplace( key, KernelRegistration(*kernel_def, kernel_class_name, std::move(factory))); } delete kernel_def; } |
本作品采用知识共享署名 4.0 国际许可协议进行许可。