TensorFlow op和op kernel的注册

原创文章,转载请注明: 转载自慢慢的回味

本文链接地址: TensorFlow op和op kernel的注册


回目录

生成ops的定义

TensorFlow Library在加载的时候,其中so里面的静态变量会实例化。
以AddN这个op为例,在文件math_ops.cc里面的REGISTER_OP(“AddN”)其实是个静态变量定义:

REGISTER_OP("AddN")
    .Input("inputs: N * T")
    .Output("sum: T")
    .Attr("N: int >= 1")
    .Attr("T: {numbertype, variant}")
    .SetIsCommutative()
    .SetIsAggregate()
    .SetShapeFn([](InferenceContext* c) {
      ShapeHandle cur = c->input(c->num_inputs() - 1);
      ............
    });


REGISTER_OP(“AddN”)展开后为:

static ::tensorflow::register_op::OpDefBuilderReceiver register_op165    \
      __attribute__((unused)) =                                                  \
          ::tensorflow::register_op::OpDefBuilderWrapper<true>("AddN")

在文件op.cc中有OpDefBuilderReceiver的构造方法:

namespace register_op {
OpDefBuilderReceiver::OpDefBuilderReceiver(
    const OpDefBuilderWrapper<true>& wrapper) {
  OpRegistry::Global()->Register(
      [wrapper](OpRegistrationData* op_reg_data) -> Status {
        return wrapper.builder().Finalize(op_reg_data);
      });
}
}

通过文件op.cc中的OpRegistry::Register调用OpRegistry::RegisterAlreadyLocked来使op_data_factory产生op的定义OpRegistrationData。

static OpRegistry* OpRegistry::Global() {
  static OpRegistry* global_op_registry = new OpRegistry;
  return global_op_registry;
}
 
void OpRegistry::Register(const OpRegistrationDataFactory& op_data_factory) {
  mutex_lock lock(mu_);
  if (initialized_) {
    TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory));
  } else {
    deferred_.push_back(op_data_factory);
  }
}
 
mutable std::unordered_map<string, const OpRegistrationData*> registry_;
 
Status OpRegistry::RegisterAlreadyLocked(
    const OpRegistrationDataFactory& op_data_factory) const {
  std::unique_ptr<OpRegistrationData> op_reg_data(new OpRegistrationData);
  Status s = op_data_factory(op_reg_data.get());
  if (s.ok()) {
    s = ValidateOpDef(op_reg_data->op_def);
    if (s.ok() &&
        !gtl::InsertIfNotPresent(&registry_, op_reg_data->op_def.name(),
                                 op_reg_data.get())) {
      s = errors::AlreadyExists("Op with name ", op_reg_data->op_def.name());
    }
  }
  Status watcher_status = s;
  if (watcher_) {
    watcher_status = watcher_(s, op_reg_data->op_def);
  }
  if (s.ok()) {
    op_reg_data.release();
  } else {
    op_reg_data.reset();
  }
  return watcher_status;
}

通过文件op_def_builder.cc中的OpDefBuilder::Finalize来给OpRegistrationData赋值。

Status OpDefBuilder::Finalize(OpRegistrationData* op_reg_data) const {
  std::vector<string> errors = errors_;
  *op_reg_data = op_reg_data_;
 
  OpDef* op_def = &op_reg_data->op_def;
  for (StringPiece attr : attrs_) {
    FinalizeAttr(attr, op_def, &errors);
  }
  for (StringPiece input : inputs_) {
    FinalizeInputOrOutput(input, false, op_def, &errors);
  }
  for (StringPiece output : outputs_) {
    FinalizeInputOrOutput(output, true, op_def, &errors);
  }
  for (StringPiece control_output : control_outputs_) {
    FinalizeControlOutput(control_output, op_def, &errors);
  }
  FinalizeDoc(doc_, op_def, &errors);
 
  if (errors.empty()) return Status::OK();
  return errors::InvalidArgument(absl::StrJoin(errors, "\n"));
}

以后使用的时候就可以通过文件op.cc中的OpRegistry::LookUp获取到op type对应的定义。

Status OpRegistry::LookUp(const string& op_type_name,
                          const OpRegistrationData** op_reg_data) const {
  {
    tf_shared_lock l(mu_);
    if (initialized_) {
      if (const OpRegistrationData* res =
              gtl::FindWithDefault(registry_, op_type_name, nullptr)) {
        *op_reg_data = res;
        return Status::OK();
      }
    }
  }
  return LookUpSlow(op_type_name, op_reg_data);
}
生成op kernels的factory定义

同样,也是在TensorFlow Library加载的时候。
还是以AddN这个op为例,在文件aggregate_ops.cc里面,关于AddNOp这个op_kernel相关代码如下:

template <typename Device>
class AddNOp<Device, Variant> : public OpKernel {
 public:
  explicit AddNOp(OpKernelConstruction* context) : OpKernel(context) {}
 
  void Compute(OpKernelContext* ctx) override {
    if (!ctx->ValidateInputsAreSameShape(this)) return;
 
    ............
    ctx->set_output(0, out);
  }
};
 
#define REGISTER_ADDN(type, dev)                                   \
  REGISTER_KERNEL_BUILDER(                                         \
      Name("AddN").Device(DEVICE_##dev).TypeConstraint<type>("T"), \
      AddNOp<dev##Device, type>)
 
#define REGISTER_ADDN_CPU(type) REGISTER_ADDN(type, CPU)
 
TF_CALL_NUMBER_TYPES(REGISTER_ADDN_CPU);
REGISTER_ADDN_CPU(Variant);
 
#undef REGISTER_ADDN_CPU

TF_CALL_NUMBER_TYPES宏展开后如下,它会对所以的数据类型都分别创建一个OpKernelRegistrar对象:

constexpr bool should_register_389__flag =                      \
      true;                        \
  static ::tensorflow::kernel_factory::OpKernelRegistrar              \
      registrar__body__389__object(                               \
          should_register_389__flag                               \
              ? ::tensorflow::register_kernel::Name("AddN").Device(DEVICE_CPU).TypeConstraint<::tensorflow::int64>("T").Build() \
              : nullptr,                                              \
          "AddNOp<CPUDevice, ::tensorflow::int64>",                                               \
          [](::tensorflow::OpKernelConstruction* context)             \
              -> ::tensorflow::OpKernel* {                            \
            return new AddNOp<CPUDevice, ::tensorflow::int64>(context);                          \
          }); 
............

op_kernel.h中有上面OpKernelRegistrar的构造函数定义:
::tensorflow::register_kernel::Name(“AddN”).Device(DEVICE_CPU).TypeConstraint<::tensorflow::int64>(“T”).Build()对应KernelDef类型,
lambda表达式[](::tensorflow::OpKernelConstruction* context) -> ::tensorflow::OpKernel* { return new AddNOp(context); }为生成AddNOp这个类型kernel的factory函数,对应参数create_fn。

class OpKernelRegistrar {
 public:
............
 
  // Registers the given factory function with TensorFlow. This is equivalent
  // to registering a factory whose Create function invokes `create_fn`.
  OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
                    OpKernel* (*create_fn)(OpKernelConstruction*)) {
    // Perform the check in the header to allow compile-time optimization
    // to a no-op, allowing the linker to remove the kernel symbols.
    if (kernel_def != nullptr) {
      InitInternal(kernel_def, kernel_class_name,
                   absl::make_unique<PtrOpKernelFactory>(create_fn));
    }
  }

文件op_kernel.cc中的OpKernelRegistrar::InitInternal完成把新创建的KernelRegistration对象插入到全局的global_registry->registry map中去。
后面如果计算图中如果需要它,会通过kernel的key来获取对象KernelRegistration,然后调用它的factory来产生AddNOp实例,然后赋值给节点node。

struct KernelRegistry {
  mutex mu;
  std::unordered_multimap<string, KernelRegistration> registry GUARDED_BY(mu);
};
 
void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
                                     StringPiece kernel_class_name,
                                     std::unique_ptr<OpKernelFactory> factory) {
  // See comments in register_kernel::Name in header for info on _no_register.
  if (kernel_def->op() != "_no_register") {
    const string key =
        Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
            kernel_def->label());
 
    // To avoid calling LoadDynamicKernels DO NOT CALL GlobalKernelRegistryTyped
    // here.
    // InitInternal gets called by static initializers, so it ends up executing
    // before main. This causes LoadKernelLibraries function to get called
    // before some file libraries can initialize, which in turn crashes the
    // program flakily. Until we get rid of static initializers in kernel
    // registration mechanism, we have this workaround here.
    auto global_registry =
        reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry());
    mutex_lock l(global_registry->mu);
    global_registry->registry.emplace(
        key,
        KernelRegistration(*kernel_def, kernel_class_name, std::move(factory)));
  }
  delete kernel_def;
}

本作品采用知识共享署名 4.0 国际许可协议进行许可。

发表回复