001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.servicemix.wsn.component;
018    
019    import java.io.StringWriter;
020    import java.io.PrintWriter;
021    import java.lang.reflect.InvocationTargetException;
022    import java.lang.reflect.Method;
023    import java.util.ArrayList;
024    import java.util.Arrays;
025    import java.util.List;
026    import java.util.Collections;
027    import java.util.GregorianCalendar;
028    import java.util.Set;
029    import java.util.HashSet;
030    
031    import javax.jbi.messaging.ExchangeStatus;
032    import javax.jbi.messaging.Fault;
033    import javax.jbi.messaging.InOnly;
034    import javax.jbi.messaging.MessageExchange;
035    import javax.jbi.messaging.NormalizedMessage;
036    import javax.jws.Oneway;
037    import javax.jws.WebMethod;
038    import javax.jws.WebService;
039    import javax.xml.bind.JAXBContext;
040    import javax.xml.bind.JAXBException;
041    import javax.xml.bind.JAXBElement;
042    import javax.xml.bind.annotation.XmlRootElement;
043    import javax.xml.bind.annotation.XmlMixed;
044    import javax.xml.namespace.QName;
045    import javax.xml.ws.WebFault;
046    import javax.xml.transform.Source;
047    import javax.xml.transform.dom.DOMSource;
048    import javax.xml.datatype.DatatypeFactory;
049    
050    import org.w3c.dom.Document;
051    
052    import org.apache.servicemix.common.endpoints.ProviderEndpoint;
053    import org.apache.servicemix.common.util.URIResolver;
054    import org.apache.servicemix.wsn.ComponentContextAware;
055    import org.apache.servicemix.wsn.jbi.JbiWrapperHelper;
056    import org.oasis_open.docs.wsrf.bf_2.BaseFaultType;
057    
058    public class WSNEndpoint extends ProviderEndpoint {
059    
060        protected String address;
061    
062        protected Object pojo;
063    
064        protected JAXBContext jaxbContext;
065    
066        protected Set<Class> endpointInterfaces = new HashSet<Class>();
067    
068        public WSNEndpoint(String address, Object pojo) {
069            this.address = address;
070            this.pojo = pojo;
071            String[] parts = URIResolver.split3(address);
072            service = new QName(parts[0], parts[1]);
073            endpoint = parts[2];
074        }
075    
076        @Override
077        public void activate() throws Exception {
078            WebService ws = getWebServiceAnnotation(pojo.getClass());
079            if (ws == null) {
080                throw new IllegalStateException("Unable to find WebService annotation");
081            }
082            Class mainInterface = Class.forName(ws.endpointInterface());
083            endpointInterfaces.add(mainInterface);
084            // Check additional interfaces
085            for (Class pojoClass = pojo.getClass(); pojoClass != Object.class; pojoClass = pojoClass.getSuperclass()) {
086                for (Class cl : pojoClass.getInterfaces()) {
087                    if (getWebServiceAnnotation(cl) != null) {
088                        endpointInterfaces.add(cl);
089                    }
090                }
091            }
092            jaxbContext = createJAXBContext(endpointInterfaces);
093            ws = getWebServiceAnnotation(mainInterface);
094            if (ws != null) {
095                interfaceName = new QName(ws.targetNamespace(), ws.name());
096            }
097            super.activate();
098            if (pojo instanceof ComponentContextAware) {
099                ((ComponentContextAware) pojo).setContext(getContext());
100            }
101        }
102    
103        public static JAXBContext createJAXBContext(Class interfaceClass) throws JAXBException {
104            return createJAXBContext(Collections.singletonList(interfaceClass));
105        }
106    
107        public static JAXBContext createJAXBContext(Iterable<Class> interfaceClasses) throws JAXBException {
108            List<Class> classes = new ArrayList<Class>();
109            classes.add(JbiFault.class);
110            classes.add(XmlException.class);
111            for (Class interfaceClass : interfaceClasses) {
112                for (Method mth : interfaceClass.getMethods()) {
113                    WebMethod wm = (WebMethod) mth.getAnnotation(WebMethod.class);
114                    if (wm != null) {
115                        classes.add(mth.getReturnType());
116                        classes.addAll(Arrays.asList(mth.getParameterTypes()));
117                    }
118                }
119            }
120            return JAXBContext.newInstance(classes.toArray(new Class[classes.size()]));
121        }
122    
123        public JAXBContext getJaxbContext() {
124            return jaxbContext;
125        }
126    
127        @SuppressWarnings("unchecked")
128        public void process(MessageExchange exchange) throws Exception {
129            if (exchange.getStatus() == ExchangeStatus.DONE) {
130                return;
131            } else if (exchange.getStatus() == ExchangeStatus.ERROR) {
132                return;
133            }
134    
135            boolean isJbiWrapped = false;
136            Source source = exchange.getMessage("in").getContent();
137            // Unwrap JBI message if needed
138            source = JbiWrapperHelper.unwrap(source);
139    
140            Object input = jaxbContext.createUnmarshaller().unmarshal(source);
141            Method webMethod = null;
142            Class inputClass = input.getClass();
143            if (input instanceof JAXBElement) {
144                inputClass = ((JAXBElement) input).getDeclaredType();
145                input = ((JAXBElement) input).getValue(); 
146            }
147            for (Class clazz : endpointInterfaces) {
148                for (Method mth : clazz.getMethods()) {
149                    Class[] params = mth.getParameterTypes();
150                    if (params.length == 1 && params[0].isAssignableFrom(inputClass)) {
151                        if (webMethod == null) {
152                            webMethod = mth;
153                        } else if (!mth.getName().equals(webMethod.getName())) {
154                            throw new IllegalStateException("Multiple methods matching parameters");
155                        }
156                    }
157                }
158            }
159            if (webMethod == null) {
160                throw new IllegalStateException("Could not determine invoked web method");
161            }
162            boolean oneWay = webMethod.getAnnotation(Oneway.class) != null;
163            Object output;
164            try {
165                output = webMethod.invoke(pojo, new Object[] {input });
166            } catch (InvocationTargetException e) {
167                if (e.getCause() instanceof Exception) {
168                    WebFault fa = (WebFault) e.getCause().getClass().getAnnotation(WebFault.class);
169                    if (!(exchange instanceof InOnly) && fa != null) {
170                        BaseFaultType info = (BaseFaultType) e.getCause().getClass().getMethod("getFaultInfo").invoke(e.getCause());
171                        // Set description if not already set
172                        if (info.getDescription().size() == 0) {
173                            BaseFaultType.Description desc = new BaseFaultType.Description();
174                            desc.setValue(e.getCause().getMessage());
175                            info.getDescription().add(desc);
176                        }
177                        // TODO: create originator field?
178                        // Set timestamp if needed
179                        if (info.getTimestamp() == null) {
180                            info.setTimestamp(DatatypeFactory.newInstance().newXMLGregorianCalendar(new GregorianCalendar()));
181                        }
182                        
183                        // TODO: do we want to send the full stack trace here ?
184                        //BaseFaultType.FaultCause cause = new BaseFaultType.FaultCause();
185                        //cause.setAny(new XmlException(e.getCause()));
186                        //info.setFaultCause(cause);
187                        Fault fault = exchange.createFault();
188                        exchange.setFault(fault);
189                        Document doc = JbiWrapperHelper.createDocument();
190                        JAXBElement el = new JAXBElement(new QName(fa.targetNamespace(), fa.name()), info.getClass(), null, info);
191                        jaxbContext.createMarshaller().marshal(el, doc);
192                        if (isJbiWrapped) {
193                            JbiWrapperHelper.wrap(doc);
194                        }
195                        fault.setContent(new DOMSource(doc));
196                        send(exchange);
197                        return;
198                    } else {
199                        throw (Exception) e.getCause();
200                    }
201                } else if (e.getCause() instanceof Error) {
202                    throw (Error) e.getCause();
203                } else {
204                    throw new RuntimeException(e.getCause());
205                }
206            }
207            if (oneWay) {
208                exchange.setStatus(ExchangeStatus.DONE);
209                send(exchange);
210            } else {
211                NormalizedMessage msg = exchange.createMessage();
212                exchange.setMessage(msg, "out");
213                Document doc = JbiWrapperHelper.createDocument();
214                jaxbContext.createMarshaller().marshal(output, doc);
215                if (isJbiWrapped) {
216                    JbiWrapperHelper.wrap(doc);
217                }
218                msg.setContent(new DOMSource(doc));
219                sendSync(exchange);
220            }
221        }
222    
223        @XmlRootElement(name = "Fault")
224        public static class JbiFault {
225            private BaseFaultType info;
226    
227            public JbiFault() {
228            }
229    
230            public JbiFault(BaseFaultType info) {
231                this.info = info;
232            }
233    
234            public BaseFaultType getInfo() {
235                return info;
236            }
237    
238            public void setInfo(BaseFaultType info) {
239                this.info = info;
240            }
241        }
242    
243        @XmlRootElement(name = "Exception")
244        public static class XmlException {
245            private String stackTrace;
246            public XmlException() {
247            }
248            public XmlException(Throwable e) {
249                StringWriter sw = new StringWriter();
250                e.printStackTrace(new PrintWriter(sw));
251                stackTrace = sw.toString();
252            }
253            public String getStackTrace() {
254                return stackTrace;
255            }
256            public void setStackTrace(String stackTrace) {
257                this.stackTrace = stackTrace;
258            }
259            @XmlMixed
260            public List getContent() {
261                return Collections.singletonList(stackTrace);
262            }
263        }
264    
265        protected Method getWebServiceMethod(QName interfaceName, QName operation) throws Exception {
266            WebService ws = getWebServiceAnnotation(pojo.getClass());
267            if (ws == null) {
268                throw new IllegalStateException("Unable to find WebService annotation");
269            }
270            Class itf = Class.forName(ws.endpointInterface());
271            for (Method mth : itf.getMethods()) {
272                WebMethod wm = (WebMethod) mth.getAnnotation(WebMethod.class);
273                if (wm != null) {
274                    // TODO: get name ?
275                }
276            }
277            return null;
278        }
279    
280        @SuppressWarnings("unchecked")
281        protected WebService getWebServiceAnnotation(Class clazz) {
282            for (Class cl = clazz; cl != null; cl = cl.getSuperclass()) {
283                WebService ws = (WebService) cl.getAnnotation(WebService.class);
284                if (ws != null) {
285                    return ws;
286                }
287            }
288            return null;
289        }
290    
291    }