原创文章,转载请注明: 转载自慢慢的回味
本文链接地址: 基于tensorflow c lib调试的主程序
回目录
主程序明细
下面贴上示例代码,程序完成:kone + ktwo = kthree, A + ktwo = plus2, plus2 + B = plusB, plusB + kthree = plusC。
//============================================================================ // Name : TensorflowTest.cpp // Author : // Version : // Copyright : Your copyright notice // Description : Hello World in C++, Ansi-style //============================================================================ #include <iostream> #include <tensorflow/c/c_api.h> #include <tensorflow/c/c_test_util.h> #include <algorithm> #include <cstddef> #include <iterator> #include <memory> #include <vector> #include <string.h> using namespace std; int main() { cout << "!!!Hello World!!!" << endl; // prints !!!Hello World!!! cout << "Hello from TensorFlow C library version" << TF_Version() << endl; TF_Status* s = TF_NewStatus(); TF_Graph* graph = TF_NewGraph(); // Construct the graph: A + 2 + B TF_Operation* a = Placeholder(graph, s, "A"); cout << TF_Message(s); TF_Operation* b = Placeholder(graph, s, "B"); cout << TF_Message(s); TF_Operation* one = ScalarConst(1, graph, s, "kone"); cout << TF_Message(s); TF_Operation* two = ScalarConst(2, graph, s, "ktwo"); cout << TF_Message(s); TF_Operation* three = Add(one, two, graph, s, "kthree"); cout << TF_Message(s); TF_Operation* plus2 = Add(a, two, graph, s, "plus2"); cout << TF_Message(s); TF_Operation* plusB = Add(plus2, b, graph, s, "plusB"); cout << TF_Message(s); TF_Operation* plusC = Add(plusB, three, graph, s, "plusC"); cout << TF_Message(s); // Setup a session and a partial run handle. The partial run will allow // computation of A + 2 + B in two phases (calls to TF_SessionPRun): // 1. Feed A and get (A+2) // 2. Feed B and get (A+2)+B TF_SessionOptions* opts = TF_NewSessionOptions(); TF_Session* sess = TF_NewSession(graph, opts, s); TF_DeleteSessionOptions(opts); TF_Output feeds[] = { TF_Output { a, 0 }, TF_Output { b, 0 } }; TF_Output fetches[] = { TF_Output { plus2, 0 }, TF_Output { plusB, 0 }, TF_Output { plusC, 0 } }; const char* handle = nullptr; TF_SessionPRunSetup(sess, feeds, TF_ARRAYSIZE(feeds), fetches, TF_ARRAYSIZE(fetches), NULL, 0, &handle, s); cout << TF_Message(s); // Feed A and fetch A + 2. TF_Output feeds1[] = { TF_Output { a, 0 }, TF_Output { b, 0 } }; TF_Output fetches1[] = { TF_Output { plus2, 0 }, TF_Output { plusB, 0 }, TF_Output { plusC, 0 } }; TF_Tensor* feedValues1[] = { Int32Tensor(1), Int32Tensor(3) }; TF_Tensor* fetchValues1[3]; TF_SessionPRun(sess, handle, feeds1, feedValues1, 2, fetches1, fetchValues1, 3, NULL, 0, s); cout << TF_Message(s); cout << *(static_cast<int*>(TF_TensorData(fetchValues1[0]))) << endl; cout << *(static_cast<int*>(TF_TensorData(fetchValues1[1]))) << endl; cout << *(static_cast<int*>(TF_TensorData(fetchValues1[2]))) << endl; // Clean up. TF_DeletePRunHandle(handle); TF_DeleteSession(sess, s); cout << TF_Message(s); TF_DeleteGraph(graph); TF_DeleteStatus(s); return 0; } </vector></memory></iterator></cstddef></algorithm></iostream> |
代码解释
创建一个计算图,后面的节点会在此图上添加节点。
TF_Graph* graph = TF_NewGraph(); |
创建一个Placeholder节点,在图运行的时候才会赋值。
TF_Operation* a = Placeholder(graph, s, "A"); |
创建一个常量节点。
TF_Operation* one = ScalarConst(1, graph, s, "kone"); |
创建一个Add运算节点,用于加上几个输入变量。
TF_Operation* three = Add(one, two, graph, s, "kthree"); |
使用上面的计算图创建一个会话。
TF_SessionOptions* opts = TF_NewSessionOptions(); TF_Session* sess = TF_NewSession(graph, opts, s); |
对会话进行配置,图中的a,b是其需要输入的数据,所以需要配置为feeds;plus2,plusB,plusC是需要输出的数据,所以配置成fetches。
调用TF_SessionPRunSetup,框架就会生产计算图,优化计算图,生成exectuor,送入第一个根节点给exector。
TF_Output feeds[] = { TF_Output { a, 0 }, TF_Output { b, 0 } }; TF_Output fetches[] = { TF_Output { plus2, 0 }, TF_Output { plusB, 0 }, TF_Output { plusC, 0 } }; const char* handle = nullptr; TF_SessionPRunSetup(sess, feeds, TF_ARRAYSIZE(feeds), fetches, TF_ARRAYSIZE(fetches), NULL, 0, &handle, s); |
送入数据,设置输出变量,运行计算图。
TF_Output feeds1[] = { TF_Output { a, 0 }, TF_Output { b, 0 } }; TF_Output fetches1[] = { TF_Output { plus2, 0 }, TF_Output { plusB, 0 }, TF_Output { plusC, 0 } }; TF_Tensor* feedValues1[] = { Int32Tensor(1), Int32Tensor(3) }; TF_Tensor* fetchValues1[3]; TF_SessionPRun(sess, handle, feeds1, feedValues1, 2, fetches1, fetchValues1, 3, NULL, 0, s); |
附上一个工程:TensorflowTest本作品采用知识共享署名 4.0 国际许可协议进行许可。