Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion example/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@
<dependency>
<groupId>com.vesoft</groupId>
<artifactId>nebula-algorithm</artifactId>
<version>${project.version}</version>
<version>3.0.0</version>
</dependency>
</dependencies>
</project>
127 changes: 127 additions & 0 deletions example/src/main/scala/com/vesoft/nebula/algorithm/DeepQueryTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/* Copyright (c) 2022 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/

package com.vesoft.nebula.algorithm

import com.vesoft.nebula.connector.connector.NebulaDataFrameReader
import com.facebook.thrift.protocol.TCompactProtocol
import com.vesoft.nebula.connector.{NebulaConnectionConfig, ReadNebulaConfig}
import org.apache.log4j.Logger
import org.apache.spark.SparkConf
import org.apache.spark.graphx.{Edge, EdgeDirection, EdgeTriplet, Graph, Pregel, VertexId}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Encoder, SparkSession}

import scala.collection.mutable

object DeepQueryTest {
private val LOGGER = Logger.getLogger(this.getClass)

def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf()
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
val spark = SparkSession
.builder()
.config(sparkConf)
.getOrCreate()
val iter = args(0).toInt
val id = args(1).toInt

query(spark, iter, id)
}

def readNebulaData(spark: SparkSession): DataFrame = {

val config =
NebulaConnectionConfig
.builder()
.withMetaAddress("192.168.15.5:9559")
.withTimeout(6000)
.withConenctionRetry(2)
.build()
val nebulaReadEdgeConfig: ReadNebulaConfig = ReadNebulaConfig
.builder()
.withSpace("twitter")
.withLabel("FOLLOW")
.withNoColumn(true)
.withLimit(20000)
.withPartitionNum(120)
.build()
val df: DataFrame =
spark.read.nebula(config, nebulaReadEdgeConfig).loadEdgesToDF()
df
}

def deepQuery(df: DataFrame,
maxIterations: Int,
startId: Int): Graph[mutable.HashSet[Int], Double] = {
implicit val encoder: Encoder[Edge[Double]] = org.apache.spark.sql.Encoders.kryo[Edge[Double]]
val edges: RDD[Edge[Double]] = df
.map(row => {
Edge(row.get(0).toString.toLong, row.get(1).toString.toLong, 1.0)
})(encoder)
.rdd

val graph = Graph.fromEdges(edges, None)

val queryGraph = graph.mapVertices { (vid, _) =>
mutable.HashSet[Int](vid.toInt)
}
queryGraph.cache()
queryGraph.numVertices
queryGraph.numEdges
df.unpersist()

def sendMessage(edge: EdgeTriplet[mutable.HashSet[Int], Double])
: Iterator[(VertexId, mutable.HashSet[Int])] = {
val (smallSet, largeSet) = if (edge.srcAttr.size < edge.dstAttr.size) {
(edge.srcAttr, edge.dstAttr)
} else {
(edge.dstAttr, edge.srcAttr)
}

if (smallSet.size == maxIterations) {
Iterator.empty
} else {
val newNeighbors =
(for (id <- smallSet; neighbor <- largeSet if neighbor != id) yield neighbor)
Iterator((edge.dstId, newNeighbors))
}
}

val initialMessage = mutable.HashSet[Int]()

val pregelGraph = Pregel(queryGraph, initialMessage, maxIterations, EdgeDirection.Both)(
vprog = (id, attr, msg) => attr ++ msg,
sendMsg = sendMessage,
mergeMsg = (a, b) => {
val setResult = a ++ b
setResult
}
)
pregelGraph.cache()
pregelGraph.numVertices
pregelGraph.numEdges
queryGraph.unpersist()
pregelGraph
}

def query(spark: SparkSession, maxIter: Int, startId: Int): Unit = {
val start = System.currentTimeMillis()
val df = readNebulaData(spark)
df.cache()
df.count()
println(s"read data cost time ${(System.currentTimeMillis() - start)}")

val startQuery = System.currentTimeMillis()
val graph = deepQuery(df, maxIter, startId)

val endQuery = System.currentTimeMillis()
val num = graph.vertices.filter(row => row._2.contains(startId)).count()
val end = System.currentTimeMillis()
println(s"query cost: ${endQuery - startQuery}")
println(s"count: ${num}, cost: ${end - endQuery}")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
package com.vesoft.nebula.algorithm

import com.facebook.thrift.protocol.TCompactProtocol
import com.vesoft.nebula.algorithm.config.{CcConfig, PRConfig}
import com.vesoft.nebula.algorithm.lib.{PageRankAlgo, StronglyConnectedComponentsAlgo}
import com.vesoft.nebula.algorithm.config.PRConfig
import com.vesoft.nebula.algorithm.lib.PageRankAlgo
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{col, dense_rank}
import org.apache.spark.sql.functions.{col, dense_rank, monotonically_increasing_id}
import org.apache.spark.sql.{DataFrame, SparkSession}

object PageRankExample {
Expand Down Expand Up @@ -69,6 +69,8 @@ object PageRankExample {
// encode id to Long type using dense_rank, the encodeId has two columns: id, encodedId
// then you need to save the encodeId to convert back for the algorithm's result.
val encodeId = idDF.withColumn("encodedId", dense_rank().over(Window.orderBy("id")))
// using function monotonically_increasing_id(), please refer https://spark.apache.org/docs/3.0.2/api/java/org/apache/spark/sql/functions.html#monotonically_increasing_id--
// val encodeId = idDF.withColumn("encodedId", monotonically_increasing_id())
encodeId.write.option("header", true).csv("file:///tmp/encodeId.csv")
encodeId.show()

Expand Down
11 changes: 11 additions & 0 deletions nebula-algorithm/src/main/resources/application.conf
Original file line number Diff line number Diff line change
Expand Up @@ -83,23 +83,27 @@
pagerank: {
maxIter: 10
resetProb: 0.15 # default 0.15
encodeId:false # if your data has string type id, please config encodeId as true.
}

# Louvain parameter
louvain: {
maxIter: 20
internalIter: 10
tol: 0.5
encodeId:false
}

# connected component parameter.
connectedcomponent: {
maxIter: 20
encodeId:false
}

# LabelPropagation parameter
labelpropagation: {
maxIter: 20
encodeId:false
}

# ShortestPaths parameter
Expand All @@ -115,6 +119,7 @@
kcore:{
maxIter:10
degree:1
encodeId:false
}

# Trianglecount parameter
Expand All @@ -126,13 +131,15 @@
# Betweenness centrality parameter. maxIter parameter means the max times of iterations.
betweenness:{
maxIter:5
encodeId:false
}

# Clustering Coefficient parameter. The type parameter has two choice, local or global
# local type will compute the clustering coefficient for each vertex, and print the average coefficient for graph.
# global type just compute the graph's clustering coefficient.
clusteringcoefficient:{
type: local
encodeId:false
}

# ClosenessAlgo parameter
Expand All @@ -142,19 +149,22 @@
bfs:{
maxIter:5
root:"10"
encodeId:false
}

# DFS parameter
dfs:{
maxIter:5
root:"10"
encodeId:false
}

# HanpAlgo parameter
hanp:{
hopAttenuation:0.1
maxIter:10
preference:1.0
encodeId:false
}

#Node2vecAlgo parameter
Expand All @@ -173,6 +183,7 @@
degree: 30,
embSeparate: ",",
modelPath: "hdfs://127.0.0.1:9000/model"
encodeId:false
}

# JaccardAlgo parameter
Expand Down
Loading