001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.servicemix.wsn.component;
018    
019    import java.io.StringWriter;
020    import java.io.PrintWriter;
021    import java.lang.reflect.InvocationTargetException;
022    import java.lang.reflect.Method;
023    import java.util.ArrayList;
024    import java.util.Arrays;
025    import java.util.List;
026    import java.util.Collections;
027    import java.util.GregorianCalendar;
028    import java.util.Set;
029    import java.util.HashSet;
030    import java.util.concurrent.atomic.AtomicBoolean;
031    
032    import javax.jbi.messaging.ExchangeStatus;
033    import javax.jbi.messaging.Fault;
034    import javax.jbi.messaging.InOnly;
035    import javax.jbi.messaging.MessageExchange;
036    import javax.jbi.messaging.NormalizedMessage;
037    import javax.jws.Oneway;
038    import javax.jws.WebMethod;
039    import javax.jws.WebService;
040    import javax.xml.bind.JAXBContext;
041    import javax.xml.bind.JAXBException;
042    import javax.xml.bind.JAXBElement;
043    import javax.xml.bind.annotation.XmlRootElement;
044    import javax.xml.bind.annotation.XmlMixed;
045    import javax.xml.namespace.QName;
046    import javax.xml.ws.WebFault;
047    import javax.xml.transform.Source;
048    import javax.xml.transform.dom.DOMSource;
049    import javax.xml.datatype.DatatypeFactory;
050    
051    import org.w3c.dom.Document;
052    
053    import org.apache.servicemix.common.endpoints.ProviderEndpoint;
054    import org.apache.servicemix.common.util.URIResolver;
055    import org.apache.servicemix.wsn.ComponentContextAware;
056    import org.apache.servicemix.wsn.jbi.JbiWrapperHelper;
057    import org.oasis_open.docs.wsrf.bf_2.BaseFaultType;
058    
059    public class WSNEndpoint extends ProviderEndpoint {
060    
061        protected String address;
062    
063        protected Object pojo;
064    
065        protected JAXBContext jaxbContext;
066    
067        protected Set<Class> endpointInterfaces = new HashSet<Class>();
068    
069        public WSNEndpoint(String address, Object pojo) {
070            this.address = address;
071            this.pojo = pojo;
072            String[] parts = URIResolver.split3(address);
073            service = new QName(parts[0], parts[1]);
074            endpoint = parts[2];
075        }
076    
077        @Override
078        public void activate() throws Exception {
079            WebService ws = getWebServiceAnnotation(pojo.getClass());
080            if (ws == null) {
081                throw new IllegalStateException("Unable to find WebService annotation");
082            }
083            Class mainInterface = Class.forName(ws.endpointInterface());
084            endpointInterfaces.add(mainInterface);
085            // Check additional interfaces
086            for (Class pojoClass = pojo.getClass(); pojoClass != Object.class; pojoClass = pojoClass.getSuperclass()) {
087                for (Class cl : pojoClass.getInterfaces()) {
088                    if (getWebServiceAnnotation(cl) != null) {
089                        endpointInterfaces.add(cl);
090                    }
091                }
092            }
093            jaxbContext = createJAXBContext(endpointInterfaces);
094            ws = getWebServiceAnnotation(mainInterface);
095            if (ws != null) {
096                interfaceName = new QName(ws.targetNamespace(), ws.name());
097            }
098            super.activate();
099            if (pojo instanceof ComponentContextAware) {
100                ((ComponentContextAware) pojo).setContext(getContext());
101            }
102        }
103    
104        public static JAXBContext createJAXBContext(Class interfaceClass) throws JAXBException {
105            return createJAXBContext(Collections.singletonList(interfaceClass));
106        }
107    
108        public static JAXBContext createJAXBContext(Iterable<Class> interfaceClasses) throws JAXBException {
109            List<Class> classes = new ArrayList<Class>();
110            classes.add(JbiFault.class);
111            classes.add(XmlException.class);
112            for (Class interfaceClass : interfaceClasses) {
113                for (Method mth : interfaceClass.getMethods()) {
114                    WebMethod wm = (WebMethod) mth.getAnnotation(WebMethod.class);
115                    if (wm != null) {
116                        classes.add(mth.getReturnType());
117                        classes.addAll(Arrays.asList(mth.getParameterTypes()));
118                    }
119                }
120            }
121            return JAXBContext.newInstance(classes.toArray(new Class[classes.size()]));
122        }
123    
124        public JAXBContext getJaxbContext() {
125            return jaxbContext;
126        }
127    
128        @SuppressWarnings("unchecked")
129        public void process(MessageExchange exchange) throws Exception {
130            if (exchange.getStatus() == ExchangeStatus.DONE) {
131                return;
132            } else if (exchange.getStatus() == ExchangeStatus.ERROR) {
133                return;
134            }
135    
136            AtomicBoolean isJbiWrapped = new AtomicBoolean(false);
137            Source source = exchange.getMessage("in").getContent();
138            // Unwrap JBI message if needed
139            source = JbiWrapperHelper.unwrap(source, isJbiWrapped);
140    
141            Object input = jaxbContext.createUnmarshaller().unmarshal(source);
142            Method webMethod = null;
143            Class inputClass = input.getClass();
144            if (input instanceof JAXBElement) {
145                inputClass = ((JAXBElement) input).getDeclaredType();
146                input = ((JAXBElement) input).getValue(); 
147            }
148            for (Class clazz : endpointInterfaces) {
149                for (Method mth : clazz.getMethods()) {
150                    Class[] params = mth.getParameterTypes();
151                    if (params.length == 1 && params[0].isAssignableFrom(inputClass)) {
152                        if (webMethod == null) {
153                            webMethod = mth;
154                        } else if (!mth.getName().equals(webMethod.getName())) {
155                            throw new IllegalStateException("Multiple methods matching parameters");
156                        }
157                    }
158                }
159            }
160            if (webMethod == null) {
161                throw new IllegalStateException("Could not determine invoked web method");
162            }
163            boolean oneWay = webMethod.getAnnotation(Oneway.class) != null;
164            Object output;
165            try {
166                output = webMethod.invoke(pojo, new Object[] {input });
167            } catch (InvocationTargetException e) {
168                if (e.getCause() instanceof Exception) {
169                    WebFault fa = (WebFault) e.getCause().getClass().getAnnotation(WebFault.class);
170                    if (!(exchange instanceof InOnly) && fa != null) {
171                        BaseFaultType info = (BaseFaultType) e.getCause().getClass().getMethod("getFaultInfo").invoke(e.getCause());
172                        // Set description if not already set
173                        if (info.getDescription().size() == 0) {
174                            BaseFaultType.Description desc = new BaseFaultType.Description();
175                            desc.setValue(e.getCause().getMessage());
176                            info.getDescription().add(desc);
177                        }
178                        // TODO: create originator field?
179                        // Set timestamp if needed
180                        if (info.getTimestamp() == null) {
181                            info.setTimestamp(DatatypeFactory.newInstance().newXMLGregorianCalendar(new GregorianCalendar()));
182                        }
183                        
184                        // TODO: do we want to send the full stack trace here ?
185                        //BaseFaultType.FaultCause cause = new BaseFaultType.FaultCause();
186                        //cause.setAny(new XmlException(e.getCause()));
187                        //info.setFaultCause(cause);
188                        Fault fault = exchange.createFault();
189                        exchange.setFault(fault);
190                        Document doc = JbiWrapperHelper.createDocument();
191                        JAXBElement el = new JAXBElement(new QName(fa.targetNamespace(), fa.name()), info.getClass(), null, info);
192                        jaxbContext.createMarshaller().marshal(el, doc);
193                        if (isJbiWrapped.get()) {
194                            JbiWrapperHelper.wrap(doc);
195                        }
196                        fault.setContent(new DOMSource(doc));
197                        send(exchange);
198                        return;
199                    } else {
200                        throw (Exception) e.getCause();
201                    }
202                } else if (e.getCause() instanceof Error) {
203                    throw (Error) e.getCause();
204                } else {
205                    throw new RuntimeException(e.getCause());
206                }
207            }
208            if (oneWay) {
209                exchange.setStatus(ExchangeStatus.DONE);
210                send(exchange);
211            } else {
212                NormalizedMessage msg = exchange.createMessage();
213                exchange.setMessage(msg, "out");
214                Document doc = JbiWrapperHelper.createDocument();
215                jaxbContext.createMarshaller().marshal(output, doc);
216                if (isJbiWrapped.get()) {
217                    JbiWrapperHelper.wrap(doc);
218                }
219                msg.setContent(new DOMSource(doc));
220                sendSync(exchange);
221            }
222        }
223    
224        @XmlRootElement(name = "Fault")
225        public static class JbiFault {
226            private BaseFaultType info;
227    
228            public JbiFault() {
229            }
230    
231            public JbiFault(BaseFaultType info) {
232                this.info = info;
233            }
234    
235            public BaseFaultType getInfo() {
236                return info;
237            }
238    
239            public void setInfo(BaseFaultType info) {
240                this.info = info;
241            }
242        }
243    
244        @XmlRootElement(name = "Exception")
245        public static class XmlException {
246            private String stackTrace;
247            public XmlException() {
248            }
249            public XmlException(Throwable e) {
250                StringWriter sw = new StringWriter();
251                e.printStackTrace(new PrintWriter(sw));
252                stackTrace = sw.toString();
253            }
254            public String getStackTrace() {
255                return stackTrace;
256            }
257            public void setStackTrace(String stackTrace) {
258                this.stackTrace = stackTrace;
259            }
260            @XmlMixed
261            public List getContent() {
262                return Collections.singletonList(stackTrace);
263            }
264        }
265    
266        protected Method getWebServiceMethod(QName interfaceName, QName operation) throws Exception {
267            WebService ws = getWebServiceAnnotation(pojo.getClass());
268            if (ws == null) {
269                throw new IllegalStateException("Unable to find WebService annotation");
270            }
271            Class itf = Class.forName(ws.endpointInterface());
272            for (Method mth : itf.getMethods()) {
273                WebMethod wm = (WebMethod) mth.getAnnotation(WebMethod.class);
274                if (wm != null) {
275                    // TODO: get name ?
276                }
277            }
278            return null;
279        }
280    
281        @SuppressWarnings("unchecked")
282        protected WebService getWebServiceAnnotation(Class clazz) {
283            for (Class cl = clazz; cl != null; cl = cl.getSuperclass()) {
284                WebService ws = (WebService) cl.getAnnotation(WebService.class);
285                if (ws != null) {
286                    return ws;
287                }
288            }
289            return null;
290        }
291    
292    }