001    /**
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.activemq.transport.nio;
019    
020    import org.apache.activemq.command.Command;
021    import org.apache.activemq.openwire.OpenWireFormat;
022    import org.apache.activemq.thread.DefaultThreadPools;
023    import org.apache.activemq.util.IOExceptionSupport;
024    import org.apache.activemq.util.ServiceStopper;
025    import org.apache.activemq.wireformat.WireFormat;
026    
027    import javax.net.SocketFactory;
028    import javax.net.ssl.*;
029    import java.io.DataInputStream;
030    import java.io.DataOutputStream;
031    import java.io.EOFException;
032    import java.io.IOException;
033    import java.net.Socket;
034    import java.net.URI;
035    import java.net.UnknownHostException;
036    import java.nio.ByteBuffer;
037    
038    public class NIOSSLTransport extends NIOTransport  {
039    
040        protected boolean needClientAuth;
041        protected boolean wantClientAuth;
042        protected String[] enabledCipherSuites;
043    
044        protected SSLContext sslContext;
045        protected SSLEngine sslEngine;
046        protected SSLSession sslSession;
047    
048    
049        protected boolean handshakeInProgress = false;
050        protected SSLEngineResult.Status status = null;
051        protected SSLEngineResult.HandshakeStatus handshakeStatus = null;
052    
053        public NIOSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException {
054            super(wireFormat, socketFactory, remoteLocation, localLocation);
055        }
056    
057        public NIOSSLTransport(WireFormat wireFormat, Socket socket) throws IOException {
058            super(wireFormat, socket);
059        }
060    
061        public void setSslContext(SSLContext sslContext) {
062            this.sslContext = sslContext;
063        }
064    
065        @Override
066        protected void initializeStreams() throws IOException {
067            try {
068                channel = socket.getChannel();
069                channel.configureBlocking(false);
070    
071                if (sslContext == null) {
072                    sslContext = SSLContext.getDefault();
073                }
074    
075                // initialize engine
076                sslEngine = sslContext.createSSLEngine();
077                sslEngine.setUseClientMode(false);
078                if (enabledCipherSuites != null) {
079                    sslEngine.setEnabledCipherSuites(enabledCipherSuites);
080                }
081                sslEngine.setNeedClientAuth(needClientAuth);
082                sslEngine.setWantClientAuth(wantClientAuth);
083    
084                sslSession = sslEngine.getSession();
085    
086                inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize());
087                inputBuffer.clear();
088                currentBuffer = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
089    
090                NIOOutputStream outputStream = new NIOOutputStream(channel);
091                outputStream.setEngine(sslEngine);
092                this.dataOut = new DataOutputStream(outputStream);
093                this.buffOut = outputStream;
094                sslEngine.beginHandshake();
095                handshakeStatus = sslEngine.getHandshakeStatus();
096                doHandshake();
097    
098            } catch (Exception e) {
099                throw new IOException(e);
100            }
101    
102        }
103    
104        protected void finishHandshake() throws Exception  {
105              if (handshakeInProgress) {
106                  handshakeInProgress = false;
107                  nextFrameSize = -1;
108    
109                  // listen for events telling us when the socket is readable.
110                  selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() {
111                      public void onSelect(SelectorSelection selection) {
112                          serviceRead();
113                      }
114    
115                      public void onError(SelectorSelection selection, Throwable error) {
116                          if (error instanceof IOException) {
117                              onException((IOException) error);
118                          } else {
119                              onException(IOExceptionSupport.create(error));
120                          }
121                      }
122                  });
123              }
124        }
125    
126        protected void serviceRead() {
127            try {
128                if (handshakeInProgress) {
129                    doHandshake();
130                }
131    
132                ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
133                plain.position(plain.limit());
134    
135                while(true) {
136                    if (!plain.hasRemaining()) {
137    
138                        if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
139                            plain.clear();
140                        } else {
141                            plain.compact();
142                        }
143                        int readCount = secureRead(plain);
144    
145    
146                        if (readCount == 0)
147                            break;
148    
149                        // channel is closed, cleanup
150                        if (readCount== -1) {
151                            onException(new EOFException());
152                            selection.close();
153                            break;
154                        }
155                    }
156    
157                    if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
158                        processCommand(plain);
159                    }
160    
161                }
162            } catch (IOException e) {
163                onException(e);
164            } catch (Throwable e) {
165                onException(IOExceptionSupport.create(e));
166            }
167        }
168    
169        protected void processCommand(ByteBuffer plain) throws Exception {
170            nextFrameSize = plain.getInt();
171            if (wireFormat instanceof OpenWireFormat) {
172                long maxFrameSize = ((OpenWireFormat) wireFormat).getMaxFrameSize();
173                if (nextFrameSize > maxFrameSize) {
174                    throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) + " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB");
175                }
176            }
177            currentBuffer = ByteBuffer.allocate(nextFrameSize + 4);
178            currentBuffer.putInt(nextFrameSize);
179            if (currentBuffer.hasRemaining()) {
180                if (currentBuffer.remaining() >= plain.remaining()) {
181                    currentBuffer.put(plain);
182                } else {
183                    byte[] fill = new byte[currentBuffer.remaining()];
184                    plain.get(fill);
185                    currentBuffer.put(fill);
186                }
187            }
188    
189            if (currentBuffer.hasRemaining()) {
190                return;
191            } else {
192                currentBuffer.flip();
193                Object command = wireFormat.unmarshal(new DataInputStream(new NIOInputStream(currentBuffer)));
194                doConsume((Command) command);
195                nextFrameSize = -1;
196            }
197        }
198    
199        protected int secureRead(ByteBuffer plain) throws Exception {
200    
201            if (!(inputBuffer.position() != 0 && inputBuffer.hasRemaining()) || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
202                int bytesRead = channel.read(inputBuffer);
203    
204                if (bytesRead == -1) {
205                    sslEngine.closeInbound();
206                    if (inputBuffer.position() == 0 ||
207                            status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
208                        return -1;
209                    }
210                }
211            }
212    
213            plain.clear();
214    
215            inputBuffer.flip();
216            SSLEngineResult res;
217            do {
218                res = sslEngine.unwrap(inputBuffer, plain);
219            } while (res.getStatus() == SSLEngineResult.Status.OK &&
220                    res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP &&
221                    res.bytesProduced() == 0);
222    
223            if (res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) {
224               finishHandshake();
225            }
226    
227            status = res.getStatus();
228            handshakeStatus = res.getHandshakeStatus();
229    
230    
231            //TODO deal with BUFFER_OVERFLOW
232    
233            if (status == SSLEngineResult.Status.CLOSED) {
234                sslEngine.closeInbound();
235                return -1;
236            }
237    
238            inputBuffer.compact();
239            plain.flip();
240    
241            return plain.remaining();
242        }
243    
244        protected void doHandshake() throws Exception {
245            handshakeInProgress = true;
246            while (true) {
247                switch (sslEngine.getHandshakeStatus()) {
248                    case NEED_UNWRAP:
249                        secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize()));
250                        break;
251                    case NEED_TASK:
252                        Runnable task;
253                        while ((task = sslEngine.getDelegatedTask()) != null) {
254                            DefaultThreadPools.getDefaultTaskRunnerFactory().execute(task);
255                        }
256                        break;
257                    case NEED_WRAP:
258                        ((NIOOutputStream)buffOut).write(ByteBuffer.allocate(0));
259                        break;
260                    case FINISHED:
261                    case NOT_HANDSHAKING:
262                        finishHandshake();
263                        return;
264                }
265            }
266        }
267    
268        @Override
269        protected void doStop(ServiceStopper stopper) throws Exception {
270            if (channel != null) {
271                channel.close();
272                channel = null;
273            }
274            super.doStop(stopper);
275        }
276    
277        public boolean isNeedClientAuth() {
278            return needClientAuth;
279        }
280    
281        public void setNeedClientAuth(boolean needClientAuth) {
282            this.needClientAuth = needClientAuth;
283        }
284    
285        public boolean isWantClientAuth() {
286            return wantClientAuth;
287        }
288    
289        public void setWantClientAuth(boolean wantClientAuth) {
290            this.wantClientAuth = wantClientAuth;
291        }
292    
293        public String[] getEnabledCipherSuites() {
294            return enabledCipherSuites;
295        }
296    
297        public void setEnabledCipherSuites(String[] enabledCipherSuites) {
298            this.enabledCipherSuites = enabledCipherSuites;
299        }
300    }