Java 8 Stream API - a new way of dealing with Collections

25 March 2015


Java 8 Stream API

Perhaps one of the most used Java API is the Collections API, Lists, HashMaps, Sets, and others are used pretty much all the time, good news is that Java 8 provides an enhanced new way of dealing with these data structures.

The Java 8 streams API boosts the Collections API with parallel processing capabilities and functional programming model.


Typical Scenario

  1. Given a 2D Map, find the K closest Points from Origin

Code in Java

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;

public class Solution {

<span class="kd">public</span> <span class="kd">static</span> <span class="kt">void</span> <span class="nf">main</span><span class="o">(</span><span class="n">String</span> <span class="n">args</span><span class="o">[])</span> <span class="o">{</span>

    <span class="n">Point</span> <span class="n">p1</span> <span class="o">=</span> <span class="k">new</span> <span class="n">Point</span><span class="o">(</span><span class="mi">1</span><span class="o">,</span> <span class="mi">3</span><span class="o">,</span> <span class="n">findDistanceOfPoints</span><span class="o">(</span><span class="mi">1</span><span class="o">,</span> <span class="mi">3</span><span class="o">,</span> <span class="mi">0</span><span class="o">,</span> <span class="mi">0</span><span class="o">));</span>
    <span class="n">Point</span> <span class="n">p2</span> <span class="o">=</span> <span class="k">new</span> <span class="n">Point</span><span class="o">(</span><span class="mi">3</span><span class="o">,</span> <span class="mi">4</span><span class="o">,</span> <span class="n">findDistanceOfPoints</span><span class="o">(</span><span class="mi">3</span><span class="o">,</span> <span class="mi">4</span><span class="o">,</span> <span class="mi">0</span><span class="o">,</span> <span class="mi">0</span><span class="o">));</span>
    <span class="n">Point</span> <span class="n">p3</span> <span class="o">=</span> <span class="k">new</span> <span class="n">Point</span><span class="o">(-</span><span class="mi">1</span><span class="o">,</span> <span class="mi">5</span><span class="o">,</span> <span class="n">findDistanceOfPoints</span><span class="o">(-</span><span class="mi">1</span><span class="o">,</span> <span class="mi">5</span><span class="o">,</span> <span class="mi">0</span><span class="o">,</span> <span class="mi">0</span><span class="o">));</span>
    <span class="n">Point</span> <span class="n">p4</span> <span class="o">=</span> <span class="k">new</span> <span class="n">Point</span><span class="o">(-</span><span class="mi">2</span><span class="o">,</span> <span class="mi">2</span><span class="o">,</span> <span class="n">findDistanceOfPoints</span><span class="o">(-</span><span class="mi">2</span><span class="o">,</span> <span class="mi">2</span><span class="o">,</span> <span class="mi">0</span><span class="o">,</span> <span class="mi">0</span><span class="o">));</span>
    <span class="n">Point</span> <span class="n">p5</span> <span class="o">=</span> <span class="k">new</span> <span class="n">Point</span><span class="o">(</span><span class="mi">2</span><span class="o">,</span> <span class="mi">3</span><span class="o">,</span> <span class="n">findDistanceOfPoints</span><span class="o">(</span><span class="mi">2</span><span class="o">,</span> <span class="mi">3</span><span class="o">,</span> <span class="mi">0</span><span class="o">,</span> <span class="mi">0</span><span class="o">));</span>

    <span class="n">List</span><span class="o">&lt;</span><span class="n">Point</span><span class="o">&gt;</span> <span class="n">ptList</span> <span class="o">=</span> <span class="k">new</span> <span class="n">ArrayList</span><span class="o">&lt;</span><span class="n">Point</span><span class="o">&gt;();</span>

    <span class="n">ptList</span><span class="o">.</span><span class="na">add</span><span class="o">(</span><span class="n">p1</span><span class="o">);</span>
    <span class="n">ptList</span><span class="o">.</span><span class="na">add</span><span class="o">(</span><span class="n">p2</span><span class="o">);</span>
    <span class="n">ptList</span><span class="o">.</span><span class="na">add</span><span class="o">(</span><span class="n">p3</span><span class="o">);</span>
    <span class="n">ptList</span><span class="o">.</span><span class="na">add</span><span class="o">(</span><span class="n">p4</span><span class="o">);</span>
    <span class="n">ptList</span><span class="o">.</span><span class="na">add</span><span class="o">(</span><span class="n">p5</span><span class="o">);</span>

    <span class="kt">int</span> <span class="n">k</span> <span class="o">=</span> <span class="mi">3</span><span class="o">;</span>
    <span class="n">printClosestPointsToOrigin</span><span class="o">(</span><span class="n">ptList</span><span class="o">,</span> <span class="n">k</span><span class="o">);</span>
<span class="o">}</span>

<span class="kd">public</span> <span class="kd">static</span> <span class="kt">void</span> <span class="nf">printClosestPointsToOrigin</span><span class="o">(</span><span class="n">List</span><span class="o">&lt;</span><span class="n">Point</span><span class="o">&gt;</span> <span class="n">ptList</span><span class="o">,</span> <span class="kt">int</span> <span class="n">k</span><span class="o">)</span> <span class="o">{</span>
    <span class="n">ptList</span><span class="o">.</span><span class="na">sort</span><span class="o">(</span><span class="k">new</span> <span class="n">PointDistanceFromOriginComparator</span><span class="o">());</span>
    <span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="o">;</span>

    <span class="k">while</span> <span class="o">(</span><span class="n">i</span> <span class="o">&lt;</span> <span class="n">ptList</span><span class="o">.</span><span class="na">size</span><span class="o">()</span> <span class="o">&amp;&amp;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">k</span><span class="o">)</span> <span class="o">{</span>
        <span class="n">Point</span> <span class="n">p</span> <span class="o">=</span> <span class="n">ptList</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="n">i</span><span class="o">);</span>
        <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"X: "</span> <span class="o">+</span> <span class="n">p</span><span class="o">.</span><span class="na">x</span> <span class="o">+</span> <span class="s">" Y: "</span> <span class="o">+</span> <span class="n">p</span><span class="o">.</span><span class="na">y</span><span class="o">);</span>
        <span class="n">i</span><span class="o">++;</span>
    <span class="o">}</span>

<span class="o">}</span>

<span class="kd">public</span> <span class="kd">static</span> <span class="kt">int</span> <span class="nf">findDistanceOfPoints</span><span class="o">(</span><span class="kt">int</span> <span class="n">x1</span><span class="o">,</span> <span class="kt">int</span> <span class="n">x2</span><span class="o">,</span> <span class="kt">int</span> <span class="n">y1</span><span class="o">,</span> <span class="kt">int</span> <span class="n">y2</span><span class="o">)</span> <span class="o">{</span>
    <span class="n">Double</span> <span class="n">dist</span> <span class="o">=</span> <span class="n">Math</span><span class="o">.</span><span class="na">sqrt</span><span class="o">(</span><span class="n">Math</span><span class="o">.</span><span class="na">pow</span><span class="o">(</span><span class="n">x1</span> <span class="o">-</span> <span class="n">x2</span><span class="o">,</span> <span class="mi">2</span><span class="o">)</span> <span class="o">+</span> <span class="n">Math</span><span class="o">.</span><span class="na">pow</span><span class="o">(</span><span class="n">y1</span> <span class="o">-</span> <span class="n">y2</span><span class="o">,</span> <span class="mi">2</span><span class="o">));</span>
    <span class="k">return</span> <span class="n">dist</span><span class="o">.</span><span class="na">intValue</span><span class="o">();</span>
<span class="o">}</span>

}

class Point {

<span class="kt">int</span> <span class="n">x</span><span class="o">,</span> <span class="n">y</span><span class="o">;</span>
<span class="kt">int</span> <span class="n">distanceFromOrigin</span><span class="o">;</span>

<span class="kd">public</span> <span class="nf">Point</span><span class="o">(</span><span class="kt">int</span> <span class="n">x</span><span class="o">,</span> <span class="kt">int</span> <span class="n">y</span><span class="o">,</span> <span class="kt">int</span> <span class="n">distO</span><span class="o">)</span> <span class="o">{</span>
    <span class="k">this</span><span class="o">.</span><span class="na">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">;</span>
    <span class="k">this</span><span class="o">.</span><span class="na">y</span> <span class="o">=</span> <span class="n">y</span><span class="o">;</span>
    <span class="k">this</span><span class="o">.</span><span class="na">distanceFromOrigin</span> <span class="o">=</span> <span class="n">distO</span><span class="o">;</span>
<span class="o">}</span>

<span class="nd">@Override</span>
<span class="kd">public</span> <span class="n">String</span> <span class="nf">toString</span><span class="o">()</span> <span class="o">{</span>
    <span class="k">return</span> <span class="s">"Point [x="</span> <span class="o">+</span> <span class="n">x</span> <span class="o">+</span> <span class="s">", y="</span> <span class="o">+</span> <span class="n">y</span> <span class="o">+</span> <span class="s">", distanceFromOrigin="</span>
            <span class="o">+</span> <span class="n">distanceFromOrigin</span> <span class="o">+</span> <span class="s">"]"</span><span class="o">;</span>
<span class="o">}</span>    

}

class PointDistanceFromOriginComparator implements Comparator<Point> {

<span class="kd">public</span> <span class="kt">int</span> <span class="nf">compare</span><span class="o">(</span><span class="n">Point</span> <span class="n">p1</span><span class="o">,</span> <span class="n">Point</span> <span class="n">p2</span><span class="o">)</span> <span class="o">{</span>
    <span class="k">if</span> <span class="o">(</span><span class="n">p1</span><span class="o">.</span><span class="na">distanceFromOrigin</span> <span class="o">&lt;</span> <span class="n">p2</span><span class="o">.</span><span class="na">distanceFromOrigin</span><span class="o">)</span> <span class="o">{</span>
        <span class="k">return</span> <span class="o">-</span><span class="mi">1</span><span class="o">;</span>
    <span class="o">}</span> <span class="k">else</span> <span class="k">if</span> <span class="o">(</span><span class="n">p1</span><span class="o">.</span><span class="na">distanceFromOrigin</span> <span class="o">&gt;</span> <span class="n">p2</span><span class="o">.</span><span class="na">distanceFromOrigin</span><span class="o">)</span> <span class="o">{</span>
        <span class="k">return</span> <span class="mi">1</span><span class="o">;</span>
    <span class="o">}</span>
    <span class="k">return</span> <span class="mi">0</span><span class="o">;</span>
<span class="o">}</span>

}


In this version there's a lot of boiler plate code that we need to declare, like sorting the array, fetch the first K elements, iterating over it and printing the results.

    public static void printClosestPointsToOrigin(List<Point> ptList, int k) {
        ptList.sort(new PointDistanceFromOriginComparator());
        int i = 0;

    <span class="k">while</span> <span class="o">(</span><span class="n">i</span> <span class="o">&lt;</span> <span class="n">ptList</span><span class="o">.</span><span class="na">size</span><span class="o">()</span> <span class="o">&amp;&amp;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">k</span><span class="o">)</span> <span class="o">{</span>
        <span class="n">Point</span> <span class="n">p</span> <span class="o">=</span> <span class="n">ptList</span><span class="o">.</span><span class="na">get</span><span class="o">(</span><span class="n">i</span><span class="o">);</span>
        <span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">"X: "</span> <span class="o">+</span> <span class="n">p</span><span class="o">.</span><span class="na">x</span> <span class="o">+</span> <span class="s">" Y: "</span> <span class="o">+</span> <span class="n">p</span><span class="o">.</span><span class="na">y</span><span class="o">);</span>
        <span class="n">i</span><span class="o">++;</span>
    <span class="o">}</span>

<span class="o">}</span></code></pre></figure>

with Java 8 stream API we can do that in basically one shot, reducing the amount of code:

    public static void printClosestPointsToOrigin(List<Point> ptList, int k) {

    <span class="c1">// sort by distance from origin and print</span>
    <span class="n">ptList</span><span class="o">.</span><span class="na">stream</span><span class="o">().</span><span class="na">sorted</span><span class="o">(</span><span class="n">PointDistanceFromOriginComparator</span><span class="o">.</span><span class="na">INSTANCE</span><span class="o">)</span>
            <span class="o">.</span><span class="na">limit</span><span class="o">(</span><span class="n">k</span><span class="o">).</span><span class="na">forEach</span><span class="o">(</span><span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">::</span><span class="n">println</span><span class="o">);</span>
<span class="o">}</span></code></pre></figure>

Java 8 final version

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;

public class Solution {

<span class="kd">public</span> <span class="kd">static</span> <span class="kt">void</span> <span class="nf">main</span><span class="o">(</span><span class="n">String</span> <span class="n">args</span><span class="o">[])</span> <span class="o">{</span>

    <span class="n">Point</span> <span class="n">p1</span> <span class="o">=</span> <span class="k">new</span> <span class="n">Point</span><span class="o">(</span><span class="mi">1</span><span class="o">,</span> <span class="mi">3</span><span class="o">,</span> <span class="n">findDistanceOfPoints</span><span class="o">(</span><span class="mi">1</span><span class="o">,</span> <span class="mi">3</span><span class="o">,</span> <span class="mi">0</span><span class="o">,</span> <span class="mi">0</span><span class="o">));</span>
    <span class="n">Point</span> <span class="n">p2</span> <span class="o">=</span> <span class="k">new</span> <span class="n">Point</span><span class="o">(</span><span class="mi">3</span><span class="o">,</span> <span class="mi">4</span><span class="o">,</span> <span class="n">findDistanceOfPoints</span><span class="o">(</span><span class="mi">3</span><span class="o">,</span> <span class="mi">4</span><span class="o">,</span> <span class="mi">0</span><span class="o">,</span> <span class="mi">0</span><span class="o">));</span>
    <span class="n">Point</span> <span class="n">p3</span> <span class="o">=</span> <span class="k">new</span> <span class="n">Point</span><span class="o">(-</span><span class="mi">1</span><span class="o">,</span> <span class="mi">5</span><span class="o">,</span> <span class="n">findDistanceOfPoints</span><span class="o">(-</span><span class="mi">1</span><span class="o">,</span> <span class="mi">5</span><span class="o">,</span> <span class="mi">0</span><span class="o">,</span> <span class="mi">0</span><span class="o">));</span>
    <span class="n">Point</span> <span class="n">p4</span> <span class="o">=</span> <span class="k">new</span> <span class="n">Point</span><span class="o">(-</span><span class="mi">2</span><span class="o">,</span> <span class="mi">2</span><span class="o">,</span> <span class="n">findDistanceOfPoints</span><span class="o">(-</span><span class="mi">2</span><span class="o">,</span> <span class="mi">2</span><span class="o">,</span> <span class="mi">0</span><span class="o">,</span> <span class="mi">0</span><span class="o">));</span>
    <span class="n">Point</span> <span class="n">p5</span> <span class="o">=</span> <span class="k">new</span> <span class="n">Point</span><span class="o">(</span><span class="mi">2</span><span class="o">,</span> <span class="mi">3</span><span class="o">,</span> <span class="n">findDistanceOfPoints</span><span class="o">(</span><span class="mi">2</span><span class="o">,</span> <span class="mi">3</span><span class="o">,</span> <span class="mi">0</span><span class="o">,</span> <span class="mi">0</span><span class="o">));</span>

    <span class="n">List</span><span class="o">&lt;</span><span class="n">Point</span><span class="o">&gt;</span> <span class="n">ptList</span> <span class="o">=</span> <span class="k">new</span> <span class="n">ArrayList</span><span class="o">&lt;</span><span class="n">Point</span><span class="o">&gt;();</span>

    <span class="n">ptList</span><span class="o">.</span><span class="na">add</span><span class="o">(</span><span class="n">p1</span><span class="o">);</span>
    <span class="n">ptList</span><span class="o">.</span><span class="na">add</span><span class="o">(</span><span class="n">p2</span><span class="o">);</span>
    <span class="n">ptList</span><span class="o">.</span><span class="na">add</span><span class="o">(</span><span class="n">p3</span><span class="o">);</span>
    <span class="n">ptList</span><span class="o">.</span><span class="na">add</span><span class="o">(</span><span class="n">p4</span><span class="o">);</span>
    <span class="n">ptList</span><span class="o">.</span><span class="na">add</span><span class="o">(</span><span class="n">p5</span><span class="o">);</span>

    <span class="kt">int</span> <span class="n">k</span> <span class="o">=</span> <span class="mi">3</span><span class="o">;</span>

    <span class="n">printClosestPointsToOrigin</span><span class="o">(</span><span class="n">ptList</span><span class="o">,</span> <span class="n">k</span><span class="o">);</span>

<span class="o">}</span>

<span class="kd">public</span> <span class="kd">static</span> <span class="kt">void</span> <span class="nf">printClosestPointsToOrigin</span><span class="o">(</span><span class="n">List</span><span class="o">&lt;</span><span class="n">Point</span><span class="o">&gt;</span> <span class="n">ptList</span><span class="o">,</span> <span class="kt">int</span> <span class="n">k</span><span class="o">)</span> <span class="o">{</span>

    <span class="c1">// sort by distance from origin and print</span>
    <span class="n">ptList</span><span class="o">.</span><span class="na">stream</span><span class="o">().</span><span class="na">sorted</span><span class="o">(</span><span class="n">PointDistanceFromOriginComparator</span><span class="o">.</span><span class="na">INSTANCE</span><span class="o">)</span>
            <span class="o">.</span><span class="na">limit</span><span class="o">(</span><span class="n">k</span><span class="o">).</span><span class="na">forEach</span><span class="o">(</span><span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">::</span><span class="n">println</span><span class="o">);</span>
<span class="o">}</span>

<span class="kd">public</span> <span class="kd">static</span> <span class="kt">int</span> <span class="nf">findDistanceOfPoints</span><span class="o">(</span><span class="kt">int</span> <span class="n">x1</span><span class="o">,</span> <span class="kt">int</span> <span class="n">x2</span><span class="o">,</span> <span class="kt">int</span> <span class="n">y1</span><span class="o">,</span> <span class="kt">int</span> <span class="n">y2</span><span class="o">)</span> <span class="o">{</span>
    <span class="n">Double</span> <span class="n">dist</span> <span class="o">=</span> <span class="n">Math</span><span class="o">.</span><span class="na">sqrt</span><span class="o">(</span><span class="n">Math</span><span class="o">.</span><span class="na">pow</span><span class="o">(</span><span class="n">x1</span> <span class="o">-</span> <span class="n">x2</span><span class="o">,</span> <span class="mi">2</span><span class="o">)</span> <span class="o">+</span> <span class="n">Math</span><span class="o">.</span><span class="na">pow</span><span class="o">(</span><span class="n">y1</span> <span class="o">-</span> <span class="n">y2</span><span class="o">,</span> <span class="mi">2</span><span class="o">));</span>
    <span class="k">return</span> <span class="n">dist</span><span class="o">.</span><span class="na">intValue</span><span class="o">();</span>
<span class="o">}</span>

}

class Point { Integer x, y; Integer distanceFromOrigin;

<span class="kd">public</span> <span class="nf">Point</span><span class="o">(</span><span class="n">Integer</span> <span class="n">x</span><span class="o">,</span> <span class="n">Integer</span> <span class="n">y</span><span class="o">,</span> <span class="n">Integer</span> <span class="n">distO</span><span class="o">)</span> <span class="o">{</span>
    <span class="k">this</span><span class="o">.</span><span class="na">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">;</span>
    <span class="k">this</span><span class="o">.</span><span class="na">y</span> <span class="o">=</span> <span class="n">y</span><span class="o">;</span>
    <span class="k">this</span><span class="o">.</span><span class="na">distanceFromOrigin</span> <span class="o">=</span> <span class="n">distO</span><span class="o">;</span>
<span class="o">}</span>

<span class="nd">@Override</span>
<span class="kd">public</span> <span class="n">String</span> <span class="nf">toString</span><span class="o">()</span> <span class="o">{</span>
    <span class="k">return</span> <span class="s">"Point [x="</span> <span class="o">+</span> <span class="n">x</span> <span class="o">+</span> <span class="s">", y="</span> <span class="o">+</span> <span class="n">y</span> <span class="o">+</span> <span class="s">", distanceFromOrigin="</span>
            <span class="o">+</span> <span class="n">distanceFromOrigin</span> <span class="o">+</span> <span class="s">"]"</span><span class="o">;</span>
<span class="o">}</span>    

}

class PointDistanceFromOriginComparator implements Comparator<Point> {

<span class="kd">public</span> <span class="kd">static</span> <span class="kd">final</span> <span class="n">PointDistanceFromOriginComparator</span> <span class="n">INSTANCE</span> <span class="o">=</span> <span class="k">new</span> <span class="n">PointDistanceFromOriginComparator</span><span class="o">();</span>

<span class="kd">public</span> <span class="kt">int</span> <span class="nf">compare</span><span class="o">(</span><span class="n">Point</span> <span class="n">p1</span><span class="o">,</span> <span class="n">Point</span> <span class="n">p2</span><span class="o">)</span> <span class="o">{</span>
    <span class="k">if</span> <span class="o">(</span><span class="n">p1</span><span class="o">.</span><span class="na">distanceFromOrigin</span> <span class="o">&lt;</span> <span class="n">p2</span><span class="o">.</span><span class="na">distanceFromOrigin</span><span class="o">)</span> <span class="o">{</span>
        <span class="k">return</span> <span class="o">-</span><span class="mi">1</span><span class="o">;</span>
    <span class="o">}</span> <span class="k">else</span> <span class="k">if</span> <span class="o">(</span><span class="n">p1</span><span class="o">.</span><span class="na">distanceFromOrigin</span> <span class="o">&gt;</span> <span class="n">p2</span><span class="o">.</span><span class="na">distanceFromOrigin</span><span class="o">)</span> <span class="o">{</span>
        <span class="k">return</span> <span class="mi">1</span><span class="o">;</span>
    <span class="o">}</span>
    <span class="k">return</span> <span class="mi">0</span><span class="o">;</span>
<span class="o">}</span>

}



comments powered by Disqus