001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.servicemix.soap;
018    
019    import java.net.URI;
020    import java.util.Collections;
021    import java.util.HashMap;
022    import java.util.Iterator;
023    import java.util.List;
024    import java.util.Map;
025    
026    import javax.jbi.JBIException;
027    import javax.jbi.component.ComponentContext;
028    import javax.jbi.messaging.DeliveryChannel;
029    import javax.jbi.messaging.MessageExchange;
030    import javax.jbi.messaging.MessageExchangeFactory;
031    import javax.jbi.messaging.MessagingException;
032    import javax.jbi.messaging.NormalizedMessage;
033    import javax.jbi.servicedesc.ServiceEndpoint;
034    import javax.wsdl.Binding;
035    import javax.wsdl.Definition;
036    import javax.wsdl.Operation;
037    import javax.wsdl.Part;
038    import javax.wsdl.Port;
039    import javax.wsdl.PortType;
040    import javax.wsdl.Service;
041    import javax.wsdl.WSDLException;
042    import javax.wsdl.factory.WSDLFactory;
043    import javax.wsdl.xml.WSDLReader;
044    import javax.xml.namespace.QName;
045    
046    import org.apache.commons.logging.Log;
047    import org.apache.commons.logging.LogFactory;
048    import org.apache.servicemix.jbi.jaxp.W3CDOMStreamWriter;
049    import org.apache.servicemix.soap.marshalers.JBIMarshaler;
050    import org.apache.servicemix.soap.marshalers.SoapMarshaler;
051    import org.apache.servicemix.soap.marshalers.SoapMessage;
052    import org.apache.servicemix.soap.marshalers.SoapWriter;
053    import org.w3c.dom.Document;
054    
055    import com.ibm.wsdl.Constants;
056    
057    /**
058     * Helper class for working with soap endpoints
059     * 
060     * @author Guillaume Nodet
061     * @version $Revision: 1.5 $
062     * @since 3.0
063     */
064    public class SoapHelper {
065    
066        private static final Log logger = LogFactory.getLog(SoapHelper.class);
067    
068        public static final URI IN_ONLY = URI.create("http://www.w3.org/2004/08/wsdl/in-only");
069        public static final URI IN_OUT = URI.create("http://www.w3.org/2004/08/wsdl/in-out");
070        public static final URI ROBUST_IN_ONLY = URI.create("http://www.w3.org/2004/08/wsdl/robust-in-only");
071    
072        private SoapEndpoint endpoint;
073        private List policies;
074        private JBIMarshaler jbiMarshaler;
075        private SoapMarshaler soapMarshaler;
076        private Map definitions;
077        private Map operationNames;
078    
079        public SoapHelper(SoapEndpoint endpoint) {
080            this.policies = endpoint.getPolicies();
081            if (this.policies == null) {
082                this.policies = Collections.EMPTY_LIST;
083            }
084            this.definitions = new HashMap();
085            this.operationNames = new HashMap();
086            this.jbiMarshaler = new JBIMarshaler();
087            this.endpoint = endpoint;
088            boolean requireDom = false;
089            for (Iterator iter = policies.iterator(); iter.hasNext();) {
090                Handler handler = (Handler) iter.next();
091                requireDom |= handler.requireDOM();
092            }
093            this.soapMarshaler = new SoapMarshaler(endpoint.isSoap(), requireDom);
094            if (endpoint.isSoap() && "1.1".equals(endpoint.getSoapVersion())) {
095                this.soapMarshaler.setSoapUri(SoapMarshaler.SOAP_11_URI);
096            }
097        }
098        
099        public SoapMarshaler getSoapMarshaler() {
100            return this.soapMarshaler;
101        }
102        
103        public JBIMarshaler getJBIMarshaler() {
104            return this.jbiMarshaler;
105        }
106    
107        public MessageExchange onReceive(Context context) throws Exception {
108            if (policies != null) {
109                for (Iterator it = policies.iterator(); it.hasNext();) {
110                    Handler policy = (Handler) it.next();
111                    policy.onReceive(context);
112                }
113            }
114    
115            // If WS-A has not set informations, use the default ones
116            if (context.getProperty(Context.SERVICE) == null && context.getProperty(Context.INTERFACE) == null) {
117                // If no target endpoint / service / interface is defined
118                // we assume we use the same informations has defined on the
119                // external endpoint
120                if (endpoint.getTargetInterfaceName() == null && endpoint.getTargetService() == null
121                                && endpoint.getTargetEndpoint() == null) {
122                    context.setProperty(Context.INTERFACE, endpoint.getInterfaceName());
123                    context.setProperty(Context.SERVICE, endpoint.getService());
124                    context.setProperty(Context.ENDPOINT, endpoint.getEndpoint());
125                } else {
126                    context.setProperty(Context.INTERFACE, endpoint.getTargetInterfaceName());
127                    context.setProperty(Context.SERVICE, endpoint.getTargetService());
128                    context.setProperty(Context.ENDPOINT, endpoint.getTargetEndpoint());
129                }
130            }
131            Operation operation = findOperation(context);
132            if (context.getProperty(Context.OPERATION) == null) {
133                if (operation != null) {
134                    // the operation QName must be retrieved from the map
135                    // so that we can have the right namespace
136                    context.setProperty(Context.OPERATION, operationNames.get(operation));
137                } else if (endpoint.getDefaultOperation() != null) {
138                    context.setProperty(Context.OPERATION, endpoint.getDefaultOperation());
139                } else {
140                    // By default, use name of body element (i.e., RPC-style)
141                    QName bodyName = context.getInMessage().getBodyName();
142                    context.setProperty(Context.OPERATION, bodyName);
143                }
144            }
145            URI mep = null; 
146            if ( operation != null ) {
147                mep = getMep(operation);
148            }
149            if (mep == null) {
150                mep = endpoint.getDefaultMep();
151            }
152            MessageExchange exchange = createExchange(mep);
153            exchange.setService((QName) context.getProperty(Context.SERVICE));
154            exchange.setInterfaceName((QName) context.getProperty(Context.INTERFACE));
155            exchange.setOperation((QName) context.getProperty(Context.OPERATION));
156            if (context.getProperty(Context.ENDPOINT) != null) {
157                ComponentContext componentContext = endpoint.getServiceUnit().getComponent().getComponentContext();
158                QName serviceName = (QName) context.getProperty(Context.SERVICE);
159                String endpointName = (String) context.getProperty(Context.ENDPOINT);
160                ServiceEndpoint se = componentContext.getEndpoint(serviceName, endpointName);
161                if (se != null) {
162                    exchange.setEndpoint(se);
163                }
164            }
165            NormalizedMessage inMessage = exchange.createMessage();
166            jbiMarshaler.toNMS(inMessage, context.getInMessage());
167            exchange.setMessage(inMessage, "in");
168            return exchange;
169        }
170    
171        public SoapMessage onReply(Context context, NormalizedMessage outMsg) throws Exception {
172            SoapMessage out = new SoapMessage();
173            if (context.getInMessage() != null) {
174                out.setEnvelopeName(context.getInMessage().getEnvelopeName());
175            }
176            jbiMarshaler.fromNMS(out, outMsg);
177            context.setOutMessage(out);
178            if (policies != null) {
179                for (Iterator it = policies.iterator(); it.hasNext();) {
180                    Handler policy = (Handler) it.next();
181                    policy.onReply(context);
182                }
183            }
184            return out;
185        }
186    
187        public SoapMessage onFault(Context context, SoapFault fault) throws Exception {
188            SoapMessage soapFault = new SoapMessage();
189            soapFault.setFault(fault);
190            if (context == null) {
191                context = new Context();
192            }
193            if (context.getInMessage() != null) {
194                soapFault.setEnvelopeName(context.getInMessage().getEnvelopeName());
195            }
196            context.setFaultMessage(soapFault);
197            if (policies != null) {
198                for (Iterator it = policies.iterator(); it.hasNext();) {
199                    Handler policy = (Handler) it.next();
200                    policy.onFault(context);
201                }
202            }
203            return soapFault;
204        }
205    
206        public void onSend(Context context) throws Exception {
207            if (policies != null) {
208                for (Iterator it = policies.iterator(); it.hasNext();) {
209                    Handler policy = (Handler) it.next();
210                    if (policy.requireDOM()) {
211                        SoapWriter writer = soapMarshaler.createWriter(context.getInMessage());
212                        W3CDOMStreamWriter domWriter = new W3CDOMStreamWriter(); 
213                        writer.writeSoapEnvelope(domWriter);
214                        context.getInMessage().setDocument(domWriter.getDocument());
215                    }
216                    policy.onSend(context);
217                }
218            }
219        }
220    
221        public void onAnswer(Context context) throws Exception {
222            if (policies != null) {
223                for (Iterator it = policies.iterator(); it.hasNext();) {
224                    Handler policy = (Handler) it.next();
225                    policy.onAnswer(context);
226                }
227            }
228        }
229    
230        public Context createContext(SoapMessage message) {
231            Context context = createContext();
232            context.setInMessage(message);
233            return context;
234        }
235        
236        public Context createContext() {
237            Context context = new Context();
238            context.setProperty(Context.AUTHENTICATION_SERVICE, endpoint.getAuthenticationService());
239            context.setProperty(Context.KEYSTORE_MANAGER, endpoint.getKeystoreManager());
240            return context;
241        }
242    
243        protected MessageExchange createExchange(URI mep) throws MessagingException {
244            ComponentContext context = endpoint.getServiceUnit().getComponent().getComponentContext();
245            DeliveryChannel channel = context.getDeliveryChannel();
246            MessageExchangeFactory factory = channel.createExchangeFactory();
247            MessageExchange exchange = factory.createExchange(mep);
248            return exchange;
249        }
250    
251        private URI getMep(Operation oper) {
252            URI mep = null;
253            if (oper != null) {
254                boolean output = oper.getOutput() != null && oper.getOutput().getMessage() != null
255                                && oper.getOutput().getMessage().getParts().size() > 0;
256                boolean faults = oper.getFaults().size() > 0;
257                if (output) {
258                    mep = IN_OUT;
259                } else if (faults) {
260                    mep = ROBUST_IN_ONLY;
261                } else {
262                    mep = IN_ONLY;
263                }
264            }
265            return mep;
266        }        
267    
268        protected Operation findOperation(Context context) throws Exception {
269            QName interfaceName = (QName) context.getProperty(Context.INTERFACE);
270            QName serviceName = (QName) context.getProperty(Context.SERVICE);
271            String endpointName = (String) context.getProperty(Context.ENDPOINT);
272            ComponentContext componentContext = endpoint.getServiceUnit().getComponent().getComponentContext();
273            QName bodyName = context.getInMessage().getBodyName();
274            
275            // Find target endpoint
276            ServiceEndpoint se = null;
277            if (serviceName != null && endpointName != null) {
278                se = componentContext.getEndpoint(serviceName, endpointName);
279            }
280            if (se == null && interfaceName != null) {
281                ServiceEndpoint[] ses = componentContext.getEndpoints(interfaceName);
282                if (ses != null && ses.length > 0) {
283                    se = ses[0];
284                }
285            }
286            
287            // Find WSDL description
288            Definition definition = null;
289            if (se != null) {
290                // Find endpoint description from the component context
291                definition = getDefinition(se);
292            }
293            if (definition == null) {
294                // Get this endpoint definition
295                definition = endpoint.getDefinition();
296            }
297    
298            // Find operation matching 
299            if (definition != null) {
300                if (interfaceName != null) {
301                    PortType portType = definition.getPortType(interfaceName);
302                    if (portType != null) {
303                        return findOperationFor(portType, bodyName);
304                    }
305                } else if (definition.getService(serviceName) != null) {
306                    Service service = definition.getService(serviceName);
307                    if (endpointName != null) {
308                        Port port = service.getPort(endpointName);
309                        if (port != null) {
310                            Binding binding = port.getBinding();
311                            if (binding != null) {
312                                PortType portType = binding.getPortType();
313                                if (portType != null) {
314                                    return findOperationFor(portType, bodyName);
315                                }
316                            }
317                        }
318                    } else if (service.getPorts().size() == 1) {
319                        Port port = (Port) service.getPorts().values().iterator().next();
320                        Binding binding = port.getBinding();
321                        if (binding != null) {
322                            PortType portType = binding.getPortType();
323                            if (portType != null) {
324                                return findOperationFor(portType, bodyName);
325                            }
326                        }
327                    }
328                } else if (definition.getPortTypes().size() == 1) {
329                    PortType portType = (PortType) definition.getPortTypes().values().iterator().next();
330                    return findOperationFor(portType, bodyName);
331                }
332            }
333            return null;
334        }
335        
336        protected Operation findOperationFor(PortType portType, QName bodyName) {
337            List list = portType.getOperations();
338            for (int i = 0; i < list.size(); i++) {
339                Operation operation = (Operation) list.get(i);
340                if (operation.getInput() != null && operation.getInput().getMessage() != null) {
341                    Map parts = operation.getInput().getMessage().getParts();
342                    Iterator iter = parts.values().iterator();
343                    while (iter.hasNext()) {
344                        Part part = (Part) iter.next();
345                        QName elementName = part.getElementName();
346                        if (elementName != null && elementName.equals(bodyName)) {
347                            // found
348                            operationNames.put(operation, new QName(portType.getQName().getNamespaceURI(), operation.getName()));
349                            return operation;
350                        }
351                    }
352                }
353            }
354            return null;
355        }
356    
357        protected Definition getDefinition(ServiceEndpoint se) throws WSDLException, JBIException {
358            Definition definition;
359            ComponentContext componentContext = endpoint.getServiceUnit().getComponent().getComponentContext();
360                String key = se.getServiceName() + se.getEndpointName();
361                synchronized (definitions) {
362                    definition = (Definition) definitions.get(key);
363                    if (definition == null) {
364                        WSDLFactory factory = WSDLFactory.newInstance();
365                        Document description = componentContext.getEndpointDescriptor(se);
366                        if (description != null) {
367                            // Parse WSDL
368                            WSDLReader reader = factory.newWSDLReader(); 
369                            reader.setFeature(Constants.FEATURE_VERBOSE, false);
370                            try {
371                                definition = reader.readWSDL(null, description);
372                                definitions.put(key, definition);
373                            } catch (WSDLException e) {
374                                logger.info("Could not read wsdl from endpoint descriptor: " + e.getMessage());
375                                if (logger.isDebugEnabled()) {
376                                    logger.debug("Could not read wsdl from endpoint descriptor", e);
377                                }
378                            }
379                        }
380                    }
381                }
382            return definition;
383        }
384    
385    }