001/** 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.activemq.transport.auto; 018 019import java.io.IOException; 020import java.io.InputStream; 021import java.net.Socket; 022import java.net.URI; 023import java.net.URISyntaxException; 024import java.nio.ByteBuffer; 025import java.util.HashMap; 026import java.util.Map; 027import java.util.Set; 028import java.util.concurrent.ConcurrentHashMap; 029import java.util.concurrent.ConcurrentMap; 030import java.util.concurrent.Future; 031import java.util.concurrent.LinkedBlockingQueue; 032import java.util.concurrent.ThreadPoolExecutor; 033import java.util.concurrent.TimeUnit; 034import java.util.concurrent.TimeoutException; 035import java.util.concurrent.atomic.AtomicInteger; 036 037import javax.net.ServerSocketFactory; 038 039import org.apache.activemq.broker.BrokerService; 040import org.apache.activemq.broker.BrokerServiceAware; 041import org.apache.activemq.openwire.OpenWireFormatFactory; 042import org.apache.activemq.transport.InactivityIOException; 043import org.apache.activemq.transport.Transport; 044import org.apache.activemq.transport.TransportFactory; 045import org.apache.activemq.transport.TransportServer; 046import org.apache.activemq.transport.protocol.AmqpProtocolVerifier; 047import org.apache.activemq.transport.protocol.MqttProtocolVerifier; 048import org.apache.activemq.transport.protocol.OpenWireProtocolVerifier; 049import org.apache.activemq.transport.protocol.ProtocolVerifier; 050import org.apache.activemq.transport.protocol.StompProtocolVerifier; 051import org.apache.activemq.transport.tcp.TcpTransport; 052import org.apache.activemq.transport.tcp.TcpTransport.InitBuffer; 053import org.apache.activemq.transport.tcp.TcpTransportFactory; 054import org.apache.activemq.transport.tcp.TcpTransportServer; 055import org.apache.activemq.util.FactoryFinder; 056import org.apache.activemq.util.IOExceptionSupport; 057import org.apache.activemq.util.IntrospectionSupport; 058import org.apache.activemq.util.ServiceStopper; 059import org.apache.activemq.wireformat.WireFormat; 060import org.apache.activemq.wireformat.WireFormatFactory; 061import org.slf4j.Logger; 062import org.slf4j.LoggerFactory; 063 064/** 065 * A TCP based implementation of {@link TransportServer} 066 */ 067public class AutoTcpTransportServer extends TcpTransportServer { 068 069 private static final Logger LOG = LoggerFactory.getLogger(AutoTcpTransportServer.class); 070 071 protected Map<String, Map<String, Object>> wireFormatOptions; 072 protected Map<String, Object> autoTransportOptions; 073 protected Set<String> enabledProtocols; 074 protected final Map<String, ProtocolVerifier> protocolVerifiers = new ConcurrentHashMap<String, ProtocolVerifier>(); 075 076 protected BrokerService brokerService; 077 078 protected final ThreadPoolExecutor newConnectionExecutor; 079 protected final ThreadPoolExecutor protocolDetectionExecutor; 080 protected int maxConnectionThreadPoolSize = Integer.MAX_VALUE; 081 protected int protocolDetectionTimeOut = 30000; 082 083 private static final FactoryFinder TRANSPORT_FACTORY_FINDER = new FactoryFinder("META-INF/services/org/apache/activemq/transport/"); 084 private final ConcurrentMap<String, TransportFactory> transportFactories = new ConcurrentHashMap<String, TransportFactory>(); 085 086 private static final FactoryFinder WIREFORMAT_FACTORY_FINDER = new FactoryFinder("META-INF/services/org/apache/activemq/wireformat/"); 087 088 public WireFormatFactory findWireFormatFactory(String scheme, Map<String, Map<String, Object>> options) throws IOException { 089 WireFormatFactory wff = null; 090 try { 091 wff = (WireFormatFactory)WIREFORMAT_FACTORY_FINDER.newInstance(scheme); 092 if (options != null) { 093 final Map<String, Object> wfOptions = new HashMap<>(); 094 if (options.get(AutoTransportUtils.ALL) != null) { 095 wfOptions.putAll(options.get(AutoTransportUtils.ALL)); 096 } 097 if (options.get(scheme) != null) { 098 wfOptions.putAll(options.get(scheme)); 099 } 100 IntrospectionSupport.setProperties(wff, wfOptions); 101 } 102 if (wff instanceof OpenWireFormatFactory) { 103 protocolVerifiers.put(AutoTransportUtils.OPENWIRE, new OpenWireProtocolVerifier((OpenWireFormatFactory) wff)); 104 } 105 return wff; 106 } catch (Throwable e) { 107 throw IOExceptionSupport.create("Could not create wire format factory for: " + scheme + ", reason: " + e, e); 108 } 109 } 110 111 public TransportFactory findTransportFactory(String scheme, Map<String, ?> options) throws IOException { 112 scheme = append(scheme, "nio"); 113 scheme = append(scheme, "ssl"); 114 115 if (scheme.isEmpty()) { 116 scheme = "tcp"; 117 } 118 119 TransportFactory tf = transportFactories.get(scheme); 120 if (tf == null) { 121 // Try to load if from a META-INF property. 122 try { 123 tf = (TransportFactory)TRANSPORT_FACTORY_FINDER.newInstance(scheme); 124 if (options != null) { 125 IntrospectionSupport.setProperties(tf, options); 126 } 127 transportFactories.put(scheme, tf); 128 } catch (Throwable e) { 129 throw IOExceptionSupport.create("Transport scheme NOT recognized: [" + scheme + "]", e); 130 } 131 } 132 return tf; 133 } 134 135 protected String append(String currentScheme, String scheme) { 136 if (this.getBindLocation().getScheme().contains(scheme)) { 137 if (!currentScheme.isEmpty()) { 138 currentScheme += "+"; 139 } 140 currentScheme += scheme; 141 } 142 return currentScheme; 143 } 144 145 /** 146 * @param transportFactory 147 * @param location 148 * @param serverSocketFactory 149 * @throws IOException 150 * @throws URISyntaxException 151 */ 152 public AutoTcpTransportServer(TcpTransportFactory transportFactory, 153 URI location, ServerSocketFactory serverSocketFactory, BrokerService brokerService, 154 Set<String> enabledProtocols) 155 throws IOException, URISyntaxException { 156 super(transportFactory, location, serverSocketFactory); 157 158 //Use an executor service here to handle new connections. Setting the max number 159 //of threads to the maximum number of connections the thread count isn't unbounded 160 newConnectionExecutor = new ThreadPoolExecutor(maxConnectionThreadPoolSize, 161 maxConnectionThreadPoolSize, 162 30L, TimeUnit.SECONDS, 163 new LinkedBlockingQueue<Runnable>()); 164 //allow the thread pool to shrink if the max number of threads isn't needed 165 //and the pool can grow and shrink as needed if contention is high 166 newConnectionExecutor.allowCoreThreadTimeOut(true); 167 168 //Executor for waiting for bytes to detection of protocol 169 protocolDetectionExecutor = new ThreadPoolExecutor(maxConnectionThreadPoolSize, 170 maxConnectionThreadPoolSize, 171 30L, TimeUnit.SECONDS, 172 new LinkedBlockingQueue<Runnable>()); 173 //allow the thread pool to shrink if the max number of threads isn't needed 174 protocolDetectionExecutor.allowCoreThreadTimeOut(true); 175 176 this.brokerService = brokerService; 177 this.enabledProtocols = enabledProtocols; 178 initProtocolVerifiers(); 179 } 180 181 public int getMaxConnectionThreadPoolSize() { 182 return maxConnectionThreadPoolSize; 183 } 184 185 /** 186 * Set the number of threads to be used for processing connections. Defaults 187 * to Integer.MAX_SIZE. Set this value to be lower to reduce the 188 * number of simultaneous connection attempts. If not set then the maximum number of 189 * threads will generally be controlled by the transport maxConnections setting: 190 * {@link TcpTransportServer#setMaximumConnections(int)}. 191 *<p> 192 * Note that this setter controls two thread pools because connection attempts 193 * require 1 thread to start processing the connection and another thread to read from the 194 * socket and to detect the protocol. Two threads are needed because some transports 195 * block on socket read so the first thread needs to be able to abort the second thread on timeout. 196 * Therefore this setting will set each thread pool to the size passed in essentially giving 197 * 2 times as many potential threads as the value set. 198 *<p> 199 * Both thread pools will close idle threads after a period of time 200 * essentially allowing the thread pools to grow and shrink dynamically based on load. 201 * 202 * @see {@link TcpTransportServer#setMaximumConnections(int)}. 203 * @param maxConnectionThreadPoolSize 204 */ 205 public void setMaxConnectionThreadPoolSize(int maxConnectionThreadPoolSize) { 206 this.maxConnectionThreadPoolSize = maxConnectionThreadPoolSize; 207 newConnectionExecutor.setCorePoolSize(maxConnectionThreadPoolSize); 208 newConnectionExecutor.setMaximumPoolSize(maxConnectionThreadPoolSize); 209 protocolDetectionExecutor.setCorePoolSize(maxConnectionThreadPoolSize); 210 protocolDetectionExecutor.setMaximumPoolSize(maxConnectionThreadPoolSize); 211 } 212 213 public void setProtocolDetectionTimeOut(int protocolDetectionTimeOut) { 214 this.protocolDetectionTimeOut = protocolDetectionTimeOut; 215 } 216 217 @Override 218 public void setWireFormatFactory(WireFormatFactory factory) { 219 super.setWireFormatFactory(factory); 220 initOpenWireProtocolVerifier(); 221 } 222 223 protected void initProtocolVerifiers() { 224 initOpenWireProtocolVerifier(); 225 226 if (isAllProtocols() || enabledProtocols.contains(AutoTransportUtils.AMQP)) { 227 protocolVerifiers.put(AutoTransportUtils.AMQP, new AmqpProtocolVerifier()); 228 } 229 if (isAllProtocols() || enabledProtocols.contains(AutoTransportUtils.STOMP)) { 230 protocolVerifiers.put(AutoTransportUtils.STOMP, new StompProtocolVerifier()); 231 } 232 if (isAllProtocols()|| enabledProtocols.contains(AutoTransportUtils.MQTT)) { 233 protocolVerifiers.put(AutoTransportUtils.MQTT, new MqttProtocolVerifier()); 234 } 235 } 236 237 protected void initOpenWireProtocolVerifier() { 238 if (isAllProtocols() || enabledProtocols.contains(AutoTransportUtils.OPENWIRE)) { 239 OpenWireProtocolVerifier owpv; 240 if (wireFormatFactory instanceof OpenWireFormatFactory) { 241 owpv = new OpenWireProtocolVerifier((OpenWireFormatFactory) wireFormatFactory); 242 } else { 243 owpv = new OpenWireProtocolVerifier(new OpenWireFormatFactory()); 244 } 245 protocolVerifiers.put(AutoTransportUtils.OPENWIRE, owpv); 246 } 247 } 248 249 protected boolean isAllProtocols() { 250 return enabledProtocols == null || enabledProtocols.isEmpty(); 251 } 252 253 @Override 254 protected void handleSocket(final Socket socket) { 255 final AutoTcpTransportServer server = this; 256 //This needs to be done in a new thread because 257 //the socket might be waiting on the client to send bytes 258 //doHandleSocket can't complete until the protocol can be detected 259 newConnectionExecutor.submit(new Runnable() { 260 @Override 261 public void run() { 262 server.doHandleSocket(socket); 263 } 264 }); 265 } 266 267 @Override 268 protected TransportInfo configureTransport(final TcpTransportServer server, final Socket socket) throws Exception { 269 final InputStream is = socket.getInputStream(); 270 final AtomicInteger readBytes = new AtomicInteger(0); 271 final ByteBuffer data = ByteBuffer.allocate(8); 272 273 // We need to peak at the first 8 bytes of the buffer to detect the protocol 274 Future<?> future = protocolDetectionExecutor.submit(new Runnable() { 275 @Override 276 public void run() { 277 try { 278 do { 279 //will block until enough bytes or read or a timeout 280 //and the socket is closed 281 int read = is.read(); 282 if (read == -1) { 283 throw new IOException("Connection failed, stream is closed."); 284 } 285 data.put((byte) read); 286 readBytes.incrementAndGet(); 287 } while (readBytes.get() < 8 && !Thread.interrupted()); 288 } catch (Exception e) { 289 throw new IllegalStateException(e); 290 } 291 } 292 }); 293 294 try { 295 //If this fails and throws an exception and the socket will be closed 296 waitForProtocolDetectionFinish(future, readBytes); 297 } finally { 298 //call cancel in case task didn't complete 299 future.cancel(true); 300 } 301 data.flip(); 302 ProtocolInfo protocolInfo = detectProtocol(data.array()); 303 304 InitBuffer initBuffer = new InitBuffer(readBytes.get(), ByteBuffer.allocate(readBytes.get())); 305 initBuffer.buffer.put(data.array()); 306 307 if (protocolInfo.detectedTransportFactory instanceof BrokerServiceAware) { 308 ((BrokerServiceAware) protocolInfo.detectedTransportFactory).setBrokerService(brokerService); 309 } 310 311 WireFormat format = protocolInfo.detectedWireFormatFactory.createWireFormat(); 312 Transport transport = createTransport(socket, format, protocolInfo.detectedTransportFactory, initBuffer); 313 314 return new TransportInfo(format, transport, protocolInfo.detectedTransportFactory); 315 } 316 317 protected void waitForProtocolDetectionFinish(final Future<?> future, final AtomicInteger readBytes) throws Exception { 318 try { 319 //Wait for protocolDetectionTimeOut if defined 320 if (protocolDetectionTimeOut > 0) { 321 future.get(protocolDetectionTimeOut, TimeUnit.MILLISECONDS); 322 } else { 323 future.get(); 324 } 325 } catch (TimeoutException e) { 326 throw new InactivityIOException("Client timed out before wire format could be detected. " + 327 " 8 bytes are required to detect the protocol but only: " + readBytes.get() + " byte(s) were sent."); 328 } 329 } 330 331 /** 332 * @param socket 333 * @param format 334 * @param detectedTransportFactory 335 * @return 336 */ 337 protected TcpTransport createTransport(Socket socket, WireFormat format, 338 TcpTransportFactory detectedTransportFactory, InitBuffer initBuffer) throws IOException { 339 return new TcpTransport(format, socket, initBuffer); 340 } 341 342 public void setWireFormatOptions(Map<String, Map<String, Object>> wireFormatOptions) { 343 this.wireFormatOptions = wireFormatOptions; 344 } 345 346 public void setEnabledProtocols(Set<String> enabledProtocols) { 347 this.enabledProtocols = enabledProtocols; 348 } 349 350 public void setAutoTransportOptions(Map<String, Object> autoTransportOptions) { 351 this.autoTransportOptions = autoTransportOptions; 352 if (autoTransportOptions.get("protocols") != null) { 353 this.enabledProtocols = AutoTransportUtils.parseProtocols((String) autoTransportOptions.get("protocols")); 354 } 355 } 356 @Override 357 protected void doStop(ServiceStopper stopper) throws Exception { 358 if (newConnectionExecutor != null) { 359 newConnectionExecutor.shutdownNow(); 360 try { 361 if (!newConnectionExecutor.awaitTermination(3, TimeUnit.SECONDS)) { 362 LOG.warn("Auto Transport newConnectionExecutor didn't shutdown cleanly"); 363 } 364 } catch (InterruptedException e) { 365 } 366 } 367 if (protocolDetectionExecutor != null) { 368 protocolDetectionExecutor.shutdownNow(); 369 try { 370 if (!protocolDetectionExecutor.awaitTermination(3, TimeUnit.SECONDS)) { 371 LOG.warn("Auto Transport protocolDetectionExecutor didn't shutdown cleanly"); 372 } 373 } catch (InterruptedException e) { 374 } 375 } 376 super.doStop(stopper); 377 } 378 379 protected ProtocolInfo detectProtocol(byte[] buffer) throws IOException { 380 TcpTransportFactory detectedTransportFactory = transportFactory; 381 WireFormatFactory detectedWireFormatFactory = wireFormatFactory; 382 383 boolean found = false; 384 for (String scheme : protocolVerifiers.keySet()) { 385 if (protocolVerifiers.get(scheme).isProtocol(buffer)) { 386 LOG.debug("Detected protocol " + scheme); 387 detectedWireFormatFactory = findWireFormatFactory(scheme, wireFormatOptions); 388 389 if (scheme.equals("default")) { 390 scheme = ""; 391 } 392 393 detectedTransportFactory = (TcpTransportFactory) findTransportFactory(scheme, transportOptions); 394 found = true; 395 break; 396 } 397 } 398 399 if (!found) { 400 throw new IllegalStateException("Could not detect the wire format"); 401 } 402 403 return new ProtocolInfo(detectedTransportFactory, detectedWireFormatFactory); 404 405 } 406 407 protected class ProtocolInfo { 408 public final TcpTransportFactory detectedTransportFactory; 409 public final WireFormatFactory detectedWireFormatFactory; 410 411 public ProtocolInfo(TcpTransportFactory detectedTransportFactory, 412 WireFormatFactory detectedWireFormatFactory) { 413 super(); 414 this.detectedTransportFactory = detectedTransportFactory; 415 this.detectedWireFormatFactory = detectedWireFormatFactory; 416 } 417 } 418 419}