Skip to content

encode root for path algo #71

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -200,16 +200,16 @@ object CoefficientConfig {
/**
* bfs
*/
case class BfsConfig(maxIter: Int, root: Long, encodeId: Boolean = false)
case class BfsConfig(maxIter: Int, root: String, encodeId: Boolean = false)
object BfsConfig {
var maxIter: Int = _
var root: Long = _
var root: String = _
var encodeId: Boolean = false

def getBfsConfig(configs: Configs): BfsConfig = {
val bfsConfig = configs.algorithmConfig.map
maxIter = bfsConfig("algorithm.bfs.maxIter").toInt
root = bfsConfig("algorithm.bfs.root").toLong
root = bfsConfig("algorithm.bfs.root").toString
encodeId = ConfigUtil.getOrElseBoolean(bfsConfig, "algorithm.bfs.encodeId", false)
BfsConfig(maxIter, root, encodeId)
}
Expand All @@ -218,16 +218,16 @@ object BfsConfig {
/**
* dfs
*/
case class DfsConfig(maxIter: Int, root: Long, encodeId: Boolean = false)
case class DfsConfig(maxIter: Int, root: String, encodeId: Boolean = false)
object DfsConfig {
var maxIter: Int = _
var root: Long = _
var root: String = _
var encodeId: Boolean = false

def getDfsConfig(configs: Configs): DfsConfig = {
val dfsConfig = configs.algorithmConfig.map
maxIter = dfsConfig("algorithm.dfs.maxIter").toInt
root = dfsConfig("algorithm.dfs.root").toLong
root = dfsConfig("algorithm.dfs.root").toString
encodeId = ConfigUtil.getOrElseBoolean(dfsConfig, "algorithm.dfs.encodeId", false)
DfsConfig(maxIter, root, encodeId)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ object AlgoConstants {
val HANP_RESULT_COL: String = "hanp"
val NODE2VEC_RESULT_COL: String = "node2vec"
val BFS_RESULT_COL: String = "bfs"
val DFS_RESULT_COL: String = "dfs"
val ENCODE_ID_COL: String = "encodedId"
val ORIGIN_ID_COL: String = "id"
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,18 @@ object BfsAlgo {
*/
def apply(spark: SparkSession, dataset: Dataset[Row], bfsConfig: BfsConfig): DataFrame = {
var encodeIdDf: DataFrame = null
var finalRoot: Long = 0

val graph: Graph[None.type, Double] = if (bfsConfig.encodeId) {
val (data, encodeId) = DecodeUtil.convertStringId2LongId(dataset, false)
encodeIdDf = encodeId
finalRoot = encodeIdDf.filter(row => row.get(0).toString == bfsConfig.root).first().getLong(1)
NebulaUtil.loadInitGraph(data, false)
} else {
finalRoot = bfsConfig.root.toLong
NebulaUtil.loadInitGraph(dataset, false)
}
val bfsGraph = execute(graph, bfsConfig.maxIter, bfsConfig.root)
val bfsGraph = execute(graph, bfsConfig.maxIter, finalRoot)

// filter out the not traversal vertices
val visitedVertices = bfsGraph.vertices.filter(v => v._2 != Double.PositiveInfinity)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,16 @@

package com.vesoft.nebula.algorithm.lib

import com.vesoft.nebula.algorithm.config.AlgoConstants.{
ALGO_ID_COL,
DFS_RESULT_COL,
ENCODE_ID_COL,
ORIGIN_ID_COL
}
import com.vesoft.nebula.algorithm.config.{AlgoConstants, BfsConfig, DfsConfig}
import com.vesoft.nebula.algorithm.utils.{DecodeUtil, NebulaUtil}
import org.apache.spark.graphx.{EdgeDirection, Graph, VertexId}
import org.apache.spark.graphx.{EdgeDirection, EdgeTriplet, Graph, Pregel, VertexId}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.types.{DoubleType, LongType, StringType, StructField, StructType}

Expand All @@ -18,21 +25,28 @@ object DfsAlgo {

def apply(spark: SparkSession, dataset: Dataset[Row], dfsConfig: DfsConfig): DataFrame = {
var encodeIdDf: DataFrame = null
var finalRoot: Long = 0

val graph: Graph[None.type, Double] = if (dfsConfig.encodeId) {
val (data, encodeId) = DecodeUtil.convertStringId2LongId(dataset, false)
encodeIdDf = encodeId
finalRoot = encodeIdDf.filter(row => row.get(0).toString == dfsConfig.root).first().getLong(1)
NebulaUtil.loadInitGraph(data, false)
} else {
finalRoot = dfsConfig.root.toLong
NebulaUtil.loadInitGraph(dataset, false)
}
val bfsVertices = dfs(graph, dfsConfig.root, mutable.Seq.empty[VertexId])(dfsConfig.maxIter)
val bfsVertices =
dfs(graph, finalRoot, mutable.Seq.empty[VertexId])(dfsConfig.maxIter).vertices.filter(v =>
v._2 != Double.PositiveInfinity)

val schema = StructType(List(StructField("dfs", LongType, nullable = false)))
val schema = StructType(
List(StructField(ALGO_ID_COL, LongType, nullable = false),
StructField(DFS_RESULT_COL, DoubleType, nullable = true)))

val rdd = spark.sparkContext.parallelize(bfsVertices.toSeq, 1).map(row => Row(row))
val algoResult = spark.sqlContext
.createDataFrame(rdd, schema)
val resultRDD = bfsVertices.map(v => Row(v._1, v._2))
val algoResult =
spark.sqlContext.createDataFrame(resultRDD, schema).orderBy(col(DFS_RESULT_COL))

if (dfsConfig.encodeId) {
DecodeUtil.convertAlgoId2StringId(algoResult, encodeIdDf).coalesce(1)
Expand All @@ -42,18 +56,35 @@ object DfsAlgo {
}

def dfs(g: Graph[None.type, Double], vertexId: VertexId, visited: mutable.Seq[VertexId])(
maxIter: Int): mutable.Seq[VertexId] = {
if (visited.contains(vertexId)) {
visited
} else {
if (iterNums > maxIter) {
return visited
maxIter: Int): Graph[Double, Double] = {

val initialGraph =
g.mapVertices((id, _) => if (id == vertexId) 0.0 else Double.PositiveInfinity)

def vertexProgram(id: VertexId, attr: Double, msg: Double): Double = {
math.min(attr, msg)
}

def sendMessage(edge: EdgeTriplet[Double, Double]): Iterator[(VertexId, Double)] = {
val sourceVertex = edge.srcAttr
val targetVertex = edge.dstAttr
if (sourceVertex + 1 < targetVertex && sourceVertex < maxIter) {
Iterator((edge.dstId, sourceVertex + 1))
} else {
Iterator.empty
}
val newVisited = visited :+ vertexId
val neighbors = g.collectNeighbors(EdgeDirection.Out).lookup(vertexId).flatten
iterNums = iterNums + 1
neighbors.foldLeft(newVisited)((visited, neighbor) => dfs(g, neighbor._1, visited)(maxIter))
}

def mergeMessage(a: Double, b: Double): Double = {
math.min(a, b)
}

//开始迭代
val resultGraph =
Pregel(initialGraph, Double.PositiveInfinity)(vertexProgram, sendMessage, mergeMessage)

resultGraph

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class BfsAlgoSuite {
def bfsAlgoSuite(): Unit = {
val spark = SparkSession.builder().master("local").getOrCreate()
val data = spark.read.option("header", true).csv("src/test/resources/edge.csv")
val bfsAlgoConfig = new BfsConfig(5, 1)
val bfsAlgoConfig = new BfsConfig(5, "1")
val result = BfsAlgo.apply(spark, data, bfsAlgoConfig)
result.show()
assert(result.count() == 4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,21 @@ import org.junit.Test
class DfsAlgoSuite {
@Test
def bfsAlgoSuite(): Unit = {
val spark = SparkSession.builder().master("local").getOrCreate()
val spark = SparkSession
.builder()
.master("local")
.config("spark.sql.shuffle.partitions", 5)
.getOrCreate()
val data = spark.read.option("header", true).csv("src/test/resources/edge.csv")
val dfsAlgoConfig = new DfsConfig(5, 3)
val result = DfsAlgo.apply(spark, data, dfsAlgoConfig)
result.show()
assert(result.count() == 4)
val dfsAlgoConfig = new DfsConfig(5, "3")
// val result = DfsAlgo.apply(spark, data, dfsAlgoConfig)
// result.show()
// assert(result.count() == 4)

val encodeDfsConfig = new DfsConfig(5, "3", true)
val encodeResult = DfsAlgo.apply(spark, data, encodeDfsConfig)

encodeResult.show()
assert(encodeResult.count() == 4)
}
}