Skip to content

Commit 360ffe5

Browse files
feat: all paths between two nodes (#828)
* feat: all paths between two nodes * fix: let's have some 2.12 vs 2.13 fun * feat: adressing comments * feat: addressing comments * feat: PySpark bindings * docs: add docs page * feat: propagate all the agg_nbrs args * fix: 3.5.x does not respect spark.checkpoint.dir
1 parent 427e06d commit 360ffe5

12 files changed

Lines changed: 802 additions & 55 deletions

File tree

connect/src/main/protobuf/graphframes.proto

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ message GraphFramesAPI {
4040
RandomWalkEmbeddings rw_embeddings = 23;
4141
AggregateNeighbors aggregate_neighbors = 24;
4242
NeighborhoodAwareCDLP neighborhood_aware_cdlp = 25;
43+
AllPaths all_paths = 26;
4344
}
4445
}
4546

@@ -88,6 +89,17 @@ message BFS {
8889
int32 max_path_length = 4;
8990
}
9091

92+
message AllPaths {
93+
ColumnOrExpression from_expr = 1;
94+
ColumnOrExpression to_expr = 2;
95+
ColumnOrExpression edge_filter = 3;
96+
int32 max_path_length = 4;
97+
bool is_directed = 5;
98+
int32 checkpoint_interval = 6;
99+
bool use_local_checkpoints = 7;
100+
optional StorageLevel storage_level = 8;
101+
}
102+
91103
message ConnectedComponents {
92104
string algorithm = 1;
93105
int32 checkpoint_interval = 2;

connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,17 @@ object GraphFramesConnectUtils {
202202
.maxPathLength(bfsProto.getMaxPathLength)
203203
.run()
204204
}
205+
case proto.GraphFramesAPI.MethodCase.ALL_PATHS => {
206+
val allPathsProto = apiMessage.getAllPaths
207+
graphFrame.allPaths
208+
.toExpr(parseColumnOrExpression(allPathsProto.getToExpr, planner))
209+
.fromExpr(parseColumnOrExpression(allPathsProto.getFromExpr, planner))
210+
.edgeFilter(parseColumnOrExpression(allPathsProto.getEdgeFilter, planner))
211+
.maxPathLength(allPathsProto.getMaxPathLength)
212+
.setIsDirected(allPathsProto.getIsDirected)
213+
.setUseLocalCheckpoints(allPathsProto.getUseLocalCheckpoints)
214+
.run()
215+
}
205216
case proto.GraphFramesAPI.MethodCase.CONNECTED_COMPONENTS => {
206217
val cc = apiMessage.getConnectedComponents
207218
val ccBuilder = graphFrame.connectedComponents

core/src/main/scala/org/graphframes/GraphFrame.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,15 @@ class GraphFrame private (
701701
*/
702702
def bfs: BFS = new BFS(this)
703703

704+
/**
705+
* Enumerate all paths between source and destination vertices.
706+
*
707+
* See [[org.graphframes.lib.AllPaths]] for details.
708+
*
709+
* @group stdlib
710+
*/
711+
def allPaths: AllPaths = new AllPaths(this)
712+
704713
/**
705714
* Aggregate information from neighboring vertices and edges through a controlled traversal.
706715
*
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.graphframes.lib
19+
20+
import org.apache.spark.sql.Column
21+
import org.apache.spark.sql.DataFrame
22+
import org.apache.spark.sql.functions.array
23+
import org.apache.spark.sql.functions.array_contains
24+
import org.apache.spark.sql.functions.col
25+
import org.apache.spark.sql.functions.concat
26+
import org.apache.spark.sql.functions.expr
27+
import org.apache.spark.sql.graphframes.SparkShims
28+
import org.graphframes.GraphFrame
29+
import org.graphframes.WithCheckpointInterval
30+
import org.graphframes.WithDirection
31+
import org.graphframes.WithIntermediateStorageLevel
32+
import org.graphframes.WithLocalCheckpoints
33+
34+
/**
35+
* Computes all simple paths between source and destination vertices.
36+
*
37+
* This algorithm enumerates paths up to `maxPathLength` hops. It supports directed and undirected
38+
* traversal as well as optional edge filtering. It returns all simple paths between source and
39+
* destination vertices. Here the term "simple" means no repeated vertices. For example, if there
40+
* are paths A-B-C, A-D-C and the edge B-A, user asked to find all the paths between "A" and "C"
41+
* only A-B-C and A-D-C will be returned, but not the A-B-A-D-C. The default value of the
42+
* `maxPathLength` is `5`. Keep in mind that requesting `maxPathLength` of the scale of the graph
43+
* diameter may tend this algorithm will try to return (almost) all simple paths in the graph that
44+
* can create huge performance degradation or even OOM-like errors. Algorithm supports both
45+
* directed and undirected graphs.
46+
*
47+
* Returned DataFrame schema:
48+
* - `path`: array of vertex ids in traversal order
49+
* - `len`: number of edges in the path (Long)
50+
*
51+
* Note: in the case of undirected graph an algorithm run on the internal graph made by union
52+
* edges and reversed edges. It is assummed that graph does not have multi-edges. Results may be
53+
* unstable and unpredictable for the graph with multi-edges.
54+
*/
55+
class AllPaths private[graphframes] (private val graph: GraphFrame)
56+
extends Arguments
57+
with Serializable
58+
with WithDirection
59+
with WithLocalCheckpoints
60+
with WithCheckpointInterval
61+
with WithIntermediateStorageLevel {
62+
63+
private var maxPathLength: Int = 5
64+
private var fromExpression: Column = _
65+
private var toExpression: Column = _
66+
private var edgeFilterExpression: Option[Column] = None
67+
68+
/**
69+
* Sets the expression identifying the source (starting) vertices.
70+
*
71+
* @param value
72+
* a Column expression evaluated against vertex attributes to select source vertices
73+
* @return
74+
* this instance for method chaining
75+
*/
76+
def fromExpr(value: Column): this.type = {
77+
fromExpression = value
78+
this
79+
}
80+
81+
/**
82+
* Sets the expression identifying the source (starting) vertices.
83+
*
84+
* @param value
85+
* a SQL expression string evaluated against vertex attributes to select source vertices
86+
* @return
87+
* this instance for method chaining
88+
*/
89+
def fromExpr(value: String): this.type = fromExpr(expr(value))
90+
91+
/**
92+
* Sets the expression identifying the destination (target) vertices.
93+
*
94+
* @param value
95+
* a Column expression evaluated against vertex attributes to select destination vertices
96+
* @return
97+
* this instance for method chaining
98+
*/
99+
def toExpr(value: Column): this.type = {
100+
toExpression = value
101+
this
102+
}
103+
104+
/**
105+
* Sets the expression identifying the destination (target) vertices.
106+
*
107+
* @param value
108+
* a SQL expression string evaluated against vertex attributes to select destination vertices
109+
* @return
110+
* this instance for method chaining
111+
*/
112+
def toExpr(value: String): this.type = toExpr(expr(value))
113+
114+
/**
115+
* Sets the maximum path length (number of edges) for the enumerated paths.
116+
*
117+
* Setting a large value (e.g. on the scale of the graph diameter) may cause the algorithm to
118+
* attempt to collect a very large number of paths, leading to severe performance degradation or
119+
* out-of-memory errors. Use with caution on large or densely connected graphs.
120+
*
121+
* @param value
122+
* the maximum number of edges in a path; must be greater than 0. Default is 5.
123+
* @return
124+
* this instance for method chaining
125+
*/
126+
def maxPathLength(value: Int): this.type = {
127+
require(value > 0, s"AllPaths maxPathLength must be > 0, but was set to $value")
128+
maxPathLength = value
129+
this
130+
}
131+
132+
/**
133+
* Sets an optional filter expression applied to edges during traversal. Only edges satisfying
134+
* this condition will be considered.
135+
*
136+
* @param value
137+
* a Column expression evaluated against edge attributes
138+
* @return
139+
* this instance for method chaining
140+
*/
141+
def edgeFilter(value: Column): this.type = {
142+
edgeFilterExpression = Some(value)
143+
this
144+
}
145+
146+
/**
147+
* Sets an optional filter expression applied to edges during traversal. Only edges satisfying
148+
* this condition will be considered.
149+
*
150+
* @param value
151+
* a SQL expression string evaluated against edge attributes
152+
* @return
153+
* this instance for method chaining
154+
*/
155+
def edgeFilter(value: String): this.type = edgeFilter(expr(value))
156+
157+
/**
158+
* Executes the AllPaths algorithm and returns all simple paths between the specified source and
159+
* destination vertices.
160+
*
161+
* @return
162+
* a DataFrame with the following columns:
163+
* - `path`: an array of vertex ids in traversal order
164+
* - `len`: the number of edges in the path (Long)
165+
*/
166+
def run(): DataFrame = {
167+
require(fromExpression != null, "fromExpr is required.")
168+
require(toExpression != null, "toExpr is required.")
169+
require(
170+
graph.vertices.columns.toSet.intersect(Set("hop", "path", "len")).isEmpty,
171+
"columns `hop`, `path` and `len` are reserved by algorithm")
172+
173+
val traversalGraph = if (isDirected) {
174+
graph
175+
} else {
176+
val edgeColumns = graph.edges.columns.toSeq
177+
val reversed = graph.edges.select(
178+
(Seq(
179+
col(GraphFrame.DST).alias(GraphFrame.SRC),
180+
col(GraphFrame.SRC).alias(GraphFrame.DST)) ++
181+
edgeColumns.filterNot(c => c == GraphFrame.SRC || c == GraphFrame.DST).map(col)): _*)
182+
GraphFrame(graph.vertices, graph.edges.unionByName(reversed))
183+
}
184+
185+
val agg = traversalGraph.aggregateNeighbors
186+
.setStartingVertices(fromExpression)
187+
.setMaxHops(maxPathLength)
188+
.setTargetCondition(SparkShims.applyExprToCol(graph.spark, toExpression, "dst_attributes"))
189+
.setStoppingCondition(
190+
array_contains(col("path"), AggregateNeighbors.dstAttr(GraphFrame.ID)))
191+
.addAccumulator(
192+
"path",
193+
array(col(GraphFrame.ID)),
194+
concat(col("path"), array(AggregateNeighbors.dstAttr(GraphFrame.ID))))
195+
.setUseLocalCheckpoints(useLocalCheckpoints)
196+
.setCheckpointInterval(checkpointInterval)
197+
.setIntermediateStorageLevel(intermediateStorageLevel)
198+
199+
edgeFilterExpression.foreach { ef =>
200+
agg.setEdgeFilter(SparkShims.applyExprToCol(graph.spark, ef, "edge_attributes"))
201+
}
202+
203+
agg
204+
.run()
205+
.select(col("path"), col("hop").alias("len"))
206+
.distinct()
207+
}
208+
}

0 commit comments

Comments
 (0)