Tensorflow 函数式编程的测试

原创文章,转载请注明: 转载自慢慢的回味

本文链接地址: Tensorflow 函数式编程的测试

本文通过一个简单的程序来验证Tensorflow能够进行函数式编程的底层支持。从此也说明,Tensorflow在2.0发布之前的1.15版本底层已经支持函数式编程,进而能够实现动态图计算。
回目录

函数式编程测试程序

如下验证了用Tensorflow的函数式编程实现如下代码逻辑:

int feed1 = 2;
int feed2 = 10;
int feed3 = 3;
while( feed1 < feed2 ){
  feed1 = feed1 + 3;
}
feed1 = feed1 + feed3

//============================================================================
// Name        : TensorflowTest.cpp
// Author      : 
// Version     :
// Copyright   : Your copyright notice
// Description : Hello World in C++, Ansi-style
//============================================================================
 
#include <iostream>
#include <tensorflow/c/c_api.h>
#include <tensorflow/c/c_test_util.h>
#include <tensorflow/c/c_api_experimental.h>
 
#include <algorithm>
#include <cstddef>
#include <iterator>
#include <memory>
#include <vector>
#include <string.h>
 
using namespace std;
 
static const char** ToArray(const vector<string>& strs) {
	const char** ptr = nullptr;
	if (!strs.empty()) {
		ptr = new const char*[strs.size()];
		for (size_t i = 0; i < strs.size(); ++i) {
			ptr[i] = strs[i].c_str();
		}
	}
	return ptr;
}
 
static vector<TF_Output> ToOutput(const vector<TF_Operation*> ops) {
	vector<TF_Output> out;
	for (auto op : ops) {
		out.push_back( { op, 0 });
	}
	return out;
}
 
int main() {
	cout << "!!!Hello World!!! FunctionTest.cpp " << endl; // prints !!!Hello World!!!
	cout << "Hello from TensorFlow C library version" << TF_Version() << endl;
 
	TF_Status* s = TF_NewStatus();
	TF_Graph* func_graph_ = TF_NewGraph();
	TF_Graph* host_graph_ = TF_NewGraph();
 
	TF_Operation* feed1 = Placeholder(func_graph_, s, "feed1");
	cout << TF_Message(s);
	TF_Operation* feed2 = Placeholder(func_graph_, s, "feed2");
	cout << TF_Message(s);
	TF_Operation* feed3 = Placeholder(func_graph_, s, "feed3");
	cout << TF_Message(s);
 
	//========================While=================================
	// Add while loop to func_graph_
 
	// The inputs to the while loop
	vector<TF_Output> whileInputs = ToOutput( { feed1, feed2 });
	vector<TF_Output> whileOutputs;
	unique_ptr<TF_WhileParams> params(
			new TF_WhileParams(
					TF_NewWhile(func_graph_, &whileInputs[0],
							whileInputs.size(), s)));
	cout << TF_Message(s);
	params->name = "test_loop";
 
	// Initialize outputs so we can easily detect errors/bugs
	whileOutputs.resize(2, { nullptr, -1 });
 
	// Create loop: while (input1 < input2) input1 += input2 + 1
	TF_Operation* less_than = LessThan(params->cond_inputs[0],
			params->cond_inputs[1], params->cond_graph, s);
	cout << TF_Message(s);
	params->cond_output = {less_than, 0};
 
	TF_Operation* three = ScalarConst(3, params->body_graph, s);
	cout << TF_Message(s);
	TF_Operation* add1 = Add(params->body_inputs[0], { three, 0 },
			params->body_graph, s, "add1");
	cout << TF_Message(s);
	params->body_outputs[0] = {add1, 0};
	params->body_outputs[1] = params->body_inputs[1];
 
	// Finalize while loop
	TF_FinishWhile(params.get(), s, &whileOutputs[0]);
	cout << TF_Message(s);
	//========================While=================================
 
	TF_Operation* add2 = Add(whileOutputs[0], { feed3, 0 }, func_graph_, s,
			"add2");
	cout << TF_Message(s);
 
	vector<TF_Output> inputs = ToOutput( { feed1, feed2, feed3 });
	vector<TF_Output> outputs = ToOutput( { add2 });
 
	const char** output_names_ptr = ToArray( { });
	TF_Function* func_ = TF_GraphToFunction(func_graph_, "myFunction", false,
			-1, nullptr, inputs.size(), inputs.data(), outputs.size(),
			outputs.data(), output_names_ptr,
			/*opts=*/nullptr, /*description=*/nullptr, s);
	delete[] output_names_ptr;
	TF_GraphCopyFunction(host_graph_, func_, nullptr, s);
 
	size_t* len = new size_t();
	cout << "func_graph_ string==============" << endl;
	cout << TF_GraphDebugString(func_graph_, len) << endl;
	cout << "func_ string====================" << endl;
	cout << TF_FunctionDebugString(func_, len) << endl;
	cout << "host_graph_ string==============" << endl;
	cout << TF_GraphDebugString(host_graph_, len) << endl;
 
	// Use, run, and verify
	TF_Operation* two = ScalarConst(2, host_graph_, s, "two");
	TF_Operation* ten = ScalarConst(10, host_graph_, s, "ten");
	TF_Operation* func_feed = Placeholder(host_graph_, s);
	TF_OperationDescription* desc = TF_NewOperation(host_graph_, "myFunction",
			"myFunction_node");
	TF_AddInput(desc, TF_Output { two, 0 });
	TF_AddInput(desc, TF_Output { ten, 0 });
	TF_AddInput(desc, TF_Output { func_feed, 0 });
	TF_Operation* func_op = TF_FinishOperation(desc, s);
 
	TF_SessionOptions* opts = TF_NewSessionOptions();
	TF_Session* sess = TF_NewSession(host_graph_, opts, s);
	TF_DeleteSessionOptions(opts);
 
	TF_Output feeds[] = { TF_Output { func_feed, 0 } };
	TF_Output fetches[] = { TF_Output { func_op, 0 } };
 
	const char* handle = nullptr;
	TF_SessionPRunSetup(sess, feeds, TF_ARRAYSIZE(feeds), fetches,
			TF_ARRAYSIZE(fetches), NULL, 0, &handle, s);
	cout << TF_Message(s);
 
	TF_Output feeds1[] = { TF_Output { func_feed, 0 } };
	TF_Output fetches1[] = { TF_Output { func_op, 0 } };
	TF_Tensor* feedValues1[] = { Int32Tensor(3) };
	TF_Tensor* fetchValues1[1];
	TF_SessionPRun(sess, handle, feeds1, feedValues1, 1, fetches1, fetchValues1,
			1, NULL, 0, s);
	cout << TF_Message(s);
	cout << "Result for logic:" << endl;
	cout << "int feed1 = 2;" << endl;
	cout << "int feed2 = 10;" << endl;
	cout << "int feed3 = 3;" << endl;
	cout << "while( feed1 < feed2 ){" << endl;
	cout << "  feed1 = feed1 + 3;" << endl;
	cout << "}" << endl;
	cout << "feed1 = feed1 + feed3" << endl;
	cout << *(static_cast<int*>(TF_TensorData(fetchValues1[0]))) << endl;
 
	// Clean up.
	TF_DeletePRunHandle(handle);
	TF_DeleteSession(sess, s);
	cout << TF_Message(s);
	TF_DeleteGraph(host_graph_);
	TF_DeleteStatus(s);
	return 0;
}

本作品采用知识共享署名 4.0 国际许可协议进行许可。

发表回复