diff --git a/nebula-algorithm/src/main/resources/application.conf b/nebula-algorithm/src/main/resources/application.conf index a518e89..38de228 100644 --- a/nebula-algorithm/src/main/resources/application.conf +++ b/nebula-algorithm/src/main/resources/application.conf @@ -10,7 +10,7 @@ } data: { - # data source. optional of nebula,csv,json + # data source. optional of nebula,nebula-ngql,csv,json source: csv # data sink, means the algorithm result will be write into this sink. optional of nebula,csv,text sink: csv diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/Main.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/Main.scala index fa60bca..4aa1812 100644 --- a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/Main.scala +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/Main.scala @@ -119,6 +119,10 @@ object Main { val reader = new NebulaReader(spark, configs, partitionNum) reader.read() } + case "nebula-ngql" => { + val reader = new NebulaReader(spark, configs, partitionNum) + reader.readNgql() + } case "csv" => { val reader = new CsvReader(spark, configs, partitionNum) reader.read() diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/Configs.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/Configs.scala index a508ab6..a2161cc 100644 --- a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/Configs.scala +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/Configs.scala @@ -90,8 +90,18 @@ object NebulaConfigEntry { } else { List() } - val readConfigEntry = + val readConfigEntry = if (nebulaConfig.hasPath("read.ngql")) { + val readGraphAddress = nebulaConfig.getString("read.graphAddress") + val ngql = nebulaConfig.getString("read.ngql") + NebulaReadConfigEntry(readMetaAddress, + readSpace, + readLabels, + readWeightCols, + readGraphAddress, + ngql) + } else { NebulaReadConfigEntry(readMetaAddress, readSpace, readLabels, readWeightCols) + } val graphAddress = nebulaConfig.getString("write.graphAddress") val writeMetaAddress = nebulaConfig.getString("write.metaAddress") @@ -203,11 +213,13 @@ case class NebulaConfigEntry(readConfigEntry: NebulaReadConfigEntry, case class NebulaReadConfigEntry(address: String = "", space: String = "", labels: List[String] = List(), - weightCols: List[String] = List()) { + weightCols: List[String] = List(), + graphAddress: String = "", + ngql: String = "") { override def toString: String = { s"NebulaReadConfigEntry: " + s"{address: $address, space: $space, labels: ${labels.mkString(",")}, " + - s"weightCols: ${weightCols.mkString(",")}}" + s"weightCols: ${weightCols.mkString(",")}, ngql: $ngql}" } } diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/reader/DataReader.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/reader/DataReader.scala index e478812..431db22 100644 --- a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/reader/DataReader.scala +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/reader/DataReader.scala @@ -65,6 +65,54 @@ class NebulaReader(spark: SparkSession, configs: Configs, partitionNum: String) } dataset } + + def readNgql(): DataFrame = { + val metaAddress = configs.nebulaConfig.readConfigEntry.address + val graphAddress = configs.nebulaConfig.readConfigEntry.graphAddress + val space = configs.nebulaConfig.readConfigEntry.space + val labels = configs.nebulaConfig.readConfigEntry.labels + val weights = configs.nebulaConfig.readConfigEntry.weightCols + val partition = partitionNum.toInt + val ngql = configs.nebulaConfig.readConfigEntry.ngql + + val config = + NebulaConnectionConfig + .builder() + .withMetaAddress(metaAddress) + .withGraphAddress(graphAddress) + .withConenctionRetry(2) + .build() + + var dataset: DataFrame = null + for (i <- labels.indices) { + val returnCols: ListBuffer[String] = new ListBuffer[String] + if (configs.dataSourceSinkEntry.hasWeight && weights.nonEmpty) { + returnCols.append(weights(i)) + } + val nebulaReadEdgeConfig: ReadNebulaConfig = ReadNebulaConfig + .builder() + .withSpace(space) + .withLabel(labels(i)) + .withPartitionNum(partition) + .withNgql(ngql) + .build() + if (dataset == null) { + dataset = spark.read.nebula(config, nebulaReadEdgeConfig).loadEdgesToDF() + if (weights.nonEmpty) { + dataset = dataset.select("_srcId", "_dstId", weights(i)) + } + } else { + var df = spark.read + .nebula(config, nebulaReadEdgeConfig) + .loadEdgesToDF() + if (weights.nonEmpty) { + df = df.select("_srcId", "_dstId", weights(i)) + } + dataset = dataset.union(df) + } + } + dataset + } } class CsvReader(spark: SparkSession, configs: Configs, partitionNum: String)