Skip to content

Commit fcebfb6

Browse files
committed
Merge remote-tracking branch 'origin/master' into df-self-join
2 parents 9a5ce19 + 75dc296 commit fcebfb6

File tree

226 files changed

+3204
-1955
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

226 files changed

+3204
-1955
lines changed

R/pkg/DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ Depends:
1111
R (>= 3.0),
1212
methods,
1313
Suggests:
14-
testthat
14+
testthat,
15+
e1071
1516
Description: R frontend for Spark
1617
License: Apache License (== 2.0)
1718
Collate:

R/pkg/NAMESPACE

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ exportMethods("glm",
1515
"predict",
1616
"summary",
1717
"kmeans",
18-
"fitted")
18+
"fitted",
19+
"naiveBayes")
1920

2021
# Job group lifecycle management methods
2122
export("setJobGroup",

R/pkg/R/generics.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,3 +1175,7 @@ setGeneric("kmeans")
11751175
#' @rdname fitted
11761176
#' @export
11771177
setGeneric("fitted")
1178+
1179+
#' @rdname naiveBayes
1180+
#' @export
1181+
setGeneric("naiveBayes", function(formula, data, ...) { standardGeneric("naiveBayes") })

R/pkg/R/mllib.R

Lines changed: 86 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@
2222
#' @export
2323
setClass("PipelineModel", representation(model = "jobj"))
2424

25+
#' @title S4 class that represents a NaiveBayesModel
26+
#' @param jobj a Java object reference to the backing Scala NaiveBayesWrapper
27+
#' @export
28+
setClass("NaiveBayesModel", representation(jobj = "jobj"))
29+
2530
#' Fits a generalized linear model
2631
#'
2732
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
@@ -42,7 +47,7 @@ setClass("PipelineModel", representation(model = "jobj"))
4247
#' @rdname glm
4348
#' @export
4449
#' @examples
45-
#'\dontrun{
50+
#' \dontrun{
4651
#' sc <- sparkR.init()
4752
#' sqlContext <- sparkRSQL.init(sc)
4853
#' data(iris)
@@ -71,7 +76,7 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram
7176
#' @rdname predict
7277
#' @export
7378
#' @examples
74-
#'\dontrun{
79+
#' \dontrun{
7580
#' model <- glm(y ~ x, trainingData)
7681
#' predicted <- predict(model, testData)
7782
#' showDF(predicted)
@@ -81,6 +86,26 @@ setMethod("predict", signature(object = "PipelineModel"),
8186
return(dataFrame(callJMethod(object@model, "transform", newData@sdf)))
8287
})
8388

89+
#' Make predictions from a naive Bayes model
90+
#'
91+
#' Makes predictions from a model produced by naiveBayes(), similarly to R package e1071's predict.
92+
#'
93+
#' @param object A fitted naive Bayes model
94+
#' @param newData DataFrame for testing
95+
#' @return DataFrame containing predicted labels in a column named "prediction"
96+
#' @rdname predict
97+
#' @export
98+
#' @examples
99+
#' \dontrun{
100+
#' model <- naiveBayes(y ~ x, trainingData)
101+
#' predicted <- predict(model, testData)
102+
#' showDF(predicted)
103+
#'}
104+
setMethod("predict", signature(object = "NaiveBayesModel"),
105+
function(object, newData) {
106+
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
107+
})
108+
84109
#' Get the summary of a model
85110
#'
86111
#' Returns the summary of a model produced by glm(), similarly to R's summary().
@@ -97,7 +122,7 @@ setMethod("predict", signature(object = "PipelineModel"),
97122
#' @rdname summary
98123
#' @export
99124
#' @examples
100-
#'\dontrun{
125+
#' \dontrun{
101126
#' model <- glm(y ~ x, trainingData)
102127
#' summary(model)
103128
#'}
@@ -140,6 +165,35 @@ setMethod("summary", signature(object = "PipelineModel"),
140165
}
141166
})
142167

168+
#' Get the summary of a naive Bayes model
169+
#'
170+
#' Returns the summary of a naive Bayes model produced by naiveBayes(), similarly to R's summary().
171+
#'
172+
#' @param object A fitted MLlib model
173+
#' @return a list containing 'apriori', the label distribution, and 'tables', conditional
174+
# probabilities given the target label
175+
#' @rdname summary
176+
#' @export
177+
#' @examples
178+
#' \dontrun{
179+
#' model <- naiveBayes(y ~ x, trainingData)
180+
#' summary(model)
181+
#'}
182+
setMethod("summary", signature(object = "NaiveBayesModel"),
183+
function(object, ...) {
184+
jobj <- object@jobj
185+
features <- callJMethod(jobj, "features")
186+
labels <- callJMethod(jobj, "labels")
187+
apriori <- callJMethod(jobj, "apriori")
188+
apriori <- t(as.matrix(unlist(apriori)))
189+
colnames(apriori) <- unlist(labels)
190+
tables <- callJMethod(jobj, "tables")
191+
tables <- matrix(tables, nrow = length(labels))
192+
rownames(tables) <- unlist(labels)
193+
colnames(tables) <- unlist(features)
194+
return(list(apriori = apriori, tables = tables))
195+
})
196+
143197
#' Fit a k-means model
144198
#'
145199
#' Fit a k-means model, similarly to R's kmeans().
@@ -152,7 +206,7 @@ setMethod("summary", signature(object = "PipelineModel"),
152206
#' @rdname kmeans
153207
#' @export
154208
#' @examples
155-
#'\dontrun{
209+
#' \dontrun{
156210
#' model <- kmeans(x, centers = 2, algorithm="random")
157211
#'}
158212
setMethod("kmeans", signature(x = "DataFrame"),
@@ -173,7 +227,7 @@ setMethod("kmeans", signature(x = "DataFrame"),
173227
#' @rdname fitted
174228
#' @export
175229
#' @examples
176-
#'\dontrun{
230+
#' \dontrun{
177231
#' model <- kmeans(trainingData, 2)
178232
#' fitted.model <- fitted(model)
179233
#' showDF(fitted.model)
@@ -192,3 +246,30 @@ setMethod("fitted", signature(object = "PipelineModel"),
192246
stop(paste("Unsupported model", modelName, sep = " "))
193247
}
194248
})
249+
250+
#' Fit a Bernoulli naive Bayes model
251+
#'
252+
#' Fit a Bernoulli naive Bayes model, similarly to R package e1071's naiveBayes() while only
253+
#' categorical features are supported. The input should be a DataFrame of observations instead of a
254+
#' contingency table.
255+
#'
256+
#' @param object A symbolic description of the model to be fitted. Currently only a few formula
257+
#' operators are supported, including '~', '.', ':', '+', and '-'.
258+
#' @param data DataFrame for training
259+
#' @param laplace Smoothing parameter
260+
#' @return a fitted naive Bayes model
261+
#' @rdname naiveBayes
262+
#' @seealso e1071: \url{https://cran.r-project.org/web/packages/e1071/}
263+
#' @export
264+
#' @examples
265+
#' \dontrun{
266+
#' df <- createDataFrame(sqlContext, infert)
267+
#' model <- naiveBayes(education ~ ., df, laplace = 0)
268+
#'}
269+
setMethod("naiveBayes", signature(formula = "formula", data = "DataFrame"),
270+
function(formula, data, laplace = 0, ...) {
271+
formula <- paste(deparse(formula), collapse = "")
272+
jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit",
273+
formula, data@sdf, laplace)
274+
return(new("NaiveBayesModel", jobj = jobj))
275+
})

R/pkg/inst/tests/testthat/test_mllib.R

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,62 @@ test_that("kmeans", {
141141
cluster <- summary.model$cluster
142142
expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1))
143143
})
144+
145+
test_that("naiveBayes", {
146+
# R code to reproduce the result.
147+
# We do not support instance weights yet. So we ignore the frequencies.
148+
#
149+
#' library(e1071)
150+
#' t <- as.data.frame(Titanic)
151+
#' t1 <- t[t$Freq > 0, -5]
152+
#' m <- naiveBayes(Survived ~ ., data = t1)
153+
#' m
154+
#' predict(m, t1)
155+
#
156+
# -- output of 'm'
157+
#
158+
# A-priori probabilities:
159+
# Y
160+
# No Yes
161+
# 0.4166667 0.5833333
162+
#
163+
# Conditional probabilities:
164+
# Class
165+
# Y 1st 2nd 3rd Crew
166+
# No 0.2000000 0.2000000 0.4000000 0.2000000
167+
# Yes 0.2857143 0.2857143 0.2857143 0.1428571
168+
#
169+
# Sex
170+
# Y Male Female
171+
# No 0.5 0.5
172+
# Yes 0.5 0.5
173+
#
174+
# Age
175+
# Y Child Adult
176+
# No 0.2000000 0.8000000
177+
# Yes 0.4285714 0.5714286
178+
#
179+
# -- output of 'predict(m, t1)'
180+
#
181+
# Yes Yes Yes Yes No No Yes Yes No No Yes Yes Yes Yes Yes Yes Yes Yes No No Yes Yes No No
182+
#
183+
184+
t <- as.data.frame(Titanic)
185+
t1 <- t[t$Freq > 0, -5]
186+
df <- suppressWarnings(createDataFrame(sqlContext, t1))
187+
m <- naiveBayes(Survived ~ ., data = df)
188+
s <- summary(m)
189+
expect_equal(as.double(s$apriori[1, "Yes"]), 0.5833333, tolerance = 1e-6)
190+
expect_equal(sum(s$apriori), 1)
191+
expect_equal(as.double(s$tables["Yes", "Age_Adult"]), 0.5714286, tolerance = 1e-6)
192+
p <- collect(select(predict(m, df), "prediction"))
193+
expect_equal(p$prediction, c("Yes", "Yes", "Yes", "Yes", "No", "No", "Yes", "Yes", "No", "No",
194+
"Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "No", "No",
195+
"Yes", "Yes", "No", "No"))
196+
197+
# Test e1071::naiveBayes
198+
if (requireNamespace("e1071", quietly = TRUE)) {
199+
expect_that(m <- e1071::naiveBayes(Survived ~ ., data = t1), not(throws_error()))
200+
expect_equal(as.character(predict(m, t1[1, ])), "Yes")
201+
}
202+
})

common/network-common/src/main/java/org/apache/spark/network/TransportContext.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@
4343

4444
/**
4545
* Contains the context to create a {@link TransportServer}, {@link TransportClientFactory}, and to
46-
* setup Netty Channel pipelines with a {@link org.apache.spark.network.server.TransportChannelHandler}.
46+
* setup Netty Channel pipelines with a
47+
* {@link org.apache.spark.network.server.TransportChannelHandler}.
4748
*
4849
* There are two communication protocols that the TransportClient provides, control-plane RPCs and
4950
* data-plane "chunk fetching". The handling of the RPCs is performed outside of the scope of the

common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
import java.nio.ByteBuffer;
2222

2323
/**
24-
* Callback for streaming data. Stream data will be offered to the {@link #onData(String, ByteBuffer)}
25-
* method as it arrives. Once all the stream data is received, {@link #onComplete(String)} will be
26-
* called.
24+
* Callback for streaming data. Stream data will be offered to the
25+
* {@link #onData(String, ByteBuffer)} method as it arrives. Once all the stream data is received,
26+
* {@link #onComplete(String)} will be called.
2727
* <p>
2828
* The network library guarantees that a single thread will call these methods at a time, but
2929
* different call may be made by different threads.

common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ private static class ClientPool {
6464
TransportClient[] clients;
6565
Object[] locks;
6666

67-
public ClientPool(int size) {
67+
ClientPool(int size) {
6868
clients = new TransportClient[size];
6969
locks = new Object[size];
7070
for (int i = 0; i < size; i++) {

common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,15 @@ public interface Message extends Encodable {
3333
boolean isBodyInFrame();
3434

3535
/** Preceding every serialized Message is its type, which allows us to deserialize it. */
36-
public static enum Type implements Encodable {
36+
enum Type implements Encodable {
3737
ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2),
3838
RpcRequest(3), RpcResponse(4), RpcFailure(5),
3939
StreamRequest(6), StreamResponse(7), StreamFailure(8),
4040
OneWayMessage(9), User(-1);
4141

4242
private final byte id;
4343

44-
private Type(int id) {
44+
Type(int id) {
4545
assert id < 128 : "Cannot have more than 128 message types";
4646
this.id = (byte) id;
4747
}

common/network-common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
package org.apache.spark.network.protocol;
1919

20-
import org.apache.spark.network.protocol.Message;
21-
2220
/** Messages from the client to the server. */
2321
public interface RequestMessage extends Message {
2422
// token interface

0 commit comments

Comments
 (0)