示例程式
強連通分量
在有向圖中,如果從任意一個頂點出發,都能通過圖中的邊到達圖中的每一個頂點,則稱之為強連通圖。一張有向圖的頂點數極大的強連通子圖稱為強連通分量。此演算法示例基於 parallel Coloring algorithm。
每個頂點包含兩個部分,如下所示:
colorID:在向前遍歷過程中儲存頂點 v 的顏色,在計算結束時,具有相同 colorID 的頂點屬於一個強連通分量。
transposeNeighbors:儲存輸入圖的轉置圖中頂點 v 的鄰居 ID。
演算法包含以下四部分:
生成轉置圖:包含兩個超步,首先每個頂點傳送 ID 到其出邊對應的鄰居,這些 ID 在第二個超步中會存為 transposeNeighbors 值。
修剪:一個超步,每個只有一個入邊或出邊的頂點,將其 colorID 設為自身 ID,狀態設為不活躍,後面傳給該頂點的訊號被忽略。
向前遍歷:頂點包括兩個子過程(超步),啟動和休眠。在啟動階段,每個頂點將其 colorID 設定為自身 ID,同時將其 ID 傳給出邊對應的鄰居。休眠階段,頂點使用其收到的最大 colorID 更新自身 colorID,並傳播其 colorID,直到 colorID 收斂。當 colorID 收斂,master 程式將全域性物件設定為向後遍歷。
向後遍歷:同樣包含兩個子過程,啟動和休眠。啟動階段,每一個 ID 等於 colorID 的頂點將其 ID 傳遞給其轉置圖鄰居頂點,同時將自身狀態設定為不活躍,後面傳給該頂點的訊號可忽略。在每一個休眠步,每個頂點接收到與其 colorID 匹配的訊號,並將其 colorID 在轉置圖中傳播,隨後設定自身狀態為不活躍。該步結束後如果仍有活躍頂點,則回到修剪步。
程式碼示例
強連通分量的程式碼,如下所示:
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import com.aliyun.odps.data.TableInfo;
import com.aliyun.odps.graph.Aggregator;
import com.aliyun.odps.graph.ComputeContext;
import com.aliyun.odps.graph.GraphJob;
import com.aliyun.odps.graph.GraphLoader;
import com.aliyun.odps.graph.MutationContext;
import com.aliyun.odps.graph.Vertex;
import com.aliyun.odps.graph.WorkerContext;
import com.aliyun.odps.io.BooleanWritable;
import com.aliyun.odps.io.IntWritable;
import com.aliyun.odps.io.LongWritable;
import com.aliyun.odps.io.NullWritable;
import com.aliyun.odps.io.Tuple;
import com.aliyun.odps.io.Writable;
import com.aliyun.odps.io.WritableRecord;
/**
* Definition from Wikipedia:
* In the mathematical theory of directed graphs, a graph is said
* to be strongly connected if every vertex is reachable from every
* other vertex. The strongly connected components of an arbitrary
* directed graph form a partition into subgraphs that are themselves
* strongly connected.
*
* Algorithms with four phases as follows.
* 1\. Transpose Graph Formation: Requires two supersteps. In the first
* superstep, each vertex sends a message with its ID to all its outgoing
* neighbors, which in the second superstep are stored in transposeNeighbors.
*
* 2\. Trimming: Takes one superstep. Every vertex with only in-coming or
* only outgoing edges (or neither) sets its colorID to its own ID and
* becomes inactive. Messages subsequently sent to the vertex are ignored.
*
* 3\. Forward-Traversal: There are two sub phases: Start and Rest. In the
* Start phase, each vertex sets its colorID to its own ID and propagates
* its ID to its outgoing neighbors. In the Rest phase, vertices update
* their own colorIDs with the minimum colorID they have seen, and propagate
* their colorIDs, if updated, until the colorIDs converge.
* Set the phase to Backward-Traversal when the colorIDs converge.
*
* 4\. Backward-Traversal: We again break the phase into Start and Rest.
* In Start, every vertex whose ID equals its colorID propagates its ID to
* the vertices in transposeNeighbors and sets itself inactive. Messages
* subsequently sent to the vertex are ignored. In each of the Rest phase supersteps,
* each vertex receiving a message that matches its colorID: (1) propagates
* its colorID in the transpose graph; (2) sets itself inactive. Messages
* subsequently sent to the vertex are ignored. Set the phase back to Trimming
* if not all vertex are inactive.
*
* http://ilpubs.stanford.edu:8090/1077/3/p535-salihoglu.pdf
*/
public class StronglyConnectedComponents {
public final static int STAGE_TRANSPOSE_1 = 0;
public final static int STAGE_TRANSPOSE_2 = 1;
public final static int STAGE_TRIMMING = 2;
public final static int STAGE_FW_START = 3;
public final static int STAGE_FW_REST = 4;
public final static int STAGE_BW_START = 5;
public final static int STAGE_BW_REST = 6;
/**
* The value is composed of component id, incoming neighbors,
* active status and updated status.
*/
public static class MyValue implements Writable {
LongWritable sccID;// strongly connected component id
Tuple inNeighbors; // transpose neighbors
BooleanWritable active; // vertex is active or not
BooleanWritable updated; // sccID is updated or not
public MyValue() {
this.sccID = new LongWritable(Long.MAX_VALUE);
this.inNeighbors = new Tuple();
this.active = new BooleanWritable(true);
this.updated = new BooleanWritable(false);
}
public void setSccID(LongWritable sccID) {
this.sccID = sccID;
}
public LongWritable getSccID() {
return this.sccID;
}
public void setInNeighbors(Tuple inNeighbors) {
this.inNeighbors = inNeighbors;
}
public Tuple getInNeighbors() {
return this.inNeighbors;
}
public void addInNeighbor(LongWritable neighbor) {
this.inNeighbors.append(new LongWritable(neighbor.get()));
}
public boolean isActive() {
return this.active.get();
}
public void setActive(boolean status) {
this.active.set(status);
}
public boolean isUpdated() {
return this.updated.get();
}
public void setUpdated(boolean update) {
this.updated.set(update);
}
@Override
public void write(DataOutput out) throws IOException {
this.sccID.write(out);
this.inNeighbors.write(out);
this.active.write(out);
this.updated.write(out);
}
@Override
public void readFields(DataInput in) throws IOException {
this.sccID.readFields(in);
this.inNeighbors.readFields(in);
this.active.readFields(in);
this.updated.readFields(in);
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("sccID: " + sccID.get());
sb.append(" inNeighbores: " + inNeighbors.toDelimitedString(','));
sb.append(" active: " + active.get());
sb.append(" updated: " + updated.get());
return sb.toString();
}
}
public static class SCCVertex extends
Vertex<LongWritable, MyValue, NullWritable, LongWritable> {
public SCCVertex() {
this.setValue(new MyValue());
}
@Override
public void compute(
ComputeContext<LongWritable, MyValue, NullWritable, LongWritable> context,
Iterable<LongWritable> msgs) throws IOException {
// Messages sent to inactive vertex are ignored.
if (!this.getValue().isActive()) {
this.voteToHalt();
return;
}
int stage = ((SCCAggrValue)context.getLastAggregatedValue(0)).getStage();
switch (stage) {
case STAGE_TRANSPOSE_1:
context.sendMessageToNeighbors(this, this.getId());
break;
case STAGE_TRANSPOSE_2:
for (LongWritable msg: msgs) {
this.getValue().addInNeighbor(msg);
}
case STAGE_TRIMMING:
this.getValue().setSccID(getId());
if (this.getValue().getInNeighbors().size() == 0 ||
this.getNumEdges() == 0) {
this.getValue().setActive(false);
}
break;
case STAGE_FW_START:
this.getValue().setSccID(getId());
context.sendMessageToNeighbors(this, this.getValue().getSccID());
break;
case STAGE_FW_REST:
long minSccID = Long.MAX_VALUE;
for (LongWritable msg : msgs) {
if (msg.get() < minSccID) {
minSccID = msg.get();
}
}
if (minSccID < this.getValue().getSccID().get()) {
this.getValue().setSccID(new LongWritable(minSccID));
context.sendMessageToNeighbors(this, this.getValue().getSccID());
this.getValue().setUpdated(true);
} else {
this.getValue().setUpdated(false);
}
break;
case STAGE_BW_START:
if (this.getId().equals(this.getValue().getSccID())) {
for (Writable neighbor : this.getValue().getInNeighbors().getAll()) {
context.sendMessage((LongWritable)neighbor, this.getValue().getSccID());
}
this.getValue().setActive(false);
}
break;
case STAGE_BW_REST:
this.getValue().setUpdated(false);
for (LongWritable msg : msgs) {
if (msg.equals(this.getValue().getSccID())) {
for (Writable neighbor : this.getValue().getInNeighbors().getAll()) {
context.sendMessage((LongWritable)neighbor, this.getValue().getSccID());
}
this.getValue().setActive(false);
this.getValue().setUpdated(true);
break;
}
}
break;
}
context.aggregate(0, getValue());
}
@Override
public void cleanup(
WorkerContext<LongWritable, MyValue, NullWritable, LongWritable> context)
throws IOException {
context.write(getId(), getValue().getSccID());
}
}
/**
* The SCCAggrValue maintains global stage and graph updated and active status.
* updated is true only if one vertex is updated.
* active is true only if one vertex is active.
*/
public static class SCCAggrValue implements Writable {
IntWritable stage = new IntWritable(STAGE_TRANSPOSE_1);
BooleanWritable updated = new BooleanWritable(false);
BooleanWritable active = new BooleanWritable(false);
public void setStage(int stage) {
this.stage.set(stage);
}
public int getStage() {
return this.stage.get();
}
public void setUpdated(boolean updated) {
this.updated.set(updated);
}
public boolean getUpdated() {
return this.updated.get();
}
public void setActive(boolean active) {
this.active.set(active);
}
public boolean getActive() {
return this.active.get();
}
@Override
public void write(DataOutput out) throws IOException {
this.stage.write(out);
this.updated.write(out);
this.active.write(out);
}
@Override
public void readFields(DataInput in) throws IOException {
this.stage.readFields(in);
this.updated.readFields(in);
this.active.readFields(in);
}
}
/**
* The job of SCCAggregator is to schedule global stage in every superstep.
*/
public static class SCCAggregator extends Aggregator<SCCAggrValue> {
@SuppressWarnings("rawtypes")
@Override
public SCCAggrValue createStartupValue(WorkerContext context) throws IOException {
return new SCCAggrValue();
}
@SuppressWarnings("rawtypes")
@Override
public SCCAggrValue createInitialValue(WorkerContext context)
throws IOException {
return (SCCAggrValue) context.getLastAggregatedValue(0);
}
@Override
public void aggregate(SCCAggrValue value, Object item) throws IOException {
MyValue v = (MyValue)item;
if ((value.getStage() == STAGE_FW_REST || value.getStage() == STAGE_BW_REST)
&& v.isUpdated()) {
value.setUpdated(true);
}
// only active vertex invoke aggregate()
value.setActive(true);
}
@Override
public void merge(SCCAggrValue value, SCCAggrValue partial)
throws IOException {
boolean updated = value.getUpdated() || partial.getUpdated();
value.setUpdated(updated);
boolean active = value.getActive() || partial.getActive();
value.setActive(active);
}
@SuppressWarnings("rawtypes")
@Override
public boolean terminate(WorkerContext context, SCCAggrValue value)
throws IOException {
// If all vertices is inactive, job is over.
if (!value.getActive()) {
return true;
}
// state machine
switch (value.getStage()) {
case STAGE_TRANSPOSE_1:
value.setStage(STAGE_TRANSPOSE_2);
break;
case STAGE_TRANSPOSE_2:
value.setStage(STAGE_TRIMMING);
break;
case STAGE_TRIMMING:
value.setStage(STAGE_FW_START);
break;
case STAGE_FW_START:
value.setStage(STAGE_FW_REST);
break;
case STAGE_FW_REST:
if (value.getUpdated()) {
value.setStage(STAGE_FW_REST);
} else {
value.setStage(STAGE_BW_START);
}
break;
case STAGE_BW_START:
value.setStage(STAGE_BW_REST);
break;
case STAGE_BW_REST:
if (value.getUpdated()) {
value.setStage(STAGE_BW_REST);
} else {
value.setStage(STAGE_TRIMMING);
}
break;
}
value.setActive(false);
value.setUpdated(false);
return false;
}
}
public static class SCCVertexReader extends
GraphLoader<LongWritable, MyValue, NullWritable, LongWritable> {
@Override
public void load(
LongWritable recordNum,
WritableRecord record,
MutationContext<LongWritable, MyValue, NullWritable, LongWritable> context)
throws IOException {
SCCVertex vertex = new SCCVertex();
vertex.setId((LongWritable) record.get(0));
String[] edges = record.get(1).toString().split(",");
for (int i = 0; i < edges.length; i++) {
try {
long destID = Long.parseLong(edges[i]);
vertex.addEdge(new LongWritable(destID), NullWritable.get());
} catch(NumberFormatException nfe) {
System.err.println("Ignore " + nfe);
}
}
context.addVertexRequest(vertex);
}
}
public static void main(String[] args) throws IOException {
if (args.length < 2) {
System.out.println("Usage: <input> <output>");
System.exit(-1);
}
GraphJob job = new GraphJob();
job.setGraphLoaderClass(SCCVertexReader.class);
job.setVertexClass(SCCVertex.class);
job.setAggregatorClass(SCCAggregator.class);
job.addInput(TableInfo.builder().tableName(args[0]).build());
job.addOutput(TableInfo.builder().tableName(args[1]).build());
long startTime = System.currentTimeMillis();
job.run();
System.out.println("Job Finished in "
+ (System.currentTimeMillis() - startTime) / 1000.0 + " seconds");
}
}
複製程式碼
連通分量
兩個頂點之間存在路徑,稱兩個頂點為連通的。如果無向圖 G 中任意兩個頂點都是連通的,則稱 G 為連通圖,否則稱為非連通圖。其頂點個數極大的連通子圖稱為連通分量。
本演算法計算每個點的連通分量成員,最後輸出頂點值中包含最小頂點 ID 的連通分量。將最小頂點 ID 沿著邊傳播到連通分量的所有頂點。
程式碼示例
連通分量的程式碼,如下所示:
import java.io.IOException;
import com.aliyun.odps.data.TableInfo;
import com.aliyun.odps.graph.ComputeContext;
import com.aliyun.odps.graph.GraphJob;
import com.aliyun.odps.graph.GraphLoader;
import com.aliyun.odps.graph.MutationContext;
import com.aliyun.odps.graph.Vertex;
import com.aliyun.odps.graph.WorkerContext;
import com.aliyun.odps.graph.examples.SSSP.MinLongCombiner;
import com.aliyun.odps.io.LongWritable;
import com.aliyun.odps.io.NullWritable;
import com.aliyun.odps.io.WritableRecord;
/**
* Compute the connected component membership of each vertex and output
* each vertex which's value containing the smallest id in the connected
* component containing that vertex.
*
* Algorithm: propagate the smallest vertex id along the edges to all
* vertices of a connected component.
*
*/
public class ConnectedComponents {
public static class CCVertex extends
Vertex<LongWritable, LongWritable, NullWritable, LongWritable> {
@Override
public void compute(
ComputeContext<LongWritable, LongWritable, NullWritable, LongWritable> context,
Iterable<LongWritable> msgs) throws IOException {
if (context.getSuperstep() == 0L) {
this.setValue(getId());
context.sendMessageToNeighbors(this, getValue());
return;
}
long minID = Long.MAX_VALUE;
for (LongWritable id : msgs) {
if (id.get() < minID) {
minID = id.get();
}
}
if (minID < this.getValue().get()) {
this.setValue(new LongWritable(minID));
context.sendMessageToNeighbors(this, getValue());
} else {
this.voteToHalt();
}
}
/**
* Output Table Description:
* +-----------------+----------------------------------------+
* | Field | Type | Comment |
* +-----------------+----------------------------------------+
* | v | bigint | vertex id |
* | minID | bigint | smallest id in the connected component |
* +-----------------+----------------------------------------+
*/
@Override
public void cleanup(
WorkerContext<LongWritable, LongWritable, NullWritable, LongWritable> context)
throws IOException {
context.write(getId(), getValue());
}
}
/**
* Input Table Description:
* +-----------------+----------------------------------------------------+
* | Field | Type | Comment |
* +-----------------+----------------------------------------------------+
* | v | bigint | vertex id |
* | es | string | comma separated target vertex id of outgoing edges |
* +-----------------+----------------------------------------------------+
*
* Example:
* For graph:
* 1 ----- 2
* | |
* 3 ----- 4
* Input table:
* +-----------+
* | v | es |
* +-----------+
* | 1 | 2,3 |
* | 2 | 1,4 |
* | 3 | 1,4 |
* | 4 | 2,3 |
* +-----------+
*/
public static class CCVertexReader extends
GraphLoader<LongWritable, LongWritable, NullWritable, LongWritable> {
@Override
public void load(
LongWritable recordNum,
WritableRecord record,
MutationContext<LongWritable, LongWritable, NullWritable, LongWritable> context)
throws IOException {
CCVertex vertex = new CCVertex();
vertex.setId((LongWritable) record.get(0));
String[] edges = record.get(1).toString().split(",");
for (int i = 0; i < edges.length; i++) {
long destID = Long.parseLong(edges[i]);
vertex.addEdge(new LongWritable(destID), NullWritable.get());
}
context.addVertexRequest(vertex);
}
}
public static void main(String[] args) throws IOException {
if (args.length < 2) {
System.out.println("Usage: <input> <output>");
System.exit(-1);
}
GraphJob job = new GraphJob();
job.setGraphLoaderClass(CCVertexReader.class);
job.setVertexClass(CCVertex.class);
job.setCombinerClass(MinLongCombiner.class);
job.addInput(TableInfo.builder().tableName(args[0]).build());
job.addOutput(TableInfo.builder().tableName(args[1]).build());
long startTime = System.currentTimeMillis();
job.run();
System.out.println("Job Finished in "
+ (System.currentTimeMillis() - startTime) / 1000.0 + " seconds");
}
}
複製程式碼
拓撲排序
對於有向邊(u,v),定義所有滿足 u演算法步驟如下:
從圖中找到一個沒有入邊的頂點,並輸出。
從圖中刪除該點,及其所有出邊。
重複以上步驟,直到所有點都已輸出。
程式碼示例
拓撲排序演算法的程式碼,如下所示:
import java.io.IOException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import com.aliyun.odps.data.TableInfo;
import com.aliyun.odps.graph.Aggregator;
import com.aliyun.odps.graph.Combiner;
import com.aliyun.odps.graph.ComputeContext;
import com.aliyun.odps.graph.GraphJob;
import com.aliyun.odps.graph.GraphLoader;
import com.aliyun.odps.graph.MutationContext;
import com.aliyun.odps.graph.Vertex;
import com.aliyun.odps.graph.WorkerContext;
import com.aliyun.odps.io.LongWritable;
import com.aliyun.odps.io.NullWritable;
import com.aliyun.odps.io.BooleanWritable;
import com.aliyun.odps.io.WritableRecord;
public class TopologySort {
private final static Log LOG = LogFactory.getLog(TopologySort.class);
public static class TopologySortVertex extends
Vertex<LongWritable, LongWritable, NullWritable, LongWritable> {
@Override
public void compute(
ComputeContext<LongWritable, LongWritable, NullWritable, LongWritable> context,
Iterable<LongWritable> messages) throws IOException {
// in superstep 0, each vertex sends message whose value is 1 to its
// neighbors
if (context.getSuperstep() == 0) {
if (hasEdges()) {
context.sendMessageToNeighbors(this, new LongWritable(1L));
}
} else if (context.getSuperstep() >= 1) {
// compute each vertex's indegree
long indegree = getValue().get();
for (LongWritable msg : messages) {
indegree += msg.get();
}
setValue(new LongWritable(indegree));
if (indegree == 0) {
voteToHalt();
if (hasEdges()) {
context.sendMessageToNeighbors(this, new LongWritable(-1L));
}
context.write(new LongWritable(context.getSuperstep()), getId());
LOG.info("vertex: " + getId());
}
context.aggregate(new LongWritable(indegree));
}
}
}
public static class TopologySortVertexReader extends
GraphLoader<LongWritable, LongWritable, NullWritable, LongWritable> {
@Override
public void load(
LongWritable recordNum,
WritableRecord record,
MutationContext<LongWritable, LongWritable, NullWritable, LongWritable> context)
throws IOException {
TopologySortVertex vertex = new TopologySortVertex();
vertex.setId((LongWritable) record.get(0));
vertex.setValue(new LongWritable(0));
String[] edges = record.get(1).toString().split(",");
for (int i = 0; i < edges.length; i++) {
long edge = Long.parseLong(edges[i]);
if (edge >= 0) {
vertex.addEdge(new LongWritable(Long.parseLong(edges[i])),
NullWritable.get());
}
}
LOG.info(record.toString());
context.addVertexRequest(vertex);
}
}
public static class LongSumCombiner extends
Combiner<LongWritable, LongWritable> {
@Override
public void combine(LongWritable vertexId, LongWritable combinedMessage,
LongWritable messageToCombine) throws IOException {
combinedMessage.set(combinedMessage.get() + messageToCombine.get());
}
}
public static class TopologySortAggregator extends
Aggregator<BooleanWritable> {
@SuppressWarnings("rawtypes")
@Override
public BooleanWritable createInitialValue(WorkerContext context)
throws IOException {
return new BooleanWritable(true);
}
@Override
public void aggregate(BooleanWritable value, Object item)
throws IOException {
boolean hasCycle = value.get();
boolean inDegreeNotZero = ((LongWritable) item).get() == 0 ? false : true;
value.set(hasCycle && inDegreeNotZero);
}
@Override
public void merge(BooleanWritable value, BooleanWritable partial)
throws IOException {
value.set(value.get() && partial.get());
}
@SuppressWarnings("rawtypes")
@Override
public boolean terminate(WorkerContext context, BooleanWritable value)
throws IOException {
if (context.getSuperstep() == 0) {
// since the initial aggregator value is true, and in superstep we don't
// do aggregate
return false;
}
return value.get();
}
}
public static void main(String[] args) throws IOException {
if (args.length != 2) {
System.out.println("Usage : <inputTable> <outputTable>");
System.exit(-1);
}
// 輸入表形式為
// 0 1,2
// 1 3
// 2 3
// 3 -1
// 第一列為vertexid,第二列為該點邊的destination vertexid,若值為-1,表示該點無出邊
// 輸出表形式為
// 0 0
// 1 1
// 1 2
// 2 3
// 第一列為supstep值,隱含了拓撲順序,第二列為vertexid
// TopologySortAggregator用來判斷圖中是否有環
// 若輸入的圖有環,則當圖中active的點入度都不為0時,迭代結束
// 使用者可以通過輸入表和輸出表的記錄數來判斷一個有向圖是否有環
GraphJob job = new GraphJob();
job.setGraphLoaderClass(TopologySortVertexReader.class);
job.setVertexClass(TopologySortVertex.class);
job.addInput(TableInfo.builder().tableName(args[0]).build());
job.addOutput(TableInfo.builder().tableName(args[1]).build());
job.setCombinerClass(LongSumCombiner.class);
job.setAggregatorClass(TopologySortAggregator.class);
long startTime = System.currentTimeMillis();
job.run();
System.out.println("Job Finished in "
+ (System.currentTimeMillis() - startTime) / 1000.0 + " seconds");
}
}
複製程式碼
線性迴歸
在統計學中,線性迴歸是用來確定兩種或兩種以上變數間的相互依賴關係的統計分析方法,與分類演算法處理離散預測不同。
迴歸演算法可對連續值型別進行預測。線性迴歸演算法定義損失函式為樣本集的最小平方誤差之和,通過最小化損失函式求解權重向量。
常用的解法是梯度下降法,流程如下:
初始化權重向量,給定下降速率以及迭代次數(或者迭代收斂條件)。
對每個樣本,計算最小平方誤差。
對最小平方誤差求和,根據下降速率更新權重。
重複迭代直到收斂。
程式碼示例
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import com.aliyun.odps.data.TableInfo;
import com.aliyun.odps.graph.Aggregator;
import com.aliyun.odps.graph.ComputeContext;
import com.aliyun.odps.graph.GraphJob;
import com.aliyun.odps.graph.MutationContext;
import com.aliyun.odps.graph.WorkerContext;
import com.aliyun.odps.graph.Vertex;
import com.aliyun.odps.graph.GraphLoader;
import com.aliyun.odps.io.DoubleWritable;
import com.aliyun.odps.io.LongWritable;
import com.aliyun.odps.io.NullWritable;
import com.aliyun.odps.io.Tuple;
import com.aliyun.odps.io.Writable;
import com.aliyun.odps.io.WritableRecord;
/**
* LineRegression input: y,x1,x2,x3,......
**/
public class LinearRegression {
public static class GradientWritable implements Writable {
Tuple lastTheta;
Tuple currentTheta;
Tuple tmpGradient;
LongWritable count;
DoubleWritable lost;
@Override
public void readFields(DataInput in) throws IOException {
lastTheta = new Tuple();
lastTheta.readFields(in);
currentTheta = new Tuple();
currentTheta.readFields(in);
tmpGradient = new Tuple();
tmpGradient.readFields(in);
count = new LongWritable();
count.readFields(in);
/* update 1: add a variable to store lost at every iteration */
lost = new DoubleWritable();
lost.readFields(in);
}
@Override
public void write(DataOutput out) throws IOException {
lastTheta.write(out);
currentTheta.write(out);
tmpGradient.write(out);
count.write(out);
lost.write(out);
}
}
public static class LinearRegressionVertex extends
Vertex<LongWritable, Tuple, NullWritable, NullWritable> {
@Override
public void compute(
ComputeContext<LongWritable, Tuple, NullWritable, NullWritable> context,
Iterable<NullWritable> messages) throws IOException {
context.aggregate(getValue());
}
}
public static class LinearRegressionVertexReader extends
GraphLoader<LongWritable, Tuple, NullWritable, NullWritable> {
@Override
public void load(LongWritable recordNum, WritableRecord record,
MutationContext<LongWritable, Tuple, NullWritable, NullWritable> context)
throws IOException {
LinearRegressionVertex vertex = new LinearRegressionVertex();
vertex.setId(recordNum);
vertex.setValue(new Tuple(record.getAll()));
context.addVertexRequest(vertex);
}
}
public static class LinearRegressionAggregator extends
Aggregator<GradientWritable> {
@SuppressWarnings("rawtypes")
@Override
public GradientWritable createInitialValue(WorkerContext context)
throws IOException {
if (context.getSuperstep() == 0) {
/* set initial value, all 0 */
GradientWritable grad = new GradientWritable();
grad.lastTheta = new Tuple();
grad.currentTheta = new Tuple();
grad.tmpGradient = new Tuple();
grad.count = new LongWritable(1);
grad.lost = new DoubleWritable(0.0);
int n = (int) Long.parseLong(context.getConfiguration()
.get("Dimension"));
for (int i = 0; i < n; i++) {
grad.lastTheta.append(new DoubleWritable(0));
grad.currentTheta.append(new DoubleWritable(0));
grad.tmpGradient.append(new DoubleWritable(0));
}
return grad;
} else
return (GradientWritable) context.getLastAggregatedValue(0);
}
public static double vecMul(Tuple value, Tuple theta) {
/* perform this partial computing: y(i)−hθ(x(i)) for each sample */
/* value denote a piece of sample and value(0) is y */
double sum = 0.0;
for (int j = 1; j < value.size(); j++)
sum += Double.parseDouble(value.get(j).toString())
* Double.parseDouble(theta.get(j).toString());
Double tmp = Double.parseDouble(theta.get(0).toString()) + sum
- Double.parseDouble(value.get(0).toString());
return tmp;
}
@Override
public void aggregate(GradientWritable gradient, Object value)
throws IOException {
/*
* perform on each vertex--each sample i:set theta(j) for each sample i
* for each dimension
*/
double tmpVar = vecMul((Tuple) value, gradient.currentTheta);
/*
* update 2:local worker aggregate(), perform like merge() below. This
* means the variable gradient denotes the previous aggregated value
*/
gradient.tmpGradient.set(0, new DoubleWritable(
((DoubleWritable) gradient.tmpGradient.get(0)).get() + tmpVar));
gradient.lost.set(Math.pow(tmpVar, 2));
/*
* calculate (y(i)−hθ(x(i))) x(i)(j) for each sample i for each
* dimension j
*/
for (int j = 1; j < gradient.tmpGradient.size(); j++)
gradient.tmpGradient.set(j, new DoubleWritable(
((DoubleWritable) gradient.tmpGradient.get(j)).get() + tmpVar
* Double.parseDouble(((Tuple) value).get(j).toString())));
}
@Override
public void merge(GradientWritable gradient, GradientWritable partial)
throws IOException {
/* perform SumAll on each dimension for all samples. */
Tuple master = (Tuple) gradient.tmpGradient;
Tuple part = (Tuple) partial.tmpGradient;
for (int j = 0; j < gradient.tmpGradient.size(); j++) {
DoubleWritable s = (DoubleWritable) master.get(j);
s.set(s.get() + ((DoubleWritable) part.get(j)).get());
}
gradient.lost.set(gradient.lost.get() + partial.lost.get());
}
@SuppressWarnings("rawtypes")
@Override
public boolean terminate(WorkerContext context, GradientWritable gradient)
throws IOException {
/*
* 1\. calculate new theta 2\. judge the diff between last step and this
* step, if smaller than the threshold, stop iteration
*/
gradient.lost = new DoubleWritable(gradient.lost.get()
/ (2 * context.getTotalNumVertices()));
/*
* we can calculate lost in order to make sure the algorithm is running on
* the right direction (for debug)
*/
System.out.println(gradient.count + " lost:" + gradient.lost);
Tuple tmpGradient = gradient.tmpGradient;
System.out.println("tmpGra" + tmpGradient);
Tuple lastTheta = gradient.lastTheta;
Tuple tmpCurrentTheta = new Tuple(gradient.currentTheta.size());
System.out.println(gradient.count + " terminate_start_last:" + lastTheta);
double alpha = 0.07; // learning rate
// alpha =
// Double.parseDouble(context.getConfiguration().get("Alpha"));
/* perform theta(j) = theta(j)-alpha*tmpGradient */
long M = context.getTotalNumVertices();
/*
* update 3: add (/M) on the code. The original code forget this step
*/
for (int j = 0; j < lastTheta.size(); j++) {
tmpCurrentTheta
.set(
j,
new DoubleWritable(Double.parseDouble(lastTheta.get(j)
.toString())
- alpha
/ M
* Double.parseDouble(tmpGradient.get(j).toString())));
}
System.out.println(gradient.count + " terminate_start_current:"
+ tmpCurrentTheta);
// judge if convergence is happening.
double diff = 0.00d;
for (int j = 0; j < gradient.currentTheta.size(); j++)
diff += Math.pow(((DoubleWritable) tmpCurrentTheta.get(j)).get()
- ((DoubleWritable) lastTheta.get(j)).get(), 2);
if (/*
* Math.sqrt(diff) < 0.00000000005d ||
*/Long.parseLong(context.getConfiguration().get("Max_Iter_Num")) == gradient.count
.get()) {
context.write(gradient.currentTheta.toArray());
return true;
}
gradient.lastTheta = tmpCurrentTheta;
gradient.currentTheta = tmpCurrentTheta;
gradient.count.set(gradient.count.get() + 1);
int n = (int) Long.parseLong(context.getConfiguration().get("Dimension"));
/*
* update 4: Important!!! Remember this step. Graph won't reset the
* initial value for global variables at the beginning of each iteration
*/
for (int i = 0; i < n; i++) {
gradient.tmpGradient.set(i, new DoubleWritable(0));
}
return false;
}
}
public static void main(String[] args) throws IOException {
GraphJob job = new GraphJob();
job.setGraphLoaderClass(LinearRegressionVertexReader.class);
job.setRuntimePartitioning(false);
job.setNumWorkers(3);
job.setVertexClass(LinearRegressionVertex.class);
job.setAggregatorClass(LinearRegressionAggregator.class);
job.addInput(TableInfo.builder().tableName(args[0]).build());
job.addOutput(TableInfo.builder().tableName(args[1]).build());
job.setMaxIteration(Integer.parseInt(args[2])); // Numbers of Iteration
job.setInt("Max_Iter_Num", Integer.parseInt(args[2]));
job.setInt("Dimension", Integer.parseInt(args[3])); // Dimension
job.setFloat("Alpha", Float.parseFloat(args[4])); // Learning rate
long start = System.currentTimeMillis();
job.run();
System.out.println("Job Finished in "
+ (System.currentTimeMillis() - start) / 1000.0 + " seconds");
}
}
複製程式碼
三角形計數
三角形計數演算法用於計算通過每個頂點的三角形個數。
演算法實現的流程如下:
每個頂點將其 ID 傳送給所有出邊鄰居。
儲存入邊和出邊鄰居併傳送給出邊鄰居。
對每條邊計算其終點的交集數量,並求和,結果輸出到表。
將表中的輸出結果求和併除以三,即得到三角形個數。
程式碼示例
三角形計數演算法的程式碼,如下所示:
import java.io.IOException;
import com.aliyun.odps.data.TableInfo;
import com.aliyun.odps.graph.ComputeContext;
import com.aliyun.odps.graph.Edge;
import com.aliyun.odps.graph.GraphJob;
import com.aliyun.odps.graph.GraphLoader;
import com.aliyun.odps.graph.MutationContext;
import com.aliyun.odps.graph.Vertex;
import com.aliyun.odps.graph.WorkerContext;
import com.aliyun.odps.io.LongWritable;
import com.aliyun.odps.io.NullWritable;
import com.aliyun.odps.io.Tuple;
import com.aliyun.odps.io.Writable;
import com.aliyun.odps.io.WritableRecord;
/**
* Compute the number of triangles passing through each vertex.
*
* The algorithm can be computed in three supersteps:
* I. Each vertex sends a message with its ID to all its outgoing
* neighbors.
* II. The incoming neighbors and outgoing neighbors are stored and
* send to outgoing neighbors.
* III. For each edge compute the intersection of the sets at destination
* vertex and sum them, then output to table.
*
* The triangle count is the sum of output table and divide by three since
* each triangle is counted three times.
*
**/
public class TriangleCount {
public static class TCVertex extends
Vertex<LongWritable, Tuple, NullWritable, Tuple> {
@Override
public void setup(
WorkerContext<LongWritable, Tuple, NullWritable, Tuple> context)
throws IOException {
// collect the outgoing neighbors
Tuple t = new Tuple();
if (this.hasEdges()) {
for (Edge<LongWritable, NullWritable> edge : this.getEdges()) {
t.append(edge.getDestVertexId());
}
}
this.setValue(t);
}
@Override
public void compute(
ComputeContext<LongWritable, Tuple, NullWritable, Tuple> context,
Iterable<Tuple> msgs) throws IOException {
if (context.getSuperstep() == 0L) {
// sends a message with its ID to all its outgoing neighbors
Tuple t = new Tuple();
t.append(getId());
context.sendMessageToNeighbors(this, t);
} else if (context.getSuperstep() == 1L) {
// store the incoming neighbors
for (Tuple msg : msgs) {
for (Writable item : msg.getAll()) {
if (!this.getValue().getAll().contains((LongWritable)item)) {
this.getValue().append((LongWritable)item);
}
}
}
// send both incoming and outgoing neighbors to all outgoing neighbors
context.sendMessageToNeighbors(this, getValue());
} else if (context.getSuperstep() == 2L) {
// count the sum of intersection at each edge
long count = 0;
for (Tuple msg : msgs) {
for (Writable id : msg.getAll()) {
if (getValue().getAll().contains(id)) {
count ++;
}
}
}
// output to table
context.write(getId(), new LongWritable(count));
this.voteToHalt();
}
}
}
public static class TCVertexReader extends
GraphLoader<LongWritable, Tuple, NullWritable, Tuple> {
@Override
public void load(
LongWritable recordNum,
WritableRecord record,
MutationContext<LongWritable, Tuple, NullWritable, Tuple> context)
throws IOException {
TCVertex vertex = new TCVertex();
vertex.setId((LongWritable) record.get(0));
String[] edges = record.get(1).toString().split(",");
for (int i = 0; i < edges.length; i++) {
try {
long destID = Long.parseLong(edges[i]);
vertex.addEdge(new LongWritable(destID), NullWritable.get());
} catch(NumberFormatException nfe) {
System.err.println("Ignore " + nfe);
}
}
context.addVertexRequest(vertex);
}
}
public static void main(String[] args) throws IOException {
if (args.length < 2) {
System.out.println("Usage: <input> <output>");
System.exit(-1);
}
GraphJob job = new GraphJob();
job.setGraphLoaderClass(TCVertexReader.class);
job.setVertexClass(TCVertex.class);
job.addInput(TableInfo.builder().tableName(args[0]).build());
job.addOutput(TableInfo.builder().tableName(args[1]).build());
long startTime = System.currentTimeMillis();
job.run();
System.out.println("Job Finished in "
+ (System.currentTimeMillis() - startTime) / 1000.0 + " seconds");
}
}
複製程式碼
輸入點表示例
輸入點表的程式碼,如下所示:
import java.io.IOException;
import com.aliyun.odps.conf.Configuration;
import com.aliyun.odps.data.TableInfo;
import com.aliyun.odps.graph.ComputeContext;
import com.aliyun.odps.graph.GraphJob;
import com.aliyun.odps.graph.GraphLoader;
import com.aliyun.odps.graph.Vertex;
import com.aliyun.odps.graph.VertexResolver;
import com.aliyun.odps.graph.MutationContext;
import com.aliyun.odps.graph.VertexChanges;
import com.aliyun.odps.graph.Edge;
import com.aliyun.odps.io.LongWritable;
import com.aliyun.odps.io.WritableComparable;
import com.aliyun.odps.io.WritableRecord;
/**
* 本示例是用於展示,對於不同型別的資料型別,如何編寫圖作業程式載入資料。主要展示GraphLoader和
* VertexResolver的配合完成圖的構建。
*
* ODPS Graph的作業輸入都為ODPS的Table,假設作業輸入有兩張表,一張儲存點的資訊,一張儲存邊的資訊。
* 儲存點資訊的表的格式,如:
* +------------------------+
* | VertexID | VertexValue |
* +------------------------+
* | id0| 9|
* +------------------------+
* | id1| 7|
* +------------------------+
* | id2| 8|
* +------------------------+
*
* 儲存邊資訊的表的格式,如
* +-----------------------------------+
* | VertexID | DestVertexID| EdgeValue|
* +-----------------------------------+
* | id0| id1| 1|
* +-----------------------------------+
* | id0| id2| 2|
* +-----------------------------------+
* | id2| id1| 3|
* +-----------------------------------+
*
* 結合兩張表的資料,表示id0有兩條出邊,分別指向id1和id2;id2有一條出邊,指向id1;id1沒有出邊。
*
* 對於此種型別的資料,在GraphLoader::load(LongWritable, Record, MutationContext)
* ,可以使用 MutationContext#addVertexRequest(Vertex)向圖中請求新增點,使用
* link MutationContext#addEdgeRequest(WritableComparable, Edge)向圖中請求新增邊,然後,在
* link VertexResolver#resolve(WritableComparable, Vertex, VertexChanges, boolean)
* 中,將load 方法中新增的點和邊,合併到一個Vertex物件中,作為返回值,新增到最後參與計算的圖中。
*
**/
public class VertexInputFormat {
private final static String EDGE_TABLE = "edge.table";
/**
* 將Record解釋為Vertex和Edge,每個Record根據其來源,表示一個Vertex或者一條Edge。
*
* 類似於com.aliyun.odps.mapreduce.Mapper#map
* ,輸入Record,生成鍵值對,此處的鍵是Vertex的ID,
* 值是Vertex或Edge,通過上下文Context寫出,這些鍵值對會在LoadingVertexResolver出根據Vertex的ID彙總。
*
* 注意:此處新增的點或邊只是根據Record內容發出的請求,並不是最後參與計算的點或邊,只有在隨後的VertexResolver
* 中新增的點或邊才參與計算。
**/
public static class VertexInputLoader extends
GraphLoader<LongWritable, LongWritable, LongWritable, LongWritable> {
private boolean isEdgeData;
/**
* 配置VertexInputLoader。
*
* @param conf
* 作業的配置引數,在main中使用GraphJob配置的,或者在console中set的
* @param workerId
* 當前工作的worker的序號,從0開始,可以用於構造唯一的vertex id
* @param inputTableInfo
* 當前worker載入的輸入表資訊,可以用於確定當前輸入是哪種型別的資料,即Record的格式
**/
@Override
public void setup(Configuration conf, int workerId, TableInfo inputTableInfo) {
isEdgeData = conf.get(EDGE_TABLE).equals(inputTableInfo.getTableName());
}
/**
* 根據Record中的內容,解析為對應的邊,並請求新增到圖中。
*
* @param recordNum
* 記錄序列號,從1開始,每個worker上單獨計數
* @param record
* 輸入表中的記錄,三列,分別表示初點、終點、邊的權重
* @param context
* 上下文,請求將解釋後的邊新增到圖中
**/
@Override
public void load(
LongWritable recordNum,
WritableRecord record,
MutationContext<LongWritable, LongWritable, LongWritable, LongWritable> context)
throws IOException {
if (isEdgeData) {
/**
* 資料來源於儲存邊資訊的表。
*
* 1、第一列表示初始點的ID
**/
LongWritable sourceVertexID = (LongWritable) record.get(0);
/**
* 2、第二列表示終點的ID
**/
LongWritable destinationVertexID = (LongWritable) record.get(1);
/**
* 3、地三列表示邊的權重
**/
LongWritable edgeValue = (LongWritable) record.get(2);
/**
* 4、建立邊,由終點ID和邊的權重組成
**/
Edge<LongWritable, LongWritable> edge = new Edge<LongWritable, LongWritable>(
destinationVertexID, edgeValue);
/**
* 5、請求給初始點新增邊
**/
context.addEdgeRequest(sourceVertexID, edge);
/**
* 6、如果每條Record表示雙向邊,重複4與5 Edge<LongWritable, LongWritable> edge2 = new
* Edge<LongWritable, LongWritable>( sourceVertexID, edgeValue);
* context.addEdgeRequest(destinationVertexID, edge2);
**/
} else {
/**
* 資料來源於儲存點資訊的表。
*
* 1、第一列表示點的ID
**/
LongWritable vertexID = (LongWritable) record.get(0);
/**
* 2、第二列表示點的值
**/
LongWritable vertexValue = (LongWritable) record.get(1);
/**
* 3、建立點,由點的ID和點的值組成
**/
MyVertex vertex = new MyVertex();
/**
* 4、初始化點
**/
vertex.setId(vertexID);
vertex.setValue(vertexValue);
/**
* 5、請求新增點
**/
context.addVertexRequest(vertex);
}
}
}
/**
* 彙總GraphLoader::load(LongWritable, Record, MutationContext)生成的鍵值對,類似於
* com.aliyun.odps.mapreduce.Reducer中的reduce。對於唯一的Vertex ID,所有關於這個ID上
* 新增\刪除、點\邊的行為都會放在VertexChanges中。
*
* 注意:此處並不只針對load方法中新增的有衝突的點或邊才呼叫(衝突是指新增多個相同Vertex物件,新增重複邊等),
* 所有在load方法中請求生成的ID都會在此處被呼叫。
**/
public static class LoadingResolver extends
VertexResolver<LongWritable, LongWritable, LongWritable, LongWritable> {
/**
* 處理關於一個ID的新增或刪除、點或邊的請求。
*
* VertexChanges有四個介面,分別與MutationContext的四個介面對應:
* VertexChanges::getAddedVertexList()與
* MutationContext::addVertexRequest(Vertex)對應,
* 在load方法中,請求新增的ID相同的Vertex物件,會被彙總在返回的List中
* VertexChanges::getAddedEdgeList()與
* MutationContext::addEdgeRequest(WritableComparable, Edge)
* 對應,請求新增的初始點ID相同的Edge物件,會被彙總在返回的List中
* VertexChanges::getRemovedVertexCount()與
* MutationContext::removeVertexRequest(WritableComparable)
* 對應,請求刪除的ID相同的Vertex,彙總的請求刪除的次數作為返回值
* VertexChanges#getRemovedEdgeList()與
* MutationContext#removeEdgeRequest(WritableComparable, WritableComparable)
* 對應,請求刪除的初始點ID相同的Edge物件,會被彙總在返回的List中
*
* 使用者通過處理關於這個ID的變化,通過返回值宣告此ID是否參與計算,如果返回的Vertex不為null,
* 則此ID會參與隨後的計算,如果返回null,則不會參與計算。
*
* @param vertexId
* 請求新增的點的ID,或請求新增的邊的初點ID
* @param vertex
* 已存在的Vertex物件,資料載入階段,始終為null
* @param vertexChanges
* 此ID上的請求新增\刪除、點\邊的集合
* @param hasMessages
* 此ID是否有輸入訊息,資料載入階段,始終為false
**/
@Override
public Vertex<LongWritable, LongWritable, LongWritable, LongWritable> resolve(
LongWritable vertexId,
Vertex<LongWritable, LongWritable, LongWritable, LongWritable> vertex,
VertexChanges<LongWritable, LongWritable, LongWritable, LongWritable> vertexChanges,
boolean hasMessages) throws IOException {
/**
* 1、獲取Vertex物件,作為參與計算的點。
**/
MyVertex computeVertex = null;
if (vertexChanges.getAddedVertexList() == null
|| vertexChanges.getAddedVertexList().isEmpty()) {
computeVertex = new MyVertex();
computeVertex.setId(vertexId);
} else {
/**
* 此處假設儲存點資訊的表中,每個Record表示唯一的點。
**/
computeVertex = (MyVertex) vertexChanges.getAddedVertexList().get(0);
}
/**
* 2、將請求給此點新增的邊,新增到Vertex物件中,如果資料有重複的可能,根據演算法需要決定是否去重。
**/
if (vertexChanges.getAddedEdgeList() != null) {
for (Edge<LongWritable, LongWritable> edge : vertexChanges
.getAddedEdgeList()) {
computeVertex.addEdge(edge.getDestVertexId(), edge.getValue());
}
}
/**
* 3、將Vertex物件返回,新增到最終的圖中參與計算。
**/
return computeVertex;
}
}
/**
* 確定參與計算的Vertex的行為。
*
**/
public static class MyVertex extends
Vertex<LongWritable, LongWritable, LongWritable, LongWritable> {
/**
* 將vertex的邊,按照輸入表的格式再寫到結果表。輸入表與輸出表的格式和資料都相同。
*
* @param context
* 執行時上下文
* @param messages
* 輸入訊息
**/
@Override
public void compute(
ComputeContext<LongWritable, LongWritable, LongWritable, LongWritable> context,
Iterable<LongWritable> messages) throws IOException {
/**
* 將點的ID和值,寫到儲存點的結果表
**/
context.write("vertex", getId(), getValue());
/**
* 將點的邊,寫到儲存邊的結果表
**/
if (hasEdges()) {
for (Edge<LongWritable, LongWritable> edge : getEdges()) {
context.write("edge", getId(), edge.getDestVertexId(),
edge.getValue());
}
}
/**
* 只迭代一輪
**/
voteToHalt();
}
}
/**
* @param args
* @throws IOException
*/
public static void main(String[] args) throws IOException {
if (args.length < 4) {
throw new IOException(
"Usage: VertexInputFormat <vertex input> <edge input> <vertex output> <edge output>");
}
/**
* GraphJob用於對Graph作業進行配置
*/
GraphJob job = new GraphJob();
/**
* 1、指定輸入的圖資料,並指定儲存邊資料所在的表。
*/
job.addInput(TableInfo.builder().tableName(args[0]).build());
job.addInput(TableInfo.builder().tableName(args[1]).build());
job.set(EDGE_TABLE, args[1]);
/**
* 2、指定載入資料的方式,將Record解釋為Edge,此處類似於map,生成的 key為vertex的ID,value為edge。
*/
job.setGraphLoaderClass(VertexInputLoader.class);
/**
* 3、指定載入資料階段,生成參與計算的vertex。此處類似於reduce,將map 生成的edge合併成一個vertex。
*/
job.setLoadingVertexResolverClass(LoadingResolver.class);
/**
* 4、指定參與計算的vertex的行為。每輪迭代執行vertex.compute方法。
*/
job.setVertexClass(MyVertex.class);
/**
* 5、指定圖作業的輸出表,將計算生成的結果寫到結果表中。
*/
job.addOutput(TableInfo.builder().tableName(args[2]).label("vertex").build());
job.addOutput(TableInfo.builder().tableName(args[3]).label("edge").build());
/**
* 6、提交作業執行。
*/
job.run();
}
}
複製程式碼
輸入邊表示例
輸入邊表的程式碼,如下所示:
import java.io.IOException;
import com.aliyun.odps.conf.Configuration;
import com.aliyun.odps.data.TableInfo;
import com.aliyun.odps.graph.ComputeContext;
import com.aliyun.odps.graph.GraphJob;
import com.aliyun.odps.graph.GraphLoader;
import com.aliyun.odps.graph.Vertex;
import com.aliyun.odps.graph.VertexResolver;
import com.aliyun.odps.graph.MutationContext;
import com.aliyun.odps.graph.VertexChanges;
import com.aliyun.odps.graph.Edge;
import com.aliyun.odps.io.LongWritable;
import com.aliyun.odps.io.WritableComparable;
import com.aliyun.odps.io.WritableRecord;
/**
* 本示例是用於展示,對於不同型別的資料型別,如何編寫圖作業程式載入資料。主要展示GraphLoader和
* VertexResolver的配合完成圖的構建。
*
* ODPS Graph的作業輸入都為ODPS的Table,假設作業輸入有兩張表,一張儲存點的資訊,一張儲存邊的資訊。
* 儲存點資訊的表的格式,如:
* +------------------------+
* | VertexID | VertexValue |
* +------------------------+
* | id0| 9|
* +------------------------+
* | id1| 7|
* +------------------------+
* | id2| 8|
* +------------------------+
*
* 儲存邊資訊的表的格式,如
* +-----------------------------------+
* | VertexID | DestVertexID| EdgeValue|
* +-----------------------------------+
* | id0| id1| 1|
* +-----------------------------------+
* | id0| id2| 2|
* +-----------------------------------+
* | id2| id1| 3|
* +-----------------------------------+
*
* 結合兩張表的資料,表示id0有兩條出邊,分別指向id1和id2;id2有一條出邊,指向id1;id1沒有出邊。
*
* 對於此種型別的資料,在GraphLoader::load(LongWritable, Record, MutationContext)
* ,可以使用 MutationContext#addVertexRequest(Vertex)向圖中請求新增點,使用
* link MutationContext#addEdgeRequest(WritableComparable, Edge)向圖中請求新增邊,然後,在
* link VertexResolver#resolve(WritableComparable, Vertex, VertexChanges, boolean)
* 中,將load 方法中新增的點和邊,合併到一個Vertex物件中,作為返回值,新增到最後參與計算的圖中。
*
**/
public class VertexInputFormat {
private final static String EDGE_TABLE = "edge.table";
/**
* 將Record解釋為Vertex和Edge,每個Record根據其來源,表示一個Vertex或者一條Edge。
* <p>
* 類似於com.aliyun.odps.mapreduce.Mapper#map
* ,輸入Record,生成鍵值對,此處的鍵是Vertex的ID,
* 值是Vertex或Edge,通過上下文Context寫出,這些鍵值對會在LoadingVertexResolver出根據Vertex的ID彙總。
*
* 注意:此處新增的點或邊只是根據Record內容發出的請求,並不是最後參與計算的點或邊,只有在隨後的VertexResolver
* 中新增的點或邊才參與計算。
**/
public static class VertexInputLoader extends
GraphLoader<LongWritable, LongWritable, LongWritable, LongWritable> {
private boolean isEdgeData;
/**
* 配置VertexInputLoader。
*
* @param conf
* 作業的配置引數,在main中使用GraphJob配置的,或者在console中set的
* @param workerId
* 當前工作的worker的序號,從0開始,可以用於構造唯一的vertex id
* @param inputTableInfo
* 當前worker載入的輸入表資訊,可以用於確定當前輸入是哪種型別的資料,即Record的格式
**/
@Override
public void setup(Configuration conf, int workerId, TableInfo inputTableInfo) {
isEdgeData = conf.get(EDGE_TABLE).equals(inputTableInfo.getTableName());
}
/**
* 根據Record中的內容,解析為對應的邊,並請求新增到圖中。
*
* @param recordNum
* 記錄序列號,從1開始,每個worker上單獨計數
* @param record
* 輸入表中的記錄,三列,分別表示初點、終點、邊的權重
* @param context
* 上下文,請求將解釋後的邊新增到圖中
**/
@Override
public void load(
LongWritable recordNum,
WritableRecord record,
MutationContext<LongWritable, LongWritable, LongWritable, LongWritable> context)
throws IOException {
if (isEdgeData) {
/**
* 資料來源於儲存邊資訊的表。
*
* 1、第一列表示初始點的ID
**/
LongWritable sourceVertexID = (LongWritable) record.get(0);
/**
* 2、第二列表示終點的ID
**/
LongWritable destinationVertexID = (LongWritable) record.get(1);
/**
* 3、地三列表示邊的權重
**/
LongWritable edgeValue = (LongWritable) record.get(2);
/**
* 4、建立邊,由終點ID和邊的權重組成
**/
Edge<LongWritable, LongWritable> edge = new Edge<LongWritable, LongWritable>(
destinationVertexID, edgeValue);
/**
* 5、請求給初始點新增邊
**/
context.addEdgeRequest(sourceVertexID, edge);
/**
* 6、如果每條Record表示雙向邊,重複4與5 Edge<LongWritable, LongWritable> edge2 = new
* Edge<LongWritable, LongWritable>( sourceVertexID, edgeValue);
* context.addEdgeRequest(destinationVertexID, edge2);
**/
} else {
/**
* 資料來源於儲存點資訊的表。
*
* 1、第一列表示點的ID
**/
LongWritable vertexID = (LongWritable) record.get(0);
/**
* 2、第二列表示點的值
**/
LongWritable vertexValue = (LongWritable) record.get(1);
/**
* 3、建立點,由點的ID和點的值組成
**/
MyVertex vertex = new MyVertex();
/**
* 4、初始化點
**/
vertex.setId(vertexID);
vertex.setValue(vertexValue);
/**
* 5、請求新增點
**/
context.addVertexRequest(vertex);
}
}
}
/**
* 彙總GraphLoader::load(LongWritable, Record, MutationContext)生成的鍵值對,類似於
* com.aliyun.odps.mapreduce.Reducer中的reduce。對於唯一的Vertex ID,所有關於這個ID上
* 新增\刪除、點\邊的行為都會放在VertexChanges中。
*
* 注意:此處並不只針對load方法中新增的有衝突的點或邊才呼叫(衝突是指新增多個相同Vertex物件,新增重複邊等),
* 所有在load方法中請求生成的ID都會在此處被呼叫。
**/
public static class LoadingResolver extends
VertexResolver<LongWritable, LongWritable, LongWritable, LongWritable> {
/**
* 處理關於一個ID的新增或刪除、點或邊的請求。
*
* VertexChanges有四個介面,分別與MutationContext的四個介面對應:
* VertexChanges::getAddedVertexList()與
* MutationContext::addVertexRequest(Vertex)對應,
* 在load方法中,請求新增的ID相同的Vertex物件,會被彙總在返回的List中
* VertexChanges::getAddedEdgeList()與
* MutationContext::addEdgeRequest(WritableComparable, Edge)
* 對應,請求新增的初始點ID相同的Edge物件,會被彙總在返回的List中
* VertexChanges::getRemovedVertexCount()與
* MutationContext::removeVertexRequest(WritableComparable)
* 對應,請求刪除的ID相同的Vertex,彙總的請求刪除的次數作為返回值
* VertexChanges#getRemovedEdgeList()與
* MutationContext#removeEdgeRequest(WritableComparable, WritableComparable)
* 對應,請求刪除的初始點ID相同的Edge物件,會被彙總在返回的List中
*
* 使用者通過處理關於這個ID的變化,通過返回值宣告此ID是否參與計算,如果返回的Vertex不為null,
* 則此ID會參與隨後的計算,如果返回null,則不會參與計算。
*
* @param vertexId
* 請求新增的點的ID,或請求新增的邊的初點ID
* @param vertex
* 已存在的Vertex物件,資料載入階段,始終為null
* @param vertexChanges
* 此ID上的請求新增\刪除、點\邊的集合
* @param hasMessages
* 此ID是否有輸入訊息,資料載入階段,始終為false
**/
@Override
public Vertex<LongWritable, LongWritable, LongWritable, LongWritable> resolve(
LongWritable vertexId,
Vertex<LongWritable, LongWritable, LongWritable, LongWritable> vertex,
VertexChanges<LongWritable, LongWritable, LongWritable, LongWritable> vertexChanges,
boolean hasMessages) throws IOException {
/**
* 1、獲取Vertex物件,作為參與計算的點。
**/
MyVertex computeVertex = null;
if (vertexChanges.getAddedVertexList() == null
|| vertexChanges.getAddedVertexList().isEmpty()) {
computeVertex = new MyVertex();
computeVertex.setId(vertexId);
} else {
/**
* 此處假設儲存點資訊的表中,每個Record表示唯一的點。
**/
computeVertex = (MyVertex) vertexChanges.getAddedVertexList().get(0);
}
/**
* 2、將請求給此點新增的邊,新增到Vertex物件中,如果資料有重複的可能,根據演算法需要決定是否去重。
**/
if (vertexChanges.getAddedEdgeList() != null) {
for (Edge<LongWritable, LongWritable> edge : vertexChanges
.getAddedEdgeList()) {
computeVertex.addEdge(edge.getDestVertexId(), edge.getValue());
}
}
/**
* 3、將Vertex物件返回,新增到最終的圖中參與計算。
**/
return computeVertex;
}
}
/**
* 確定參與計算的Vertex的行為。
*
**/
public static class MyVertex extends
Vertex<LongWritable, LongWritable, LongWritable, LongWritable> {
/**
* 將vertex的邊,按照輸入表的格式再寫到結果表。輸入表與輸出表的格式和資料都相同。
*
* @param context
* 執行時上下文
* @param messages
* 輸入訊息
**/
@Override
public void compute(
ComputeContext<LongWritable, LongWritable, LongWritable, LongWritable> context,
Iterable<LongWritable> messages) throws IOException {
/**
* 將點的ID和值,寫到儲存點的結果表
**/
context.write("vertex", getId(), getValue());
/**
* 將點的邊,寫到儲存邊的結果表
**/
if (hasEdges()) {
for (Edge<LongWritable, LongWritable> edge : getEdges()) {
context.write("edge", getId(), edge.getDestVertexId(),
edge.getValue());
}
}
/**
* 只迭代一輪
**/
voteToHalt();
}
}
/**
* @param args
* @throws IOException
*/
public static void main(String[] args) throws IOException {
if (args.length < 4) {
throw new IOException(
"Usage: VertexInputFormat <vertex input> <edge input> <vertex output> <edge output>");
}
/**
* GraphJob用於對Graph作業進行配置
*/
GraphJob job = new GraphJob();
/**
* 1、指定輸入的圖資料,並指定儲存邊資料所在的表。
*/
job.addInput(TableInfo.builder().tableName(args[0]).build());
job.addInput(TableInfo.builder().tableName(args[1]).build());
job.set(EDGE_TABLE, args[1]);
/**
* 2、指定載入資料的方式,將Record解釋為Edge,此處類似於map,生成的 key為vertex的ID,value為edge。
*/
job.setGraphLoaderClass(VertexInputLoader.class);
/**
* 3、指定載入資料階段,生成參與計算的vertex。此處類似於reduce,將map 生成的edge合併成一個vertex。
*/
job.setLoadingVertexResolverClass(LoadingResolver.class);
/**
* 4、指定參與計算的vertex的行為。每輪迭代執行vertex.compute方法。
*/
job.setVertexClass(MyVertex.class);
/**
* 5、指定圖作業的輸出表,將計算生成的結果寫到結果表中。
*/
job.addOutput(TableInfo.builder().tableName(args[2]).label("vertex").build());
job.addOutput(TableInfo.builder().tableName(args[3]).label("edge").build());
/**
* 6、提交作業執行。
*/
job.run();
}
}
複製程式碼
本文作者:雲花
本文為雲棲社群原創內容,未經允許不得轉載。