原创文章,转载请注明: 转载自慢慢的回味
本文链接地址: Tensorflow 函数式编程的测试
本文通过一个简单的程序来验证Tensorflow能够进行函数式编程的底层支持。从此也说明,Tensorflow在2.0发布之前的1.15版本底层已经支持函数式编程,进而能够实现动态图计算。
回目录
函数式编程测试程序
如下验证了用Tensorflow的函数式编程实现如下代码逻辑:
int feed1 = 2; int feed2 = 10; int feed3 = 3; while( feed1 < feed2 ){ feed1 = feed1 + 3; } feed1 = feed1 + feed3 |
//============================================================================ // Name : TensorflowTest.cpp // Author : // Version : // Copyright : Your copyright notice // Description : Hello World in C++, Ansi-style //============================================================================ #include <iostream> #include <tensorflow/c/c_api.h> #include <tensorflow/c/c_test_util.h> #include <tensorflow/c/c_api_experimental.h> #include <algorithm> #include <cstddef> #include <iterator> #include <memory> #include <vector> #include <string.h> using namespace std; static const char** ToArray(const vector<string>& strs) { const char** ptr = nullptr; if (!strs.empty()) { ptr = new const char*[strs.size()]; for (size_t i = 0; i < strs.size(); ++i) { ptr[i] = strs[i].c_str(); } } return ptr; } static vector<TF_Output> ToOutput(const vector<TF_Operation*> ops) { vector<TF_Output> out; for (auto op : ops) { out.push_back( { op, 0 }); } return out; } int main() { cout << "!!!Hello World!!! FunctionTest.cpp " << endl; // prints !!!Hello World!!! cout << "Hello from TensorFlow C library version" << TF_Version() << endl; TF_Status* s = TF_NewStatus(); TF_Graph* func_graph_ = TF_NewGraph(); TF_Graph* host_graph_ = TF_NewGraph(); TF_Operation* feed1 = Placeholder(func_graph_, s, "feed1"); cout << TF_Message(s); TF_Operation* feed2 = Placeholder(func_graph_, s, "feed2"); cout << TF_Message(s); TF_Operation* feed3 = Placeholder(func_graph_, s, "feed3"); cout << TF_Message(s); //========================While================================= // Add while loop to func_graph_ // The inputs to the while loop vector<TF_Output> whileInputs = ToOutput( { feed1, feed2 }); vector<TF_Output> whileOutputs; unique_ptr<TF_WhileParams> params( new TF_WhileParams( TF_NewWhile(func_graph_, &whileInputs[0], whileInputs.size(), s))); cout << TF_Message(s); params->name = "test_loop"; // Initialize outputs so we can easily detect errors/bugs whileOutputs.resize(2, { nullptr, -1 }); // Create loop: while (input1 < input2) input1 += input2 + 1 TF_Operation* less_than = LessThan(params->cond_inputs[0], params->cond_inputs[1], params->cond_graph, s); cout << TF_Message(s); params->cond_output = {less_than, 0}; TF_Operation* three = ScalarConst(3, params->body_graph, s); cout << TF_Message(s); TF_Operation* add1 = Add(params->body_inputs[0], { three, 0 }, params->body_graph, s, "add1"); cout << TF_Message(s); params->body_outputs[0] = {add1, 0}; params->body_outputs[1] = params->body_inputs[1]; // Finalize while loop TF_FinishWhile(params.get(), s, &whileOutputs[0]); cout << TF_Message(s); //========================While================================= TF_Operation* add2 = Add(whileOutputs[0], { feed3, 0 }, func_graph_, s, "add2"); cout << TF_Message(s); vector<TF_Output> inputs = ToOutput( { feed1, feed2, feed3 }); vector<TF_Output> outputs = ToOutput( { add2 }); const char** output_names_ptr = ToArray( { }); TF_Function* func_ = TF_GraphToFunction(func_graph_, "myFunction", false, -1, nullptr, inputs.size(), inputs.data(), outputs.size(), outputs.data(), output_names_ptr, /*opts=*/nullptr, /*description=*/nullptr, s); delete[] output_names_ptr; TF_GraphCopyFunction(host_graph_, func_, nullptr, s); size_t* len = new size_t(); cout << "func_graph_ string==============" << endl; cout << TF_GraphDebugString(func_graph_, len) << endl; cout << "func_ string====================" << endl; cout << TF_FunctionDebugString(func_, len) << endl; cout << "host_graph_ string==============" << endl; cout << TF_GraphDebugString(host_graph_, len) << endl; // Use, run, and verify TF_Operation* two = ScalarConst(2, host_graph_, s, "two"); TF_Operation* ten = ScalarConst(10, host_graph_, s, "ten"); TF_Operation* func_feed = Placeholder(host_graph_, s); TF_OperationDescription* desc = TF_NewOperation(host_graph_, "myFunction", "myFunction_node"); TF_AddInput(desc, TF_Output { two, 0 }); TF_AddInput(desc, TF_Output { ten, 0 }); TF_AddInput(desc, TF_Output { func_feed, 0 }); TF_Operation* func_op = TF_FinishOperation(desc, s); TF_SessionOptions* opts = TF_NewSessionOptions(); TF_Session* sess = TF_NewSession(host_graph_, opts, s); TF_DeleteSessionOptions(opts); TF_Output feeds[] = { TF_Output { func_feed, 0 } }; TF_Output fetches[] = { TF_Output { func_op, 0 } }; const char* handle = nullptr; TF_SessionPRunSetup(sess, feeds, TF_ARRAYSIZE(feeds), fetches, TF_ARRAYSIZE(fetches), NULL, 0, &handle, s); cout << TF_Message(s); TF_Output feeds1[] = { TF_Output { func_feed, 0 } }; TF_Output fetches1[] = { TF_Output { func_op, 0 } }; TF_Tensor* feedValues1[] = { Int32Tensor(3) }; TF_Tensor* fetchValues1[1]; TF_SessionPRun(sess, handle, feeds1, feedValues1, 1, fetches1, fetchValues1, 1, NULL, 0, s); cout << TF_Message(s); cout << "Result for logic:" << endl; cout << "int feed1 = 2;" << endl; cout << "int feed2 = 10;" << endl; cout << "int feed3 = 3;" << endl; cout << "while( feed1 < feed2 ){" << endl; cout << " feed1 = feed1 + 3;" << endl; cout << "}" << endl; cout << "feed1 = feed1 + feed3" << endl; cout << *(static_cast<int*>(TF_TensorData(fetchValues1[0]))) << endl; // Clean up. TF_DeletePRunHandle(handle); TF_DeleteSession(sess, s); cout << TF_Message(s); TF_DeleteGraph(host_graph_); TF_DeleteStatus(s); return 0; } |
本作品采用知识共享署名 4.0 国际许可协议进行许可。