TensorFlow Java API 學習筆記

maweiliang發表於2019-03-03

之前寫過一篇TensorFlow Java 環境的搭建 TensorFlow Java+eclipse下環境搭建,今天看看TensorFlow Java API 的簡單說明 和操作。

TensorFlow是什麼

由 Google 開源,是一個深度學習庫,
是一套使用資料流圖 (data flow graphics)進行資料計算的軟體庫(software library) 和應用介面(API),並以此作為基礎加上其它功能的庫和開發工具成為一套進行機器學習、特別是深度學習(deep learning)的應用程式開發框架 (framework)。 —————谷歌開發技術推广部 大中華區主管 欒躍 (Bill Luan)

支援CNN、RNN和LSTM演算法,是目前在 Image,NLP (神經語言學)最流行的深度神經網路模型。

TensorFlow 優點

基於Python,寫的很快並且具有可讀性。

在多GPU系統上的執行更為順暢。

程式碼編譯效率較高。

社群發展的非常迅速並且活躍。

能夠生成顯示網路拓撲結構和效能的視覺化圖

TensorFlow 的工作原理

TensorFlow是用資料流圖(data flow graphs)技術來進行數值計算的

邊:用於傳送節點之間的多維陣列,即張量( tensor )

節點:表示數學運算操作符 用operation表示,簡稱op

TensorFlow Java API  學習筆記
TensorFlow Java API  學習筆記

TensorFlow Java API

TensorFlow Java API  學習筆記
public class HelloTF {
	public static void main(String[] args) throws Exception {
		try (Graph g = new Graph(); Session s = new Session(g)) {
			// 使用佔位符構造一個圖,新增兩個浮點型的張量
			Output x = g.opBuilder("Placeholder", "x").setAttr("dtype", DataType.FLOAT).build().output(0);//建立一個OP
			Output y = g.opBuilder("Placeholder", "y").setAttr("dtype", DataType.FLOAT).build().output(0);
			Output z = g.opBuilder("Add", "z").addInput(x).addInput(y).build().output(0);
			System.out.println( " z= " + z);
			// 多次執行,每次使用不同的x和y值
			float[] X = new float[] { 1, 2, 3 };
			float[] Y = new float[] { 4, 5, 6 };
			for (int i = 0; i < X.length; i++) {
				try (Tensor tx = Tensor.create(X[i]);
						Tensor ty = Tensor.create(Y[i]);
						Tensor tz = s.runner().feed("x", tx).feed("y", ty).fetch("z").run().get(0)) {
					System.out.println(X[i] + " + " + Y[i] + " = " + tz.floatValue());
				}
			}
		}
     
		Graph graph = new Graph();       
		Tensor tensor = Tensor.create(2);
		Tensor tensor2 = tensor.create(3);
		Output output = graph.opBuilder("Const", "mx").setAttr("dtype", tensor.dataType()).setAttr("value", tensor).build().output(0);
		Output output2 = graph.opBuilder("Const", "my").setAttr("dtype", tensor2.dataType()).setAttr("value", tensor2).build().output(0);
		Output output3 =graph.opBuilder("Sub", "mz").addInput(output).addInput(output2).build().output(0);
		Session session = new Session(graph);	
		Tensor ttt=  session.runner().fetch("mz").run().get(0);
		System.out.println(ttt.intValue());
		Tensor t= session.runner().feed("mx", tensor).feed("my", tensor2).fetch("mz").run().get(0);
		System.out.println(t.intValue());
		session.close();
		tensor.close();
		tensor2.close();
		graph.close();
	}
}
複製程式碼
 z= <Add `z:0` shape=<unknown> dtype=FLOAT>
1.0 + 4.0 = 5.0
2.0 + 5.0 = 7.0
3.0 + 6.0 = 9.0
-1
-1
複製程式碼

相關文章