001/**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.activemq.transport.auto;
018
019import java.io.IOException;
020import java.io.InputStream;
021import java.net.Socket;
022import java.net.URI;
023import java.net.URISyntaxException;
024import java.nio.ByteBuffer;
025import java.util.HashMap;
026import java.util.Map;
027import java.util.Set;
028import java.util.concurrent.ConcurrentHashMap;
029import java.util.concurrent.ConcurrentMap;
030import java.util.concurrent.Future;
031import java.util.concurrent.LinkedBlockingQueue;
032import java.util.concurrent.ThreadPoolExecutor;
033import java.util.concurrent.TimeUnit;
034import java.util.concurrent.TimeoutException;
035import java.util.concurrent.atomic.AtomicInteger;
036
037import javax.net.ServerSocketFactory;
038
039import org.apache.activemq.broker.BrokerService;
040import org.apache.activemq.broker.BrokerServiceAware;
041import org.apache.activemq.openwire.OpenWireFormatFactory;
042import org.apache.activemq.transport.InactivityIOException;
043import org.apache.activemq.transport.Transport;
044import org.apache.activemq.transport.TransportFactory;
045import org.apache.activemq.transport.TransportServer;
046import org.apache.activemq.transport.protocol.AmqpProtocolVerifier;
047import org.apache.activemq.transport.protocol.MqttProtocolVerifier;
048import org.apache.activemq.transport.protocol.OpenWireProtocolVerifier;
049import org.apache.activemq.transport.protocol.ProtocolVerifier;
050import org.apache.activemq.transport.protocol.StompProtocolVerifier;
051import org.apache.activemq.transport.tcp.TcpTransport;
052import org.apache.activemq.transport.tcp.TcpTransport.InitBuffer;
053import org.apache.activemq.transport.tcp.TcpTransportFactory;
054import org.apache.activemq.transport.tcp.TcpTransportServer;
055import org.apache.activemq.util.FactoryFinder;
056import org.apache.activemq.util.IOExceptionSupport;
057import org.apache.activemq.util.IntrospectionSupport;
058import org.apache.activemq.util.ServiceStopper;
059import org.apache.activemq.wireformat.WireFormat;
060import org.apache.activemq.wireformat.WireFormatFactory;
061import org.slf4j.Logger;
062import org.slf4j.LoggerFactory;
063
064/**
065 * A TCP based implementation of {@link TransportServer}
066 */
067public class AutoTcpTransportServer extends TcpTransportServer {
068
069    private static final Logger LOG = LoggerFactory.getLogger(AutoTcpTransportServer.class);
070
071    protected Map<String, Map<String, Object>> wireFormatOptions;
072    protected Map<String, Object> autoTransportOptions;
073    protected Set<String> enabledProtocols;
074    protected final Map<String, ProtocolVerifier> protocolVerifiers = new ConcurrentHashMap<String, ProtocolVerifier>();
075
076    protected BrokerService brokerService;
077
078    protected final ThreadPoolExecutor newConnectionExecutor;
079    protected final ThreadPoolExecutor protocolDetectionExecutor;
080    protected int maxConnectionThreadPoolSize = Integer.MAX_VALUE;
081    protected int protocolDetectionTimeOut = 30000;
082
083    private static final FactoryFinder TRANSPORT_FACTORY_FINDER = new FactoryFinder("META-INF/services/org/apache/activemq/transport/");
084    private final ConcurrentMap<String, TransportFactory> transportFactories = new ConcurrentHashMap<String, TransportFactory>();
085
086    private static final FactoryFinder WIREFORMAT_FACTORY_FINDER = new FactoryFinder("META-INF/services/org/apache/activemq/wireformat/");
087
088    public WireFormatFactory findWireFormatFactory(String scheme, Map<String, Map<String, Object>> options) throws IOException {
089        WireFormatFactory wff = null;
090        try {
091            wff = (WireFormatFactory)WIREFORMAT_FACTORY_FINDER.newInstance(scheme);
092            if (options != null) {
093                final Map<String, Object> wfOptions = new HashMap<>();
094                if (options.get(AutoTransportUtils.ALL) != null) {
095                    wfOptions.putAll(options.get(AutoTransportUtils.ALL));
096                }
097                if (options.get(scheme) != null) {
098                    wfOptions.putAll(options.get(scheme));
099                }
100                IntrospectionSupport.setProperties(wff, wfOptions);
101            }
102            return wff;
103        } catch (Throwable e) {
104           throw IOExceptionSupport.create("Could not create wire format factory for: " + scheme + ", reason: " + e, e);
105        }
106    }
107
108    public TransportFactory findTransportFactory(String scheme, Map<String, ?> options) throws IOException {
109        scheme = append(scheme, "nio");
110        scheme = append(scheme, "ssl");
111
112        if (scheme.isEmpty()) {
113            scheme = "tcp";
114        }
115
116        TransportFactory tf = transportFactories.get(scheme);
117        if (tf == null) {
118            // Try to load if from a META-INF property.
119            try {
120                tf = (TransportFactory)TRANSPORT_FACTORY_FINDER.newInstance(scheme);
121                if (options != null) {
122                    IntrospectionSupport.setProperties(tf, options);
123                }
124                transportFactories.put(scheme, tf);
125            } catch (Throwable e) {
126                throw IOExceptionSupport.create("Transport scheme NOT recognized: [" + scheme + "]", e);
127            }
128        }
129        return tf;
130    }
131
132    protected String append(String currentScheme, String scheme) {
133        if (this.getBindLocation().getScheme().contains(scheme)) {
134            if (!currentScheme.isEmpty()) {
135                currentScheme += "+";
136            }
137            currentScheme += scheme;
138        }
139        return currentScheme;
140    }
141
142    /**
143     * @param transportFactory
144     * @param location
145     * @param serverSocketFactory
146     * @throws IOException
147     * @throws URISyntaxException
148     */
149    public AutoTcpTransportServer(TcpTransportFactory transportFactory,
150            URI location, ServerSocketFactory serverSocketFactory, BrokerService brokerService,
151            Set<String> enabledProtocols)
152            throws IOException, URISyntaxException {
153        super(transportFactory, location, serverSocketFactory);
154
155        //Use an executor service here to handle new connections.  Setting the max number
156        //of threads to the maximum number of connections the thread count isn't unbounded
157        newConnectionExecutor = new ThreadPoolExecutor(maxConnectionThreadPoolSize,
158                maxConnectionThreadPoolSize,
159                30L, TimeUnit.SECONDS,
160                new LinkedBlockingQueue<Runnable>());
161        //allow the thread pool to shrink if the max number of threads isn't needed
162        //and the pool can grow and shrink as needed if contention is high
163        newConnectionExecutor.allowCoreThreadTimeOut(true);
164
165        //Executor for waiting for bytes to detection of protocol
166        protocolDetectionExecutor = new ThreadPoolExecutor(maxConnectionThreadPoolSize,
167                maxConnectionThreadPoolSize,
168                30L, TimeUnit.SECONDS,
169                new LinkedBlockingQueue<Runnable>());
170        //allow the thread pool to shrink if the max number of threads isn't needed
171        protocolDetectionExecutor.allowCoreThreadTimeOut(true);
172
173        this.brokerService = brokerService;
174        this.enabledProtocols = enabledProtocols;
175        initProtocolVerifiers();
176    }
177
178    public int getMaxConnectionThreadPoolSize() {
179        return maxConnectionThreadPoolSize;
180    }
181
182    /**
183     * Set the number of threads to be used for processing connections.  Defaults
184     * to Integer.MAX_SIZE.  Set this value to be lower to reduce the
185     * number of simultaneous connection attempts.  If not set then the maximum number of
186     * threads will generally be controlled by the transport maxConnections setting:
187     * {@link TcpTransportServer#setMaximumConnections(int)}.
188     *<p>
189     * Note that this setter controls two thread pools because connection attempts
190     * require 1 thread to start processing the connection and another thread to read from the
191     * socket and to detect the protocol. Two threads are needed because some transports
192     * block on socket read so the first thread needs to be able to abort the second thread on timeout.
193     * Therefore this setting will set each thread pool to the size passed in essentially giving
194     * 2 times as many potential threads as the value set.
195     *<p>
196     * Both thread pools will close idle threads after a period of time
197     * essentially allowing the thread pools to grow and shrink dynamically based on load.
198     *
199     * @see {@link TcpTransportServer#setMaximumConnections(int)}.
200     * @param maxConnectionThreadPoolSize
201     */
202    public void setMaxConnectionThreadPoolSize(int maxConnectionThreadPoolSize) {
203        this.maxConnectionThreadPoolSize = maxConnectionThreadPoolSize;
204        newConnectionExecutor.setCorePoolSize(maxConnectionThreadPoolSize);
205        newConnectionExecutor.setMaximumPoolSize(maxConnectionThreadPoolSize);
206        protocolDetectionExecutor.setCorePoolSize(maxConnectionThreadPoolSize);
207        protocolDetectionExecutor.setMaximumPoolSize(maxConnectionThreadPoolSize);
208    }
209
210    public void setProtocolDetectionTimeOut(int protocolDetectionTimeOut) {
211        this.protocolDetectionTimeOut = protocolDetectionTimeOut;
212    }
213
214    @Override
215    public void setWireFormatFactory(WireFormatFactory factory) {
216        super.setWireFormatFactory(factory);
217        initOpenWireProtocolVerifier();
218    }
219
220    protected void initProtocolVerifiers() {
221        initOpenWireProtocolVerifier();
222
223        if (isAllProtocols() || enabledProtocols.contains(AutoTransportUtils.AMQP)) {
224            protocolVerifiers.put(AutoTransportUtils.AMQP, new AmqpProtocolVerifier());
225        }
226        if (isAllProtocols() || enabledProtocols.contains(AutoTransportUtils.STOMP)) {
227            protocolVerifiers.put(AutoTransportUtils.STOMP, new StompProtocolVerifier());
228        }
229        if (isAllProtocols()|| enabledProtocols.contains(AutoTransportUtils.MQTT)) {
230            protocolVerifiers.put(AutoTransportUtils.MQTT, new MqttProtocolVerifier());
231        }
232    }
233
234    protected void initOpenWireProtocolVerifier() {
235        if (isAllProtocols() || enabledProtocols.contains(AutoTransportUtils.OPENWIRE)) {
236            OpenWireProtocolVerifier owpv;
237            if (wireFormatFactory instanceof OpenWireFormatFactory) {
238                owpv = new OpenWireProtocolVerifier(((OpenWireFormatFactory) wireFormatFactory).isSizePrefixDisabled());
239            } else {
240                owpv = new OpenWireProtocolVerifier(new OpenWireFormatFactory().isSizePrefixDisabled());
241            }
242            protocolVerifiers.put(AutoTransportUtils.OPENWIRE, owpv);
243        }
244    }
245
246    protected boolean isAllProtocols() {
247        return enabledProtocols == null || enabledProtocols.isEmpty();
248    }
249
250    @Override
251    protected void handleSocket(final Socket socket) {
252        final AutoTcpTransportServer server = this;
253        //This needs to be done in a new thread because
254        //the socket might be waiting on the client to send bytes
255        //doHandleSocket can't complete until the protocol can be detected
256        newConnectionExecutor.submit(new Runnable() {
257            @Override
258            public void run() {
259                server.doHandleSocket(socket);
260            }
261        });
262    }
263
264    @Override
265    protected TransportInfo configureTransport(final TcpTransportServer server, final Socket socket) throws Exception {
266        final InputStream is = socket.getInputStream();
267        final AtomicInteger readBytes = new AtomicInteger(0);
268        final ByteBuffer data = ByteBuffer.allocate(8);
269
270        // We need to peak at the first 8 bytes of the buffer to detect the protocol
271        Future<?> future = protocolDetectionExecutor.submit(new Runnable() {
272            @Override
273            public void run() {
274                try {
275                    do {
276                        //will block until enough bytes or read or a timeout
277                        //and the socket is closed
278                        int read = is.read();
279                        if (read == -1) {
280                            throw new IOException("Connection failed, stream is closed.");
281                        }
282                        data.put((byte) read);
283                        readBytes.incrementAndGet();
284                    } while (readBytes.get() < 8 && !Thread.interrupted());
285                } catch (Exception e) {
286                    throw new IllegalStateException(e);
287                }
288            }
289        });
290
291        try {
292            //If this fails and throws an exception and the socket will be closed
293            waitForProtocolDetectionFinish(future, readBytes);
294        } finally {
295            //call cancel in case task didn't complete
296            future.cancel(true);
297        }
298        data.flip();
299        ProtocolInfo protocolInfo = detectProtocol(data.array());
300
301        InitBuffer initBuffer = new InitBuffer(readBytes.get(), ByteBuffer.allocate(readBytes.get()));
302        initBuffer.buffer.put(data.array());
303
304        if (protocolInfo.detectedTransportFactory instanceof BrokerServiceAware) {
305            ((BrokerServiceAware) protocolInfo.detectedTransportFactory).setBrokerService(brokerService);
306        }
307
308        WireFormat format = protocolInfo.detectedWireFormatFactory.createWireFormat();
309        Transport transport = createTransport(socket, format, protocolInfo.detectedTransportFactory, initBuffer);
310
311        return new TransportInfo(format, transport, protocolInfo.detectedTransportFactory);
312    }
313
314    protected void waitForProtocolDetectionFinish(final Future<?> future, final AtomicInteger readBytes) throws Exception {
315        try {
316            //Wait for protocolDetectionTimeOut if defined
317            if (protocolDetectionTimeOut > 0) {
318                future.get(protocolDetectionTimeOut, TimeUnit.MILLISECONDS);
319            } else {
320                future.get();
321            }
322        } catch (TimeoutException e) {
323            throw new InactivityIOException("Client timed out before wire format could be detected. " +
324                    " 8 bytes are required to detect the protocol but only: " + readBytes.get() + " byte(s) were sent.");
325        }
326    }
327
328    /**
329     * @param socket
330     * @param format
331     * @param detectedTransportFactory
332     * @return
333     */
334    protected TcpTransport createTransport(Socket socket, WireFormat format,
335            TcpTransportFactory detectedTransportFactory, InitBuffer initBuffer) throws IOException {
336        return new TcpTransport(format, socket, initBuffer);
337    }
338
339    public void setWireFormatOptions(Map<String, Map<String, Object>> wireFormatOptions) {
340        this.wireFormatOptions = wireFormatOptions;
341    }
342
343    public void setEnabledProtocols(Set<String> enabledProtocols) {
344        this.enabledProtocols = enabledProtocols;
345    }
346
347    public void setAutoTransportOptions(Map<String, Object> autoTransportOptions) {
348        this.autoTransportOptions = autoTransportOptions;
349        if (autoTransportOptions.get("protocols") != null) {
350            this.enabledProtocols = AutoTransportUtils.parseProtocols((String) autoTransportOptions.get("protocols"));
351        }
352    }
353    @Override
354    protected void doStop(ServiceStopper stopper) throws Exception {
355        if (newConnectionExecutor != null) {
356            newConnectionExecutor.shutdownNow();
357            try {
358                if (!newConnectionExecutor.awaitTermination(3, TimeUnit.SECONDS)) {
359                    LOG.warn("Auto Transport newConnectionExecutor didn't shutdown cleanly");
360                }
361            } catch (InterruptedException e) {
362            }
363        }
364        if (protocolDetectionExecutor != null) {
365            protocolDetectionExecutor.shutdownNow();
366            try {
367                if (!protocolDetectionExecutor.awaitTermination(3, TimeUnit.SECONDS)) {
368                    LOG.warn("Auto Transport protocolDetectionExecutor didn't shutdown cleanly");
369                }
370            } catch (InterruptedException e) {
371            }
372        }
373        super.doStop(stopper);
374    }
375
376    protected ProtocolInfo detectProtocol(byte[] buffer) throws IOException {
377        TcpTransportFactory detectedTransportFactory = transportFactory;
378        WireFormatFactory detectedWireFormatFactory = wireFormatFactory;
379
380        boolean found = false;
381        for (String scheme : protocolVerifiers.keySet()) {
382            if (protocolVerifiers.get(scheme).isProtocol(buffer)) {
383                LOG.debug("Detected protocol " + scheme);
384                detectedWireFormatFactory = findWireFormatFactory(scheme, wireFormatOptions);
385
386                if (scheme.equals("default")) {
387                    scheme = "";
388                }
389
390                detectedTransportFactory = (TcpTransportFactory) findTransportFactory(scheme, transportOptions);
391                found = true;
392                break;
393            }
394        }
395
396        if (!found) {
397            throw new IllegalStateException("Could not detect the wire format");
398        }
399
400        return new ProtocolInfo(detectedTransportFactory, detectedWireFormatFactory);
401
402    }
403
404    protected class ProtocolInfo {
405        public final TcpTransportFactory detectedTransportFactory;
406        public final WireFormatFactory detectedWireFormatFactory;
407
408        public ProtocolInfo(TcpTransportFactory detectedTransportFactory,
409                WireFormatFactory detectedWireFormatFactory) {
410            super();
411            this.detectedTransportFactory = detectedTransportFactory;
412            this.detectedWireFormatFactory = detectedWireFormatFactory;
413        }
414    }
415
416}