001 /**
002 * Copyright (C) 2010-2011, FuseSource Corp. All rights reserved.
003 *
004 * http://fusesource.com
005 *
006 * The software in this package is published under the terms of the
007 * CDDL license a copy of which has been included with this distribution
008 * in the license.txt file.
009 */
010 package org.fusesource.hawtdispatch.transport;
011
012 import javax.net.ssl.*;
013 import java.io.IOException;
014 import java.net.URI;
015 import java.nio.ByteBuffer;
016 import java.nio.channels.ReadableByteChannel;
017 import java.nio.channels.SocketChannel;
018 import java.nio.channels.WritableByteChannel;
019 import java.security.cert.Certificate;
020 import java.security.cert.X509Certificate;
021 import java.util.ArrayList;
022 import java.util.concurrent.Executor;
023
024 import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_UNWRAP;
025 import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_WRAP;
026 import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING;
027 import static javax.net.ssl.SSLEngineResult.Status.BUFFER_OVERFLOW;
028
029 /**
030 * An SSL Transport for secure communications.
031 *
032 * @author <a href="http://hiramchirino.com">Hiram Chirino</a>
033 */
034 public class SslTransport extends TcpTransport {
035
036
037 /**
038 * Maps uri schemes to a protocol algorithm names.
039 * Valid algorithm names listed at:
040 * http://download.oracle.com/javase/6/docs/technotes/guides/security/StandardNames.html#SSLContext
041 */
042 public static String protocol(String scheme) {
043 if( scheme.equals("tls") ) {
044 return "TLS";
045 } else if( scheme.startsWith("tlsv") ) {
046 return "TLSv"+scheme.substring(4);
047 } else if( scheme.equals("ssl") ) {
048 return "SSL";
049 } else if( scheme.startsWith("sslv") ) {
050 return "SSLv"+scheme.substring(4);
051 }
052 return null;
053 }
054
055 private SSLContext sslContext;
056 private SSLEngine engine;
057
058 private ByteBuffer readBuffer;
059 private boolean readUnderflow;
060
061 private ByteBuffer writeBuffer;
062 private boolean writeFlushing;
063
064 private ByteBuffer readOverflowBuffer;
065 private SSLChannel ssl_channel = new SSLChannel();
066
067 private Executor blockingExecutor;
068
069 public void setSSLContext(SSLContext ctx) {
070 this.sslContext = ctx;
071 }
072
073 /**
074 * Allows subclasses of TcpTransportFactory to create custom instances of
075 * TcpTransport.
076 */
077 public static SslTransport createTransport(URI uri) throws Exception {
078 String protocol = protocol(uri.getScheme());
079 if( protocol !=null ) {
080 SslTransport rc = new SslTransport();
081 rc.setSSLContext(SSLContext.getInstance(protocol));
082 return rc;
083 }
084 return null;
085 }
086
087 class SSLChannel implements ReadableByteChannel, WritableByteChannel {
088
089 public int write(ByteBuffer plain) throws IOException {
090 return secure_write(plain);
091 }
092
093 public int read(ByteBuffer plain) throws IOException {
094 return secure_read(plain);
095 }
096
097 public boolean isOpen() {
098 return getSocketChannel().isOpen();
099 }
100
101 public void close() throws IOException {
102 getSocketChannel().close();
103 }
104 }
105
106 public SSLSession getSSLSession() {
107 return engine==null ? null : engine.getSession();
108 }
109
110 public X509Certificate[] getPeerX509Certificates() {
111 if( engine==null ) {
112 return null;
113 }
114 try {
115 ArrayList<X509Certificate> rc = new ArrayList<X509Certificate>();
116 for( Certificate c:engine.getSession().getPeerCertificates() ) {
117 if(c instanceof X509Certificate) {
118 rc.add((X509Certificate) c);
119 }
120 }
121 return rc.toArray(new X509Certificate[rc.size()]);
122 } catch (SSLPeerUnverifiedException e) {
123 return null;
124 }
125 }
126
127 @Override
128 public void connecting(URI remoteLocation, URI localLocation) throws Exception {
129 assert engine == null;
130 engine = sslContext.createSSLEngine();
131 engine.setUseClientMode(true);
132 super.connecting(remoteLocation, localLocation);
133 }
134
135 @Override
136 public void connected(SocketChannel channel) throws Exception {
137 if (engine == null) {
138 engine = sslContext.createSSLEngine();
139 engine.setUseClientMode(false);
140 engine.setWantClientAuth(true);
141 }
142 SSLSession session = engine.getSession();
143 readBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
144 readBuffer.flip();
145 writeBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
146
147 super.connected(channel);
148 }
149
150 @Override
151 protected void onConnected() throws IOException {
152 super.onConnected();
153 engine.setWantClientAuth(true);
154 engine.beginHandshake();
155 handshake();
156 }
157
158 @Override
159 public void flush() {
160 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
161 handshake();
162 } else {
163 super.flush();
164 }
165 }
166
167 @Override
168 protected void drainInbound() {
169 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
170 handshake();
171 } else {
172 super.drainInbound();
173 }
174 }
175
176 /**
177 * @return true if fully flushed.
178 * @throws IOException
179 */
180 protected boolean transportFlush() throws IOException {
181 while (true) {
182 if(writeFlushing) {
183 int count = super.writeChannel().write(writeBuffer);
184 if( !writeBuffer.hasRemaining() ) {
185 writeBuffer.clear();
186 writeFlushing = false;
187 suspendWrite();
188 return true;
189 } else {
190 return false;
191 }
192 } else {
193 if( writeBuffer.position()!=0 ) {
194 writeBuffer.flip();
195 writeFlushing = true;
196 resumeWrite();
197 } else {
198 return true;
199 }
200 }
201 }
202 }
203
204 private int secure_write(ByteBuffer plain) throws IOException {
205 if( !transportFlush() ) {
206 // can't write anymore until the write_secured_buffer gets fully flushed out..
207 return 0;
208 }
209 int rc = 0;
210 while ( plain.hasRemaining() || engine.getHandshakeStatus()==NEED_WRAP ) {
211 SSLEngineResult result = engine.wrap(plain, writeBuffer);
212 assert result.getStatus()!= BUFFER_OVERFLOW;
213 rc += result.bytesConsumed();
214 if( !transportFlush() ) {
215 break;
216 }
217 }
218 if( plain.remaining()==0 && engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
219 dispatchQueue.execute(new Runnable() {
220 public void run() {
221 handshake();
222 }
223 });
224 }
225 return rc;
226 }
227
228 private int secure_read(ByteBuffer plain) throws IOException {
229 int rc=0;
230 while ( plain.hasRemaining() || engine.getHandshakeStatus() == NEED_UNWRAP ) {
231 if( readOverflowBuffer !=null ) {
232 if( plain.hasRemaining() ) {
233 // lets drain the overflow buffer before trying to suck down anymore
234 // network bytes.
235 int size = Math.min(plain.remaining(), readOverflowBuffer.remaining());
236 plain.put(readOverflowBuffer.array(), readOverflowBuffer.position(), size);
237 readOverflowBuffer.position(readOverflowBuffer.position()+size);
238 if( !readOverflowBuffer.hasRemaining() ) {
239 readOverflowBuffer = null;
240 }
241 rc += size;
242 } else {
243 return rc;
244 }
245 } else if( readUnderflow ) {
246 int count = super.readChannel().read(readBuffer);
247 if( count == -1 ) { // peer closed socket.
248 if (rc==0) {
249 return -1;
250 } else {
251 return rc;
252 }
253 }
254 if( count==0 ) { // no data available right now.
255 return rc;
256 }
257 // read in some more data, perhaps now we can unwrap.
258 readUnderflow = false;
259 readBuffer.flip();
260 } else {
261 SSLEngineResult result = engine.unwrap(readBuffer, plain);
262 rc += result.bytesProduced();
263 if( result.getStatus() == BUFFER_OVERFLOW ) {
264 readOverflowBuffer = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize());
265 result = engine.unwrap(readBuffer, readOverflowBuffer);
266 if( readOverflowBuffer.position()==0 ) {
267 readOverflowBuffer = null;
268 } else {
269 readOverflowBuffer.flip();
270 }
271 }
272 switch( result.getStatus() ) {
273 case CLOSED:
274 if (rc==0) {
275 engine.closeInbound();
276 return -1;
277 } else {
278 return rc;
279 }
280 case OK:
281 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
282 dispatchQueue.execute(new Runnable() {
283 public void run() {
284 handshake();
285 }
286 });
287 }
288 break;
289 case BUFFER_UNDERFLOW:
290 readBuffer.compact();
291 readUnderflow = true;
292 break;
293 case BUFFER_OVERFLOW:
294 throw new AssertionError("Unexpected case.");
295 }
296 }
297 }
298 return rc;
299 }
300
301 public void handshake() {
302 try {
303 if( !transportFlush() ) {
304 return;
305 }
306 switch (engine.getHandshakeStatus()) {
307 case NEED_TASK:
308 final Runnable task = engine.getDelegatedTask();
309 if( task!=null ) {
310 blockingExecutor.execute(new Runnable() {
311 public void run() {
312 task.run();
313 dispatchQueue.execute(new Runnable() {
314 public void run() {
315 if (isConnected()) {
316 handshake();
317 }
318 }
319 });
320 }
321 });
322 }
323 break;
324
325 case NEED_WRAP:
326 secure_write(ByteBuffer.allocate(0));
327 break;
328
329 case NEED_UNWRAP:
330 secure_read(ByteBuffer.allocate(0));
331 break;
332
333 case FINISHED:
334 case NOT_HANDSHAKING:
335 break;
336
337 default:
338 System.err.println("Unexpected ssl engine handshake status: "+ engine.getHandshakeStatus());
339 break;
340 }
341 } catch (IOException e ) {
342 onTransportFailure(e);
343 }
344 }
345
346
347 public ReadableByteChannel readChannel() {
348 return ssl_channel;
349 }
350
351 public WritableByteChannel writeChannel() {
352 return ssl_channel;
353 }
354
355 public Executor getBlockingExecutor() {
356 return blockingExecutor;
357 }
358
359 public void setBlockingExecutor(Executor blockingExecutor) {
360 this.blockingExecutor = blockingExecutor;
361 }
362 }
363
364