001/**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.activemq.transport.stomp;
018
019import java.io.DataInput;
020import java.io.DataInputStream;
021import java.io.DataOutput;
022import java.io.DataOutputStream;
023import java.io.IOException;
024import java.io.InputStream;
025import java.io.PushbackInputStream;
026import java.util.HashMap;
027import java.util.Map;
028
029import org.apache.activemq.util.ByteArrayInputStream;
030import org.apache.activemq.util.ByteArrayOutputStream;
031import org.apache.activemq.util.ByteSequence;
032import org.apache.activemq.wireformat.WireFormat;
033
034/**
035 * Implements marshalling and unmarsalling the <a
036 * href="http://stomp.codehaus.org/">Stomp</a> protocol.
037 */
038public class StompWireFormat implements WireFormat {
039
040    private static final byte[] NO_DATA = new byte[] {};
041    private static final byte[] END_OF_FRAME = new byte[] {0, '\n'};
042
043    private static final int MAX_COMMAND_LENGTH = 1024;
044    private static final int MAX_HEADER_LENGTH = 1024 * 10;
045    private static final int MAX_HEADERS = 1000;
046    private static final int MAX_DATA_LENGTH = 1024 * 1024 * 100;
047
048    private int version = 1;
049    private String stompVersion = Stomp.DEFAULT_VERSION;
050
051    public ByteSequence marshal(Object command) throws IOException {
052        ByteArrayOutputStream baos = new ByteArrayOutputStream();
053        DataOutputStream dos = new DataOutputStream(baos);
054        marshal(command, dos);
055        dos.close();
056        return baos.toByteSequence();
057    }
058
059    public Object unmarshal(ByteSequence packet) throws IOException {
060        ByteArrayInputStream stream = new ByteArrayInputStream(packet);
061        DataInputStream dis = new DataInputStream(stream);
062        return unmarshal(dis);
063    }
064
065    public void marshal(Object command, DataOutput os) throws IOException {
066        StompFrame stomp = (org.apache.activemq.transport.stomp.StompFrame)command;
067
068        if (stomp.getAction().equals(Stomp.Commands.KEEPALIVE)) {
069            os.write(Stomp.BREAK);
070            return;
071        }
072
073        StringBuilder buffer = new StringBuilder();
074        buffer.append(stomp.getAction());
075        buffer.append(Stomp.NEWLINE);
076
077        // Output the headers.
078        for (Map.Entry<String, String> entry : stomp.getHeaders().entrySet()) {
079            buffer.append(entry.getKey());
080            buffer.append(Stomp.Headers.SEPERATOR);
081            buffer.append(encodeHeader(entry.getValue()));
082            buffer.append(Stomp.NEWLINE);
083        }
084
085        // Add a newline to seperate the headers from the content.
086        buffer.append(Stomp.NEWLINE);
087
088        os.write(buffer.toString().getBytes("UTF-8"));
089        os.write(stomp.getContent());
090        os.write(END_OF_FRAME);
091    }
092
093    public Object unmarshal(DataInput in) throws IOException {
094
095        try {
096
097            // parse action
098            String action = parseAction(in);
099
100            // Parse the headers
101            HashMap<String, String> headers = parseHeaders(in);
102
103            // Read in the data part.
104            byte[] data = NO_DATA;
105            String contentLength = headers.get(Stomp.Headers.CONTENT_LENGTH);
106            if ((action.equals(Stomp.Commands.SEND) || action.equals(Stomp.Responses.MESSAGE)) && contentLength != null) {
107
108                // Bless the client, he's telling us how much data to read in.
109                int length = parseContentLength(contentLength);
110
111                data = new byte[length];
112                in.readFully(data);
113
114                if (in.readByte() != 0) {
115                    throw new ProtocolException(Stomp.Headers.CONTENT_LENGTH + " bytes were read and " + "there was no trailing null byte", true);
116                }
117
118            } else {
119
120                // We don't know how much to read.. data ends when we hit a 0
121                byte b;
122                ByteArrayOutputStream baos = null;
123                while ((b = in.readByte()) != 0) {
124
125                    if (baos == null) {
126                        baos = new ByteArrayOutputStream();
127                    } else if (baos.size() > MAX_DATA_LENGTH) {
128                        throw new ProtocolException("The maximum data length was exceeded", true);
129                    }
130
131                    baos.write(b);
132                }
133
134                if (baos != null) {
135                    baos.close();
136                    data = baos.toByteArray();
137                }
138            }
139
140            return new StompFrame(action, headers, data);
141
142        } catch (ProtocolException e) {
143            return new StompFrameError(e);
144        }
145    }
146
147    private String readLine(DataInput in, int maxLength, String errorMessage) throws IOException {
148        ByteSequence sequence = readHeaderLine(in, maxLength, errorMessage);
149        return new String(sequence.getData(), sequence.getOffset(), sequence.getLength(), "UTF-8").trim();
150    }
151
152    private ByteSequence readHeaderLine(DataInput in, int maxLength, String errorMessage) throws IOException {
153        byte b;
154        ByteArrayOutputStream baos = new ByteArrayOutputStream(maxLength);
155        while ((b = in.readByte()) != '\n') {
156            if (baos.size() > maxLength) {
157                baos.close();
158                throw new ProtocolException(errorMessage, true);
159            }
160            baos.write(b);
161        }
162
163        baos.close();
164        ByteSequence line = baos.toByteSequence();
165
166        if (stompVersion.equals(Stomp.V1_0) || stompVersion.equals(Stomp.V1_2)) {
167            int lineLength = line.getLength();
168            if (lineLength > 0 && line.data[lineLength-1] == '\r') {
169                line.setLength(lineLength-1);
170            }
171        }
172
173        return line;
174    }
175
176    protected String parseAction(DataInput in) throws IOException {
177        String action = null;
178
179        // skip white space to next real action line
180        while (true) {
181            action = readLine(in, MAX_COMMAND_LENGTH, "The maximum command length was exceeded");
182            if (action == null) {
183                throw new IOException("connection was closed");
184            } else {
185                action = action.trim();
186                if (action.length() > 0) {
187                    break;
188                }
189            }
190        }
191
192        return action;
193    }
194
195    protected HashMap<String, String> parseHeaders(DataInput in) throws IOException {
196        HashMap<String, String> headers = new HashMap<String, String>(25);
197        while (true) {
198            ByteSequence line = readHeaderLine(in, MAX_HEADER_LENGTH, "The maximum header length was exceeded");
199            if (line != null && line.length > 1) {
200
201                if (headers.size() > MAX_HEADERS) {
202                    throw new ProtocolException("The maximum number of headers was exceeded", true);
203                }
204
205                try {
206
207                    ByteArrayInputStream headerLine = new ByteArrayInputStream(line);
208                    ByteArrayOutputStream stream = new ByteArrayOutputStream(line.length);
209
210                    // First complete the name
211                    int result = -1;
212                    while ((result = headerLine.read()) != -1) {
213                        if (result != ':') {
214                            stream.write(result);
215                        } else {
216                            break;
217                        }
218                    }
219
220                    ByteSequence nameSeq = stream.toByteSequence();
221
222                    String name = new String(nameSeq.getData(), nameSeq.getOffset(), nameSeq.getLength(), "UTF-8");
223                    String value = decodeHeader(headerLine);
224                    if (stompVersion.equals(Stomp.V1_0)) {
225                        value = value.trim();
226                    }
227
228                    if (!headers.containsKey(name)) {
229                        headers.put(name, value);
230                    }
231
232                    stream.close();
233
234                } catch (Exception e) {
235                    throw new ProtocolException("Unable to parser header line [" + line + "]", true);
236                }
237            } else {
238                break;
239            }
240        }
241        return headers;
242    }
243
244    protected int parseContentLength(String contentLength) throws ProtocolException {
245        int length;
246        try {
247            length = Integer.parseInt(contentLength.trim());
248        } catch (NumberFormatException e) {
249            throw new ProtocolException("Specified content-length is not a valid integer", true);
250        }
251
252        if (length > MAX_DATA_LENGTH) {
253            throw new ProtocolException("The maximum data length was exceeded", true);
254        }
255
256        return length;
257    }
258
259    private String encodeHeader(String header) throws IOException {
260        String result = header;
261        if (!stompVersion.equals(Stomp.V1_0)) {
262            byte[] utf8buf = header.getBytes("UTF-8");
263            ByteArrayOutputStream stream = new ByteArrayOutputStream(utf8buf.length);
264            for(byte val : utf8buf) {
265                switch(val) {
266                case Stomp.ESCAPE:
267                    stream.write(Stomp.ESCAPE_ESCAPE_SEQ);
268                    break;
269                case Stomp.BREAK:
270                    stream.write(Stomp.NEWLINE_ESCAPE_SEQ);
271                    break;
272                case Stomp.COLON:
273                    stream.write(Stomp.COLON_ESCAPE_SEQ);
274                    break;
275                default:
276                    stream.write(val);
277                }
278            }
279            result =  new String(stream.toByteArray(), "UTF-8");
280        }
281
282        return result;
283    }
284
285    private String decodeHeader(InputStream header) throws IOException {
286        ByteArrayOutputStream decoded = new ByteArrayOutputStream();
287        PushbackInputStream stream = new PushbackInputStream(header);
288
289        int value = -1;
290        while( (value = stream.read()) != -1) {
291            if (value == 92) {
292
293                int next = stream.read();
294                if (next != -1) {
295                    switch(next) {
296                    case 110:
297                        decoded.write(Stomp.BREAK);
298                        break;
299                    case 99:
300                        decoded.write(Stomp.COLON);
301                        break;
302                    case 92:
303                        decoded.write(Stomp.ESCAPE);
304                        break;
305                    default:
306                        stream.unread(next);
307                        decoded.write(value);
308                    }
309                } else {
310                    decoded.write(value);
311                }
312
313            } else {
314                decoded.write(value);
315            }
316        }
317
318        return new String(decoded.toByteArray(), "UTF-8");
319    }
320
321    public int getVersion() {
322        return version;
323    }
324
325    public void setVersion(int version) {
326        this.version = version;
327    }
328
329    public String getStompVersion() {
330        return stompVersion;
331    }
332
333    public void setStompVersion(String stompVersion) {
334        this.stompVersion = stompVersion;
335    }
336}