基于tensorflow c lib调试的主程序

原创文章,转载请注明: 转载自慢慢的回味

本文链接地址: 基于tensorflow c lib调试的主程序


回目录

主程序明细

下面贴上示例代码,程序完成:kone + ktwo = kthree, A + ktwo = plus2, plus2 + B = plusB, plusB + kthree = plusC。

//============================================================================
// Name        : TensorflowTest.cpp
// Author      : 
// Version     :
// Copyright   : Your copyright notice
// Description : Hello World in C++, Ansi-style
//============================================================================
 
#include <iostream>
#include <tensorflow/c/c_api.h>
#include <tensorflow/c/c_test_util.h>
 
#include <algorithm>
#include <cstddef>
#include <iterator>
#include <memory>
#include <vector>
#include <string.h>
 
using namespace std;
 
int main() {
	cout << "!!!Hello World!!!" << endl; // prints !!!Hello World!!!
	cout << "Hello from TensorFlow C library version" << TF_Version() << endl;
 
	TF_Status* s = TF_NewStatus();
	TF_Graph* graph = TF_NewGraph();
 
	// Construct the graph: A + 2 + B
	TF_Operation* a = Placeholder(graph, s, "A");
	cout << TF_Message(s);
 
	TF_Operation* b = Placeholder(graph, s, "B");
	cout << TF_Message(s);
 
	TF_Operation* one = ScalarConst(1, graph, s, "kone");
	cout << TF_Message(s);
 
	TF_Operation* two = ScalarConst(2, graph, s, "ktwo");
	cout << TF_Message(s);
 
	TF_Operation* three = Add(one, two, graph, s, "kthree");
	cout << TF_Message(s);
 
	TF_Operation* plus2 = Add(a, two, graph, s, "plus2");
	cout << TF_Message(s);
 
	TF_Operation* plusB = Add(plus2, b, graph, s, "plusB");
	cout << TF_Message(s);
 
	TF_Operation* plusC = Add(plusB, three, graph, s, "plusC");
	cout << TF_Message(s);
 
	// Setup a session and a partial run handle.  The partial run will allow
	// computation of A + 2 + B in two phases (calls to TF_SessionPRun):
	// 1. Feed A and get (A+2)
	// 2. Feed B and get (A+2)+B
	TF_SessionOptions* opts = TF_NewSessionOptions();
	TF_Session* sess = TF_NewSession(graph, opts, s);
	TF_DeleteSessionOptions(opts);
 
	TF_Output feeds[] = { TF_Output { a, 0 }, TF_Output { b, 0 } };
	TF_Output fetches[] = { TF_Output { plus2, 0 }, TF_Output { plusB, 0 }, TF_Output { plusC, 0 }  };
 
	const char* handle = nullptr;
	TF_SessionPRunSetup(sess, feeds, TF_ARRAYSIZE(feeds), fetches,
			TF_ARRAYSIZE(fetches), NULL, 0, &handle, s);
	cout << TF_Message(s);
 
	// Feed A and fetch A + 2.
	TF_Output feeds1[] = { TF_Output { a, 0 }, TF_Output { b, 0 } };
	TF_Output fetches1[] = { TF_Output { plus2, 0 }, TF_Output { plusB, 0 }, TF_Output { plusC, 0 } };
	TF_Tensor* feedValues1[] = { Int32Tensor(1), Int32Tensor(3) };
	TF_Tensor* fetchValues1[3];
	TF_SessionPRun(sess, handle, feeds1, feedValues1, 2, fetches1, fetchValues1,
			3, NULL, 0, s);
	cout << TF_Message(s);
	cout << *(static_cast<int*>(TF_TensorData(fetchValues1[0]))) << endl;
	cout << *(static_cast<int*>(TF_TensorData(fetchValues1[1]))) << endl;
	cout << *(static_cast<int*>(TF_TensorData(fetchValues1[2]))) << endl;
 
	// Clean up.
	TF_DeletePRunHandle(handle);
	TF_DeleteSession(sess, s);
	cout << TF_Message(s);
	TF_DeleteGraph(graph);
	TF_DeleteStatus(s);
	return 0;
}
 
</vector></memory></iterator></cstddef></algorithm></iostream>

代码解释

创建一个计算图,后面的节点会在此图上添加节点。

TF_Graph* graph = TF_NewGraph();

创建一个Placeholder节点,在图运行的时候才会赋值。

TF_Operation* a = Placeholder(graph, s, "A");

创建一个常量节点。

TF_Operation* one = ScalarConst(1, graph, s, "kone");

创建一个Add运算节点,用于加上几个输入变量。

TF_Operation* three = Add(one, two, graph, s, "kthree");

使用上面的计算图创建一个会话。

TF_SessionOptions* opts = TF_NewSessionOptions();
TF_Session* sess = TF_NewSession(graph, opts, s);

对会话进行配置,图中的a,b是其需要输入的数据,所以需要配置为feeds;plus2,plusB,plusC是需要输出的数据,所以配置成fetches。
调用TF_SessionPRunSetup,框架就会生产计算图,优化计算图,生成exectuor,送入第一个根节点给exector。

	TF_Output feeds[] = { TF_Output { a, 0 }, TF_Output { b, 0 } };
	TF_Output fetches[] = { TF_Output { plus2, 0 }, TF_Output { plusB, 0 }, TF_Output { plusC, 0 }  };
 
	const char* handle = nullptr;
	TF_SessionPRunSetup(sess, feeds, TF_ARRAYSIZE(feeds), fetches,
			TF_ARRAYSIZE(fetches), NULL, 0, &handle, s);

送入数据,设置输出变量,运行计算图。

	TF_Output feeds1[] = { TF_Output { a, 0 }, TF_Output { b, 0 } };
	TF_Output fetches1[] = { TF_Output { plus2, 0 }, TF_Output { plusB, 0 }, TF_Output { plusC, 0 } };
	TF_Tensor* feedValues1[] = { Int32Tensor(1), Int32Tensor(3) };
	TF_Tensor* fetchValues1[3];
	TF_SessionPRun(sess, handle, feeds1, feedValues1, 2, fetches1, fetchValues1,
			3, NULL, 0, s);

附上一个工程:TensorflowTest本作品采用知识共享署名 4.0 国际许可协议进行许可。

发表回复