001/** 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.activemq.transport.auto; 018 019import java.io.IOException; 020import java.io.InputStream; 021import java.net.Socket; 022import java.net.URI; 023import java.net.URISyntaxException; 024import java.nio.ByteBuffer; 025import java.util.HashMap; 026import java.util.Map; 027import java.util.Set; 028import java.util.concurrent.ConcurrentHashMap; 029import java.util.concurrent.ConcurrentMap; 030import java.util.concurrent.Future; 031import java.util.concurrent.LinkedBlockingQueue; 032import java.util.concurrent.ThreadPoolExecutor; 033import java.util.concurrent.TimeUnit; 034import java.util.concurrent.TimeoutException; 035import java.util.concurrent.atomic.AtomicInteger; 036 037import javax.net.ServerSocketFactory; 038 039import org.apache.activemq.broker.BrokerService; 040import org.apache.activemq.broker.BrokerServiceAware; 041import org.apache.activemq.openwire.OpenWireFormatFactory; 042import org.apache.activemq.transport.InactivityIOException; 043import org.apache.activemq.transport.Transport; 044import org.apache.activemq.transport.TransportFactory; 045import org.apache.activemq.transport.TransportServer; 046import org.apache.activemq.transport.protocol.AmqpProtocolVerifier; 047import org.apache.activemq.transport.protocol.MqttProtocolVerifier; 048import org.apache.activemq.transport.protocol.OpenWireProtocolVerifier; 049import org.apache.activemq.transport.protocol.ProtocolVerifier; 050import org.apache.activemq.transport.protocol.StompProtocolVerifier; 051import org.apache.activemq.transport.tcp.TcpTransport; 052import org.apache.activemq.transport.tcp.TcpTransport.InitBuffer; 053import org.apache.activemq.transport.tcp.TcpTransportFactory; 054import org.apache.activemq.transport.tcp.TcpTransportServer; 055import org.apache.activemq.util.FactoryFinder; 056import org.apache.activemq.util.IOExceptionSupport; 057import org.apache.activemq.util.IntrospectionSupport; 058import org.apache.activemq.util.ServiceStopper; 059import org.apache.activemq.wireformat.WireFormat; 060import org.apache.activemq.wireformat.WireFormatFactory; 061import org.slf4j.Logger; 062import org.slf4j.LoggerFactory; 063 064/** 065 * A TCP based implementation of {@link TransportServer} 066 */ 067public class AutoTcpTransportServer extends TcpTransportServer { 068 069 private static final Logger LOG = LoggerFactory.getLogger(AutoTcpTransportServer.class); 070 071 protected Map<String, Map<String, Object>> wireFormatOptions; 072 protected Map<String, Object> autoTransportOptions; 073 protected Set<String> enabledProtocols; 074 protected final Map<String, ProtocolVerifier> protocolVerifiers = new ConcurrentHashMap<String, ProtocolVerifier>(); 075 076 protected BrokerService brokerService; 077 078 protected final ThreadPoolExecutor newConnectionExecutor; 079 protected final ThreadPoolExecutor protocolDetectionExecutor; 080 protected int maxConnectionThreadPoolSize = Integer.MAX_VALUE; 081 protected int protocolDetectionTimeOut = 30000; 082 083 private static final FactoryFinder TRANSPORT_FACTORY_FINDER = new FactoryFinder("META-INF/services/org/apache/activemq/transport/"); 084 private final ConcurrentMap<String, TransportFactory> transportFactories = new ConcurrentHashMap<String, TransportFactory>(); 085 086 private static final FactoryFinder WIREFORMAT_FACTORY_FINDER = new FactoryFinder("META-INF/services/org/apache/activemq/wireformat/"); 087 088 public WireFormatFactory findWireFormatFactory(String scheme, Map<String, Map<String, Object>> options) throws IOException { 089 WireFormatFactory wff = null; 090 try { 091 wff = (WireFormatFactory)WIREFORMAT_FACTORY_FINDER.newInstance(scheme); 092 if (options != null) { 093 final Map<String, Object> wfOptions = new HashMap<>(); 094 if (options.get(AutoTransportUtils.ALL) != null) { 095 wfOptions.putAll(options.get(AutoTransportUtils.ALL)); 096 } 097 if (options.get(scheme) != null) { 098 wfOptions.putAll(options.get(scheme)); 099 } 100 IntrospectionSupport.setProperties(wff, wfOptions); 101 } 102 return wff; 103 } catch (Throwable e) { 104 throw IOExceptionSupport.create("Could not create wire format factory for: " + scheme + ", reason: " + e, e); 105 } 106 } 107 108 public TransportFactory findTransportFactory(String scheme, Map<String, ?> options) throws IOException { 109 scheme = append(scheme, "nio"); 110 scheme = append(scheme, "ssl"); 111 112 if (scheme.isEmpty()) { 113 scheme = "tcp"; 114 } 115 116 TransportFactory tf = transportFactories.get(scheme); 117 if (tf == null) { 118 // Try to load if from a META-INF property. 119 try { 120 tf = (TransportFactory)TRANSPORT_FACTORY_FINDER.newInstance(scheme); 121 if (options != null) { 122 IntrospectionSupport.setProperties(tf, options); 123 } 124 transportFactories.put(scheme, tf); 125 } catch (Throwable e) { 126 throw IOExceptionSupport.create("Transport scheme NOT recognized: [" + scheme + "]", e); 127 } 128 } 129 return tf; 130 } 131 132 protected String append(String currentScheme, String scheme) { 133 if (this.getBindLocation().getScheme().contains(scheme)) { 134 if (!currentScheme.isEmpty()) { 135 currentScheme += "+"; 136 } 137 currentScheme += scheme; 138 } 139 return currentScheme; 140 } 141 142 /** 143 * @param transportFactory 144 * @param location 145 * @param serverSocketFactory 146 * @throws IOException 147 * @throws URISyntaxException 148 */ 149 public AutoTcpTransportServer(TcpTransportFactory transportFactory, 150 URI location, ServerSocketFactory serverSocketFactory, BrokerService brokerService, 151 Set<String> enabledProtocols) 152 throws IOException, URISyntaxException { 153 super(transportFactory, location, serverSocketFactory); 154 155 //Use an executor service here to handle new connections. Setting the max number 156 //of threads to the maximum number of connections the thread count isn't unbounded 157 newConnectionExecutor = new ThreadPoolExecutor(maxConnectionThreadPoolSize, 158 maxConnectionThreadPoolSize, 159 30L, TimeUnit.SECONDS, 160 new LinkedBlockingQueue<Runnable>()); 161 //allow the thread pool to shrink if the max number of threads isn't needed 162 //and the pool can grow and shrink as needed if contention is high 163 newConnectionExecutor.allowCoreThreadTimeOut(true); 164 165 //Executor for waiting for bytes to detection of protocol 166 protocolDetectionExecutor = new ThreadPoolExecutor(maxConnectionThreadPoolSize, 167 maxConnectionThreadPoolSize, 168 30L, TimeUnit.SECONDS, 169 new LinkedBlockingQueue<Runnable>()); 170 //allow the thread pool to shrink if the max number of threads isn't needed 171 protocolDetectionExecutor.allowCoreThreadTimeOut(true); 172 173 this.brokerService = brokerService; 174 this.enabledProtocols = enabledProtocols; 175 initProtocolVerifiers(); 176 } 177 178 public int getMaxConnectionThreadPoolSize() { 179 return maxConnectionThreadPoolSize; 180 } 181 182 /** 183 * Set the number of threads to be used for processing connections. Defaults 184 * to Integer.MAX_SIZE. Set this value to be lower to reduce the 185 * number of simultaneous connection attempts. If not set then the maximum number of 186 * threads will generally be controlled by the transport maxConnections setting: 187 * {@link TcpTransportServer#setMaximumConnections(int)}. 188 *<p> 189 * Note that this setter controls two thread pools because connection attempts 190 * require 1 thread to start processing the connection and another thread to read from the 191 * socket and to detect the protocol. Two threads are needed because some transports 192 * block on socket read so the first thread needs to be able to abort the second thread on timeout. 193 * Therefore this setting will set each thread pool to the size passed in essentially giving 194 * 2 times as many potential threads as the value set. 195 *<p> 196 * Both thread pools will close idle threads after a period of time 197 * essentially allowing the thread pools to grow and shrink dynamically based on load. 198 * 199 * @see {@link TcpTransportServer#setMaximumConnections(int)}. 200 * @param maxConnectionThreadPoolSize 201 */ 202 public void setMaxConnectionThreadPoolSize(int maxConnectionThreadPoolSize) { 203 this.maxConnectionThreadPoolSize = maxConnectionThreadPoolSize; 204 newConnectionExecutor.setCorePoolSize(maxConnectionThreadPoolSize); 205 newConnectionExecutor.setMaximumPoolSize(maxConnectionThreadPoolSize); 206 protocolDetectionExecutor.setCorePoolSize(maxConnectionThreadPoolSize); 207 protocolDetectionExecutor.setMaximumPoolSize(maxConnectionThreadPoolSize); 208 } 209 210 public void setProtocolDetectionTimeOut(int protocolDetectionTimeOut) { 211 this.protocolDetectionTimeOut = protocolDetectionTimeOut; 212 } 213 214 @Override 215 public void setWireFormatFactory(WireFormatFactory factory) { 216 super.setWireFormatFactory(factory); 217 initOpenWireProtocolVerifier(); 218 } 219 220 protected void initProtocolVerifiers() { 221 initOpenWireProtocolVerifier(); 222 223 if (isAllProtocols() || enabledProtocols.contains(AutoTransportUtils.AMQP)) { 224 protocolVerifiers.put(AutoTransportUtils.AMQP, new AmqpProtocolVerifier()); 225 } 226 if (isAllProtocols() || enabledProtocols.contains(AutoTransportUtils.STOMP)) { 227 protocolVerifiers.put(AutoTransportUtils.STOMP, new StompProtocolVerifier()); 228 } 229 if (isAllProtocols()|| enabledProtocols.contains(AutoTransportUtils.MQTT)) { 230 protocolVerifiers.put(AutoTransportUtils.MQTT, new MqttProtocolVerifier()); 231 } 232 } 233 234 protected void initOpenWireProtocolVerifier() { 235 if (isAllProtocols() || enabledProtocols.contains(AutoTransportUtils.OPENWIRE)) { 236 OpenWireProtocolVerifier owpv; 237 if (wireFormatFactory instanceof OpenWireFormatFactory) { 238 owpv = new OpenWireProtocolVerifier(((OpenWireFormatFactory) wireFormatFactory).isSizePrefixDisabled()); 239 } else { 240 owpv = new OpenWireProtocolVerifier(new OpenWireFormatFactory().isSizePrefixDisabled()); 241 } 242 protocolVerifiers.put(AutoTransportUtils.OPENWIRE, owpv); 243 } 244 } 245 246 protected boolean isAllProtocols() { 247 return enabledProtocols == null || enabledProtocols.isEmpty(); 248 } 249 250 @Override 251 protected void handleSocket(final Socket socket) { 252 final AutoTcpTransportServer server = this; 253 //This needs to be done in a new thread because 254 //the socket might be waiting on the client to send bytes 255 //doHandleSocket can't complete until the protocol can be detected 256 newConnectionExecutor.submit(new Runnable() { 257 @Override 258 public void run() { 259 server.doHandleSocket(socket); 260 } 261 }); 262 } 263 264 @Override 265 protected TransportInfo configureTransport(final TcpTransportServer server, final Socket socket) throws Exception { 266 final InputStream is = socket.getInputStream(); 267 final AtomicInteger readBytes = new AtomicInteger(0); 268 final ByteBuffer data = ByteBuffer.allocate(8); 269 270 // We need to peak at the first 8 bytes of the buffer to detect the protocol 271 Future<?> future = protocolDetectionExecutor.submit(new Runnable() { 272 @Override 273 public void run() { 274 try { 275 do { 276 //will block until enough bytes or read or a timeout 277 //and the socket is closed 278 int read = is.read(); 279 if (read == -1) { 280 throw new IOException("Connection failed, stream is closed."); 281 } 282 data.put((byte) read); 283 readBytes.incrementAndGet(); 284 } while (readBytes.get() < 8 && !Thread.interrupted()); 285 } catch (Exception e) { 286 throw new IllegalStateException(e); 287 } 288 } 289 }); 290 291 try { 292 //If this fails and throws an exception and the socket will be closed 293 waitForProtocolDetectionFinish(future, readBytes); 294 } finally { 295 //call cancel in case task didn't complete 296 future.cancel(true); 297 } 298 data.flip(); 299 ProtocolInfo protocolInfo = detectProtocol(data.array()); 300 301 InitBuffer initBuffer = new InitBuffer(readBytes.get(), ByteBuffer.allocate(readBytes.get())); 302 initBuffer.buffer.put(data.array()); 303 304 if (protocolInfo.detectedTransportFactory instanceof BrokerServiceAware) { 305 ((BrokerServiceAware) protocolInfo.detectedTransportFactory).setBrokerService(brokerService); 306 } 307 308 WireFormat format = protocolInfo.detectedWireFormatFactory.createWireFormat(); 309 Transport transport = createTransport(socket, format, protocolInfo.detectedTransportFactory, initBuffer); 310 311 return new TransportInfo(format, transport, protocolInfo.detectedTransportFactory); 312 } 313 314 protected void waitForProtocolDetectionFinish(final Future<?> future, final AtomicInteger readBytes) throws Exception { 315 try { 316 //Wait for protocolDetectionTimeOut if defined 317 if (protocolDetectionTimeOut > 0) { 318 future.get(protocolDetectionTimeOut, TimeUnit.MILLISECONDS); 319 } else { 320 future.get(); 321 } 322 } catch (TimeoutException e) { 323 throw new InactivityIOException("Client timed out before wire format could be detected. " + 324 " 8 bytes are required to detect the protocol but only: " + readBytes.get() + " byte(s) were sent."); 325 } 326 } 327 328 /** 329 * @param socket 330 * @param format 331 * @param detectedTransportFactory 332 * @return 333 */ 334 protected TcpTransport createTransport(Socket socket, WireFormat format, 335 TcpTransportFactory detectedTransportFactory, InitBuffer initBuffer) throws IOException { 336 return new TcpTransport(format, socket, initBuffer); 337 } 338 339 public void setWireFormatOptions(Map<String, Map<String, Object>> wireFormatOptions) { 340 this.wireFormatOptions = wireFormatOptions; 341 } 342 343 public void setEnabledProtocols(Set<String> enabledProtocols) { 344 this.enabledProtocols = enabledProtocols; 345 } 346 347 public void setAutoTransportOptions(Map<String, Object> autoTransportOptions) { 348 this.autoTransportOptions = autoTransportOptions; 349 if (autoTransportOptions.get("protocols") != null) { 350 this.enabledProtocols = AutoTransportUtils.parseProtocols((String) autoTransportOptions.get("protocols")); 351 } 352 } 353 @Override 354 protected void doStop(ServiceStopper stopper) throws Exception { 355 if (newConnectionExecutor != null) { 356 newConnectionExecutor.shutdownNow(); 357 try { 358 if (!newConnectionExecutor.awaitTermination(3, TimeUnit.SECONDS)) { 359 LOG.warn("Auto Transport newConnectionExecutor didn't shutdown cleanly"); 360 } 361 } catch (InterruptedException e) { 362 } 363 } 364 if (protocolDetectionExecutor != null) { 365 protocolDetectionExecutor.shutdownNow(); 366 try { 367 if (!protocolDetectionExecutor.awaitTermination(3, TimeUnit.SECONDS)) { 368 LOG.warn("Auto Transport protocolDetectionExecutor didn't shutdown cleanly"); 369 } 370 } catch (InterruptedException e) { 371 } 372 } 373 super.doStop(stopper); 374 } 375 376 protected ProtocolInfo detectProtocol(byte[] buffer) throws IOException { 377 TcpTransportFactory detectedTransportFactory = transportFactory; 378 WireFormatFactory detectedWireFormatFactory = wireFormatFactory; 379 380 boolean found = false; 381 for (String scheme : protocolVerifiers.keySet()) { 382 if (protocolVerifiers.get(scheme).isProtocol(buffer)) { 383 LOG.debug("Detected protocol " + scheme); 384 detectedWireFormatFactory = findWireFormatFactory(scheme, wireFormatOptions); 385 386 if (scheme.equals("default")) { 387 scheme = ""; 388 } 389 390 detectedTransportFactory = (TcpTransportFactory) findTransportFactory(scheme, transportOptions); 391 found = true; 392 break; 393 } 394 } 395 396 if (!found) { 397 throw new IllegalStateException("Could not detect the wire format"); 398 } 399 400 return new ProtocolInfo(detectedTransportFactory, detectedWireFormatFactory); 401 402 } 403 404 protected class ProtocolInfo { 405 public final TcpTransportFactory detectedTransportFactory; 406 public final WireFormatFactory detectedWireFormatFactory; 407 408 public ProtocolInfo(TcpTransportFactory detectedTransportFactory, 409 WireFormatFactory detectedWireFormatFactory) { 410 super(); 411 this.detectedTransportFactory = detectedTransportFactory; 412 this.detectedWireFormatFactory = detectedWireFormatFactory; 413 } 414 } 415 416}