Skip to content

Commit 435b214

Browse files
committed
add dfs algorithm
1 parent 0871159 commit 435b214

File tree

5 files changed

+98
-0
lines changed

5 files changed

+98
-0
lines changed

nebula-algorithm/src/main/resources/application.conf

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,12 @@
144144
root:"10"
145145
}
146146

147+
# DFS parameter
148+
dfs:{
149+
maxIter:5
150+
root:"10"
151+
}
152+
147153
# HanpAlgo parameter
148154
hanp:{
149155
hopAttenuation:0.1

nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/Main.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import com.vesoft.nebula.algorithm.config.{
1313
CcConfig,
1414
CoefficientConfig,
1515
Configs,
16+
DfsConfig,
1617
HanpConfig,
1718
JaccardConfig,
1819
KCoreConfig,
@@ -30,6 +31,7 @@ import com.vesoft.nebula.algorithm.lib.{
3031
ClusteringCoefficientAlgo,
3132
ConnectedComponentsAlgo,
3233
DegreeStaticAlgo,
34+
DfsAlgo,
3335
GraphTriangleCountAlgo,
3436
HanpAlgo,
3537
JaccardAlgo,
@@ -204,6 +206,10 @@ object Main {
204206
val bfsConfig = BfsConfig.getBfsConfig(configs)
205207
BfsAlgo(spark, dataSet, bfsConfig)
206208
}
209+
case "dfs" => {
210+
val dfsConfig = DfsConfig.getDfsConfig(configs)
211+
DfsAlgo(spark, dataSet, dfsConfig)
212+
}
207213
case "jaccard" => {
208214
val jaccardConfig = JaccardConfig.getJaccardConfig(configs)
209215
JaccardAlgo(spark, dataSet, jaccardConfig)

nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/AlgoConfig.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,22 @@ object BfsConfig {
183183
}
184184
}
185185

186+
/**
187+
* dfs
188+
*/
189+
case class DfsConfig(maxIter: Int, root: Long)
190+
object DfsConfig {
191+
var maxIter: Int = _
192+
var root: Long = _
193+
194+
def getDfsConfig(configs: Configs): DfsConfig = {
195+
val dfsConfig = configs.algorithmConfig.map
196+
maxIter = dfsConfig("algorithm.dfs.maxIter").toInt
197+
root = dfsConfig("algorithm.dfs.root").toLong
198+
DfsConfig(maxIter, root)
199+
}
200+
}
201+
186202
/**
187203
* Hanp
188204
*/
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/* Copyright (c) 2022 vesoft inc. All rights reserved.
2+
*
3+
* This source code is licensed under Apache 2.0 License.
4+
*/
5+
6+
package com.vesoft.nebula.algorithm.lib
7+
8+
import com.vesoft.nebula.algorithm.config.{AlgoConstants, BfsConfig, DfsConfig}
9+
import com.vesoft.nebula.algorithm.utils.NebulaUtil
10+
import org.apache.spark.graphx.{EdgeDirection, Graph, VertexId}
11+
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
12+
import org.apache.spark.sql.types.{DoubleType, LongType, StringType, StructField, StructType}
13+
14+
import scala.collection.mutable
15+
16+
object DfsAlgo {
17+
var iterNums = 0
18+
19+
def apply(spark: SparkSession, dataset: Dataset[Row], dfsConfig: DfsConfig): DataFrame = {
20+
val graph: Graph[None.type, Double] = NebulaUtil.loadInitGraph(dataset, false)
21+
val bfsVertices = dfs(graph, dfsConfig.root, mutable.Seq.empty[VertexId])(dfsConfig.maxIter)
22+
23+
val schema = StructType(List(StructField("dfs", LongType, nullable = false)))
24+
25+
val rdd = spark.sparkContext.parallelize(bfsVertices.toSeq, 1).map(row => Row(row))
26+
val algoResult = spark.sqlContext
27+
.createDataFrame(rdd, schema)
28+
29+
algoResult.repartition(1)
30+
}
31+
32+
def dfs(g: Graph[None.type, Double], vertexId: VertexId, visited: mutable.Seq[VertexId])(
33+
maxIter: Int): mutable.Seq[VertexId] = {
34+
if (visited.contains(vertexId)) {
35+
visited
36+
} else {
37+
if (iterNums > maxIter) {
38+
return visited
39+
}
40+
val newVisited = visited :+ vertexId
41+
val neighbors = g.collectNeighbors(EdgeDirection.Out).lookup(vertexId).flatten
42+
iterNums = iterNums + 1
43+
neighbors.foldLeft(newVisited)((visited, neighbor) => dfs(g, neighbor._1, visited)(maxIter))
44+
}
45+
}
46+
47+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/* Copyright (c) 2022 vesoft inc. All rights reserved.
2+
*
3+
* This source code is licensed under Apache 2.0 License.
4+
*/
5+
6+
package scala.com.vesoft.nebula.algorithm.lib
7+
8+
import com.vesoft.nebula.algorithm.config.{BfsConfig, DfsConfig}
9+
import com.vesoft.nebula.algorithm.lib.{BfsAlgo, DfsAlgo}
10+
import org.apache.spark.sql.SparkSession
11+
import org.junit.Test
12+
13+
class DfsAlgoSuite {
14+
@Test
15+
def bfsAlgoSuite(): Unit = {
16+
val spark = SparkSession.builder().master("local").getOrCreate()
17+
val data = spark.read.option("header", true).csv("src/test/resources/edge.csv")
18+
val dfsAlgoConfig = new DfsConfig(5, 3)
19+
val result = DfsAlgo.apply(spark, data, dfsAlgoConfig)
20+
result.show()
21+
assert(result.count() == 4)
22+
}
23+
}

0 commit comments

Comments
 (0)