001    /**
002     * Copyright (C) 2010-2011, FuseSource Corp.  All rights reserved.
003     *
004     *     http://fusesource.com
005     *
006     * The software in this package is published under the terms of the
007     * CDDL license a copy of which has been included with this distribution
008     * in the license.txt file.
009     */
010    package org.fusesource.hawtdispatch.transport;
011    
012    import javax.net.ssl.*;
013    import java.io.IOException;
014    import java.net.URI;
015    import java.nio.ByteBuffer;
016    import java.nio.channels.ReadableByteChannel;
017    import java.nio.channels.SocketChannel;
018    import java.nio.channels.WritableByteChannel;
019    import java.security.cert.Certificate;
020    import java.security.cert.X509Certificate;
021    import java.util.ArrayList;
022    import java.util.concurrent.Executor;
023    
024    import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_UNWRAP;
025    import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_WRAP;
026    import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING;
027    import static javax.net.ssl.SSLEngineResult.Status.BUFFER_OVERFLOW;
028    
029    /**
030     * An SSL Transport for secure communications.
031     *
032     * @author <a href="http://hiramchirino.com">Hiram Chirino</a>
033     */
034    public class SslTransport extends TcpTransport {
035    
036    
037        /**
038         * Maps uri schemes to a protocol algorithm names.
039         * Valid algorithm names listed at:
040         * http://download.oracle.com/javase/6/docs/technotes/guides/security/StandardNames.html#SSLContext
041         */
042        public static String protocol(String scheme) {
043            if( scheme.equals("tls") ) {
044                return "TLS";
045            } else if( scheme.startsWith("tlsv") ) {
046                return "TLSv"+scheme.substring(4);
047            } else if( scheme.equals("ssl") ) {
048                return "SSL";
049            } else if( scheme.startsWith("sslv") ) {
050                return "SSLv"+scheme.substring(4);
051            }
052            return null;
053        }
054    
055        private SSLContext sslContext;
056        private SSLEngine engine;
057    
058        private ByteBuffer readBuffer;
059        private boolean readUnderflow;
060    
061        private ByteBuffer writeBuffer;
062        private boolean writeFlushing;
063    
064        private ByteBuffer readOverflowBuffer;
065        private SSLChannel ssl_channel = new SSLChannel();
066    
067        private Executor blockingExecutor;
068    
069        public void setSSLContext(SSLContext ctx) {
070            this.sslContext = ctx;
071        }
072    
073        /**
074         * Allows subclasses of TcpTransportFactory to create custom instances of
075         * TcpTransport.
076         */
077        public static SslTransport createTransport(URI uri) throws Exception {
078            String protocol = protocol(uri.getScheme());
079            if( protocol !=null ) {
080                SslTransport rc = new SslTransport();
081                rc.setSSLContext(SSLContext.getInstance(protocol));
082                return rc;
083            }
084            return null;
085        }
086    
087        class SSLChannel implements ReadableByteChannel, WritableByteChannel {
088    
089            public int write(ByteBuffer plain) throws IOException {
090                return secure_write(plain);
091            }
092    
093            public int read(ByteBuffer plain) throws IOException {
094                return secure_read(plain);
095            }
096    
097            public boolean isOpen() {
098                return getSocketChannel().isOpen();
099            }
100    
101            public void close() throws IOException {
102                getSocketChannel().close();
103            }
104        }
105    
106        public SSLSession getSSLSession() {
107            return engine==null ? null : engine.getSession();
108        }
109    
110        public X509Certificate[] getPeerX509Certificates() {
111            if( engine==null ) {
112                return null;
113            }
114            try {
115                ArrayList<X509Certificate> rc = new ArrayList<X509Certificate>();
116                for( Certificate c:engine.getSession().getPeerCertificates() ) {
117                    if(c instanceof X509Certificate) {
118                        rc.add((X509Certificate) c);
119                    }
120                }
121                return rc.toArray(new X509Certificate[rc.size()]);
122            } catch (SSLPeerUnverifiedException e) {
123                return null;
124            }
125        }
126    
127        @Override
128        public void connecting(URI remoteLocation, URI localLocation) throws Exception {
129            assert engine == null;
130            engine = sslContext.createSSLEngine();
131            engine.setUseClientMode(true);
132            super.connecting(remoteLocation, localLocation);
133        }
134    
135        @Override
136        public void connected(SocketChannel channel) throws Exception {
137            if (engine == null) {
138                engine = sslContext.createSSLEngine();
139                engine.setUseClientMode(false);
140                engine.setWantClientAuth(true);
141            }
142            SSLSession session = engine.getSession();
143            readBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
144            readBuffer.flip();
145            writeBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
146    
147            super.connected(channel);
148        }
149    
150        @Override
151        protected void onConnected() throws IOException {
152            super.onConnected();
153            engine.setWantClientAuth(true);
154            engine.beginHandshake();
155            handshake();
156        }
157    
158        @Override
159        public void flush() {
160            if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
161                handshake();
162            } else {
163                super.flush();
164            }
165        }
166    
167        @Override
168        protected void drainInbound() {
169            if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
170                handshake();
171            } else {
172                super.drainInbound();
173            }
174        }
175    
176        /**
177         * @return true if fully flushed.
178         * @throws IOException
179         */
180        protected boolean transportFlush() throws IOException {
181            while (true) {
182                if(writeFlushing) {
183                    int count = super.writeChannel().write(writeBuffer);
184                    if( !writeBuffer.hasRemaining() ) {
185                        writeBuffer.clear();
186                        writeFlushing = false;
187                        suspendWrite();
188                        return true;
189                    } else {
190                        return false;
191                    }
192                } else {
193                    if( writeBuffer.position()!=0 ) {
194                        writeBuffer.flip();
195                        writeFlushing = true;
196                        resumeWrite();
197                    } else {
198                        return true;
199                    }
200                }
201            }
202        }
203    
204        private int secure_write(ByteBuffer plain) throws IOException {
205            if( !transportFlush() ) {
206                // can't write anymore until the write_secured_buffer gets fully flushed out..
207                return 0;
208            }
209            int rc = 0;
210            while ( plain.hasRemaining() || engine.getHandshakeStatus()==NEED_WRAP ) {
211                SSLEngineResult result = engine.wrap(plain, writeBuffer);
212                assert result.getStatus()!= BUFFER_OVERFLOW;
213                rc += result.bytesConsumed();
214                if( !transportFlush() ) {
215                    break;
216                }
217            }
218            if( plain.remaining()==0 && engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
219                dispatchQueue.execute(new Runnable() {
220                    public void run() {
221                        handshake();
222                    }
223                });
224            }
225            return rc;
226        }
227    
228        private int secure_read(ByteBuffer plain) throws IOException {
229            int rc=0;
230            while ( plain.hasRemaining() || engine.getHandshakeStatus() == NEED_UNWRAP ) {
231                if( readOverflowBuffer !=null ) {
232                    if(  plain.hasRemaining() ) {
233                        // lets drain the overflow buffer before trying to suck down anymore
234                        // network bytes.
235                        int size = Math.min(plain.remaining(), readOverflowBuffer.remaining());
236                        plain.put(readOverflowBuffer.array(), readOverflowBuffer.position(), size);
237                        readOverflowBuffer.position(readOverflowBuffer.position()+size);
238                        if( !readOverflowBuffer.hasRemaining() ) {
239                            readOverflowBuffer = null;
240                        }
241                        rc += size;
242                    } else {
243                        return rc;
244                    }
245                } else if( readUnderflow ) {
246                    int count = super.readChannel().read(readBuffer);
247                    if( count == -1 ) {  // peer closed socket.
248                        if (rc==0) {
249                            return -1;
250                        } else {
251                            return rc;
252                        }
253                    }
254                    if( count==0 ) {  // no data available right now.
255                        return rc;
256                    }
257                    // read in some more data, perhaps now we can unwrap.
258                    readUnderflow = false;
259                    readBuffer.flip();
260                } else {
261                    SSLEngineResult result = engine.unwrap(readBuffer, plain);
262                    rc += result.bytesProduced();
263                    if( result.getStatus() == BUFFER_OVERFLOW ) {
264                        readOverflowBuffer = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize());
265                        result = engine.unwrap(readBuffer, readOverflowBuffer);
266                        if( readOverflowBuffer.position()==0 ) {
267                            readOverflowBuffer = null;
268                        } else {
269                            readOverflowBuffer.flip();
270                        }
271                    }
272                    switch( result.getStatus() ) {
273                        case CLOSED:
274                            if (rc==0) {
275                                engine.closeInbound();
276                                return -1;
277                            } else {
278                                return rc;
279                            }
280                        case OK:
281                            if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
282                                dispatchQueue.execute(new Runnable() {
283                                    public void run() {
284                                        handshake();
285                                    }
286                                });
287                            }
288                            break;
289                        case BUFFER_UNDERFLOW:
290                            readBuffer.compact();
291                            readUnderflow = true;
292                            break;
293                        case BUFFER_OVERFLOW:
294                            throw new AssertionError("Unexpected case.");
295                    }
296                }
297            }
298            return rc;
299        }
300    
301        public void handshake() {
302            try {
303                if( !transportFlush() ) {
304                    return;
305                }
306                switch (engine.getHandshakeStatus()) {
307                    case NEED_TASK:
308                        final Runnable task = engine.getDelegatedTask();
309                        if( task!=null ) {
310                            blockingExecutor.execute(new Runnable() {
311                                public void run() {
312                                    task.run();
313                                    dispatchQueue.execute(new Runnable() {
314                                        public void run() {
315                                            if (isConnected()) {
316                                                handshake();
317                                            }
318                                        }
319                                    });
320                                }
321                            });
322                        }
323                        break;
324    
325                    case NEED_WRAP:
326                        secure_write(ByteBuffer.allocate(0));
327                        break;
328    
329                    case NEED_UNWRAP:
330                        secure_read(ByteBuffer.allocate(0));
331                        break;
332    
333                    case FINISHED:
334                    case NOT_HANDSHAKING:
335                        break;
336    
337                    default:
338                        System.err.println("Unexpected ssl engine handshake status: "+ engine.getHandshakeStatus());
339                        break;
340                }
341            } catch (IOException e ) {
342                onTransportFailure(e);
343            }
344        }
345    
346    
347        public ReadableByteChannel readChannel() {
348            return ssl_channel;
349        }
350    
351        public WritableByteChannel writeChannel() {
352            return ssl_channel;
353        }
354    
355        public Executor getBlockingExecutor() {
356            return blockingExecutor;
357        }
358    
359        public void setBlockingExecutor(Executor blockingExecutor) {
360            this.blockingExecutor = blockingExecutor;
361        }
362    }
363    
364