001/**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.activemq.transport.stomp;
018
019import java.io.DataInput;
020import java.io.DataInputStream;
021import java.io.DataOutput;
022import java.io.DataOutputStream;
023import java.io.IOException;
024import java.io.InputStream;
025import java.io.PushbackInputStream;
026import java.util.HashMap;
027import java.util.Map;
028
029import org.apache.activemq.util.ByteArrayInputStream;
030import org.apache.activemq.util.ByteArrayOutputStream;
031import org.apache.activemq.util.ByteSequence;
032import org.apache.activemq.wireformat.WireFormat;
033
034/**
035 * Implements marshalling and unmarsalling the <a
036 * href="http://stomp.codehaus.org/">Stomp</a> protocol.
037 */
038public class StompWireFormat implements WireFormat {
039
040    private static final byte[] NO_DATA = new byte[] {};
041    private static final byte[] END_OF_FRAME = new byte[] {0, '\n'};
042
043    private static final int MAX_COMMAND_LENGTH = 1024;
044    private static final int MAX_HEADER_LENGTH = 1024 * 10;
045    private static final int MAX_HEADERS = 1000;
046    private static final int MAX_DATA_LENGTH = 1024 * 1024 * 100;
047
048    private int version = 1;
049    private int maxDataLength = MAX_DATA_LENGTH;
050    private String stompVersion = Stomp.DEFAULT_VERSION;
051
052    @Override
053    public ByteSequence marshal(Object command) throws IOException {
054        ByteArrayOutputStream baos = new ByteArrayOutputStream();
055        DataOutputStream dos = new DataOutputStream(baos);
056        marshal(command, dos);
057        dos.close();
058        return baos.toByteSequence();
059    }
060
061    @Override
062    public Object unmarshal(ByteSequence packet) throws IOException {
063        ByteArrayInputStream stream = new ByteArrayInputStream(packet);
064        DataInputStream dis = new DataInputStream(stream);
065        return unmarshal(dis);
066    }
067
068    @Override
069    public void marshal(Object command, DataOutput os) throws IOException {
070        StompFrame stomp = (org.apache.activemq.transport.stomp.StompFrame)command;
071
072        if (stomp.getAction().equals(Stomp.Commands.KEEPALIVE)) {
073            os.write(Stomp.BREAK);
074            return;
075        }
076
077        StringBuilder buffer = new StringBuilder();
078        buffer.append(stomp.getAction());
079        buffer.append(Stomp.NEWLINE);
080
081        // Output the headers.
082        for (Map.Entry<String, String> entry : stomp.getHeaders().entrySet()) {
083            buffer.append(entry.getKey());
084            buffer.append(Stomp.Headers.SEPERATOR);
085            buffer.append(encodeHeader(entry.getValue()));
086            buffer.append(Stomp.NEWLINE);
087        }
088
089        // Add a newline to seperate the headers from the content.
090        buffer.append(Stomp.NEWLINE);
091
092        os.write(buffer.toString().getBytes("UTF-8"));
093        os.write(stomp.getContent());
094        os.write(END_OF_FRAME);
095    }
096
097    @Override
098    public Object unmarshal(DataInput in) throws IOException {
099
100        try {
101
102            // parse action
103            String action = parseAction(in);
104
105            // Parse the headers
106            HashMap<String, String> headers = parseHeaders(in);
107
108            // Read in the data part.
109            byte[] data = NO_DATA;
110            String contentLength = headers.get(Stomp.Headers.CONTENT_LENGTH);
111            if ((action.equals(Stomp.Commands.SEND) || action.equals(Stomp.Responses.MESSAGE)) && contentLength != null) {
112
113                // Bless the client, he's telling us how much data to read in.
114                int length = parseContentLength(contentLength);
115
116                data = new byte[length];
117                in.readFully(data);
118
119                if (in.readByte() != 0) {
120                    throw new ProtocolException(Stomp.Headers.CONTENT_LENGTH + " bytes were read and " + "there was no trailing null byte", true);
121                }
122
123            } else {
124
125                // We don't know how much to read.. data ends when we hit a 0
126                byte b;
127                ByteArrayOutputStream baos = null;
128                while ((b = in.readByte()) != 0) {
129
130                    if (baos == null) {
131                        baos = new ByteArrayOutputStream();
132                    } else if (baos.size() > getMaxDataLength()) {
133                        throw new ProtocolException("The maximum data length was exceeded", true);
134                    }
135
136                    baos.write(b);
137                }
138
139                if (baos != null) {
140                    baos.close();
141                    data = baos.toByteArray();
142                }
143            }
144
145            return new StompFrame(action, headers, data);
146
147        } catch (ProtocolException e) {
148            return new StompFrameError(e);
149        }
150    }
151
152    private String readLine(DataInput in, int maxLength, String errorMessage) throws IOException {
153        ByteSequence sequence = readHeaderLine(in, maxLength, errorMessage);
154        return new String(sequence.getData(), sequence.getOffset(), sequence.getLength(), "UTF-8").trim();
155    }
156
157    private ByteSequence readHeaderLine(DataInput in, int maxLength, String errorMessage) throws IOException {
158        byte b;
159        ByteArrayOutputStream baos = new ByteArrayOutputStream(maxLength);
160        while ((b = in.readByte()) != '\n') {
161            if (baos.size() > maxLength) {
162                baos.close();
163                throw new ProtocolException(errorMessage, true);
164            }
165            baos.write(b);
166        }
167
168        baos.close();
169        ByteSequence line = baos.toByteSequence();
170
171        if (stompVersion.equals(Stomp.V1_0) || stompVersion.equals(Stomp.V1_2)) {
172            int lineLength = line.getLength();
173            if (lineLength > 0 && line.data[lineLength-1] == '\r') {
174                line.setLength(lineLength-1);
175            }
176        }
177
178        return line;
179    }
180
181    protected String parseAction(DataInput in) throws IOException {
182        String action = null;
183
184        // skip white space to next real action line
185        while (true) {
186            action = readLine(in, MAX_COMMAND_LENGTH, "The maximum command length was exceeded");
187            if (action == null) {
188                throw new IOException("connection was closed");
189            } else {
190                action = action.trim();
191                if (action.length() > 0) {
192                    break;
193                }
194            }
195        }
196
197        return action;
198    }
199
200    protected HashMap<String, String> parseHeaders(DataInput in) throws IOException {
201        HashMap<String, String> headers = new HashMap<String, String>(25);
202        while (true) {
203            ByteSequence line = readHeaderLine(in, MAX_HEADER_LENGTH, "The maximum header length was exceeded");
204            if (line != null && line.length > 1) {
205
206                if (headers.size() > MAX_HEADERS) {
207                    throw new ProtocolException("The maximum number of headers was exceeded", true);
208                }
209
210                try {
211
212                    ByteArrayInputStream headerLine = new ByteArrayInputStream(line);
213                    ByteArrayOutputStream stream = new ByteArrayOutputStream(line.length);
214
215                    // First complete the name
216                    int result = -1;
217                    while ((result = headerLine.read()) != -1) {
218                        if (result != ':') {
219                            stream.write(result);
220                        } else {
221                            break;
222                        }
223                    }
224
225                    ByteSequence nameSeq = stream.toByteSequence();
226
227                    String name = new String(nameSeq.getData(), nameSeq.getOffset(), nameSeq.getLength(), "UTF-8");
228                    String value = decodeHeader(headerLine);
229                    if (stompVersion.equals(Stomp.V1_0)) {
230                        value = value.trim();
231                    }
232
233                    if (!headers.containsKey(name)) {
234                        headers.put(name, value);
235                    }
236
237                    stream.close();
238
239                } catch (Exception e) {
240                    throw new ProtocolException("Unable to parser header line [" + line + "]", true);
241                }
242            } else {
243                break;
244            }
245        }
246        return headers;
247    }
248
249    protected int parseContentLength(String contentLength) throws ProtocolException {
250        int length;
251        try {
252            length = Integer.parseInt(contentLength.trim());
253        } catch (NumberFormatException e) {
254            throw new ProtocolException("Specified content-length is not a valid integer", true);
255        }
256
257        if (length > getMaxDataLength()) {
258            throw new ProtocolException("The maximum data length was exceeded", true);
259        }
260
261        return length;
262    }
263
264    private String encodeHeader(String header) throws IOException {
265        String result = header;
266        if (!stompVersion.equals(Stomp.V1_0)) {
267            byte[] utf8buf = header.getBytes("UTF-8");
268            ByteArrayOutputStream stream = new ByteArrayOutputStream(utf8buf.length);
269            for(byte val : utf8buf) {
270                switch(val) {
271                case Stomp.ESCAPE:
272                    stream.write(Stomp.ESCAPE_ESCAPE_SEQ);
273                    break;
274                case Stomp.BREAK:
275                    stream.write(Stomp.NEWLINE_ESCAPE_SEQ);
276                    break;
277                case Stomp.COLON:
278                    stream.write(Stomp.COLON_ESCAPE_SEQ);
279                    break;
280                default:
281                    stream.write(val);
282                }
283            }
284            result =  new String(stream.toByteArray(), "UTF-8");
285            stream.close();
286        }
287
288        return result;
289    }
290
291    private String decodeHeader(InputStream header) throws IOException {
292        ByteArrayOutputStream decoded = new ByteArrayOutputStream();
293        PushbackInputStream stream = new PushbackInputStream(header);
294
295        int value = -1;
296        while( (value = stream.read()) != -1) {
297            if (value == 92) {
298
299                int next = stream.read();
300                if (next != -1) {
301                    switch(next) {
302                    case 110:
303                        decoded.write(Stomp.BREAK);
304                        break;
305                    case 99:
306                        decoded.write(Stomp.COLON);
307                        break;
308                    case 92:
309                        decoded.write(Stomp.ESCAPE);
310                        break;
311                    default:
312                        stream.unread(next);
313                        decoded.write(value);
314                    }
315                } else {
316                    decoded.write(value);
317                }
318
319            } else {
320                decoded.write(value);
321            }
322        }
323
324        decoded.close();
325
326        return new String(decoded.toByteArray(), "UTF-8");
327    }
328
329    @Override
330    public int getVersion() {
331        return version;
332    }
333
334    @Override
335    public void setVersion(int version) {
336        this.version = version;
337    }
338
339    public String getStompVersion() {
340        return stompVersion;
341    }
342
343    public void setStompVersion(String stompVersion) {
344        this.stompVersion = stompVersion;
345    }
346
347    public void setMaxDataLength(int maxDataLength) {
348        this.maxDataLength = maxDataLength;
349    }
350
351    public int getMaxDataLength() {
352        return maxDataLength;
353    }
354}