diff --git a/crunch-core/src/main/java/org/apache/crunch/PCollection.java b/crunch-core/src/main/java/org/apache/crunch/PCollection.java index 2d62d003..1d3598c5 100644 --- a/crunch-core/src/main/java/org/apache/crunch/PCollection.java +++ b/crunch-core/src/main/java/org/apache/crunch/PCollection.java @@ -267,4 +267,9 @@ PTable parallelDo(String name, DoFn> doFn, PTableType * Returns a {@code PObject} of the minimum element of this instance. */ PObject min(); + + /** + * Returns a {@code PObject} of an aggregate of this instance. + */ + PObject aggregate(Aggregator aggregator); } diff --git a/crunch-core/src/main/java/org/apache/crunch/impl/dist/collect/PCollectionImpl.java b/crunch-core/src/main/java/org/apache/crunch/impl/dist/collect/PCollectionImpl.java index ee820f0c..6e1a713d 100644 --- a/crunch-core/src/main/java/org/apache/crunch/impl/dist/collect/PCollectionImpl.java +++ b/crunch-core/src/main/java/org/apache/crunch/impl/dist/collect/PCollectionImpl.java @@ -19,6 +19,8 @@ import com.google.common.collect.Lists; import com.google.common.collect.Sets; + +import org.apache.crunch.Aggregator; import org.apache.crunch.CachingOptions; import org.apache.crunch.DoFn; import org.apache.crunch.FilterFn; @@ -259,6 +261,11 @@ public PObject max() { public PObject min() { return Aggregate.min(this); } + + @Override + public PObject aggregate(Aggregator aggregator) { + return Aggregate.aggregate(this, aggregator); + } @Override public PTypeFamily getTypeFamily() { diff --git a/crunch-core/src/main/java/org/apache/crunch/impl/mem/collect/MemCollection.java b/crunch-core/src/main/java/org/apache/crunch/impl/mem/collect/MemCollection.java index 81433eb9..c586fa52 100644 --- a/crunch-core/src/main/java/org/apache/crunch/impl/mem/collect/MemCollection.java +++ b/crunch-core/src/main/java/org/apache/crunch/impl/mem/collect/MemCollection.java @@ -25,6 +25,7 @@ import javassist.util.proxy.MethodHandler; import javassist.util.proxy.ProxyFactory; +import org.apache.crunch.Aggregator; import org.apache.crunch.CachingOptions; import org.apache.crunch.DoFn; import org.apache.crunch.FilterFn; @@ -240,6 +241,11 @@ public PObject min() { return Aggregate.min(this); } + @Override + public PObject aggregate(Aggregator aggregator) { + return Aggregate.aggregate(this, aggregator); + } + @Override public PCollection filter(FilterFn filterFn) { return parallelDo(filterFn, getPType()); diff --git a/crunch-core/src/main/java/org/apache/crunch/lib/Aggregate.java b/crunch-core/src/main/java/org/apache/crunch/lib/Aggregate.java index d8388b33..a07a2846 100644 --- a/crunch-core/src/main/java/org/apache/crunch/lib/Aggregate.java +++ b/crunch-core/src/main/java/org/apache/crunch/lib/Aggregate.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.PriorityQueue; +import org.apache.crunch.Aggregator; import org.apache.crunch.CombineFn; import org.apache.crunch.DoFn; import org.apache.crunch.Emitter; @@ -277,4 +278,16 @@ public Collection map(Iterable values) { } }, tf.collections(collect.getValueType())); } + + public static PObject aggregate(PCollection collect, Aggregator aggregator) { + PTypeFamily tf = collect.getTypeFamily(); + PCollection aggregation = collect.parallelDo("Aggregate.aggregator", new MapFn>() { + public Pair map(S input) { + return Pair.of(0L, input); + } + }, tf.tableOf(tf.longs(), collect.getPType())) + .groupByKey() + .combineValues(aggregator).values(); + return new FirstElementPObject(aggregation); + } } diff --git a/crunch-examples/src/main/java/org/apache/crunch/examples/TotalWordCount.java b/crunch-examples/src/main/java/org/apache/crunch/examples/TotalWordCount.java new file mode 100644 index 00000000..a4309eed --- /dev/null +++ b/crunch-examples/src/main/java/org/apache/crunch/examples/TotalWordCount.java @@ -0,0 +1,78 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.crunch.examples; + +import java.io.Serializable; + +import org.apache.crunch.DoFn; +import org.apache.crunch.Emitter; +import org.apache.crunch.PCollection; +import org.apache.crunch.PObject; +import org.apache.crunch.PTable; +import org.apache.crunch.Pipeline; +import org.apache.crunch.PipelineResult; +import org.apache.crunch.fn.Aggregators; +import org.apache.crunch.impl.mr.MRPipeline; +import org.apache.crunch.types.writable.Writables; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.conf.Configured; +import org.apache.hadoop.util.GenericOptionsParser; +import org.apache.hadoop.util.Tool; +import org.apache.hadoop.util.ToolRunner; + +public class TotalWordCount extends Configured implements Tool, Serializable { + public int run(String[] args) throws Exception { + if (args.length != 1) { + System.err.println(); + System.err.println("Usage: " + this.getClass().getName() + " [generic options] input"); + System.err.println(); + GenericOptionsParser.printGenericCommandUsage(System.err); + return 1; + } + // Create an object to coordinate pipeline creation and execution. + Pipeline pipeline = new MRPipeline(TotalWordCount.class, getConf()); + // Reference a given text file as a collection of Strings. + PCollection lines = pipeline.readTextFile(args[0]); + + // Define a function that splits each line in a PCollection of Strings into + // a + // PCollection made up of the individual words in the file. + PCollection numberOfWords = lines.parallelDo(new DoFn() { + public void process(String line, Emitter emitter) { + emitter.emit((long)line.split("\\s+").length); + } + }, Writables.longs()); // Indicates the serialization format + + // The aggregate method groups a collection into a single PObject. + PObject totalCount = numberOfWords.aggregate(Aggregators.SUM_LONGS()); + + // Execute the pipeline as a MapReduce. + PipelineResult result = pipeline.run(); + + System.out.println("Total number of words: " + totalCount.getValue()); + + pipeline.done(); + + return result.succeeded() ? 0 : 1; + } + + public static void main(String[] args) throws Exception { + int result = ToolRunner.run(new Configuration(), new TotalWordCount(), args); + System.exit(result); + } +}