1 /*
2 * Licensed to the University Corporation for Advanced Internet Development,
3 * Inc. (UCAID) under one or more contributor license agreements. See the
4 * NOTICE file distributed with this work for additional information regarding
5 * copyright ownership. The UCAID licenses this file to You under the Apache
6 * License, Version 2.0 (the "License"); you may not use this file except in
7 * compliance with the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package edu.internet2.middleware.shibboleth.idp.util;
19
20 import java.net.InetAddress;
21 import java.net.UnknownHostException;
22 import java.util.BitSet;
23
24 import org.opensaml.xml.util.DatatypeHelper;
25
26 /** Represents a range of IP addresses. */
27 public class IPRange {
28
29 /** Number of bits within */
30 private int addressLength;
31
32 /** The IP network address for the range. */
33 private BitSet network;
34
35 /** The netmask for the range. */
36 private BitSet mask;
37
38 /**
39 * Constructor
40 *
41 * @param networkAddress the network address for the range
42 * @param maskSize the number of bits in the netmask
43 */
44 public IPRange(InetAddress networkAddress, int maskSize) {
45 this(networkAddress.getAddress(), maskSize);
46 }
47
48 /**
49 * Constructor
50 *
51 * @param networkAddress the network address for the range
52 * @param maskSize the number of bits in the netmask
53 */
54 public IPRange(byte[] networkAddress, int maskSize) {
55 addressLength = networkAddress.length * 8;
56 if (addressLength != 32 && addressLength != 128) {
57 throw new IllegalArgumentException("Network address was neither an IPv4 or IPv6 address");
58 }
59
60 network = toBitSet(networkAddress);
61 mask = new BitSet(addressLength);
62 mask.set(addressLength - maskSize, addressLength, true);
63 }
64
65 /**
66 * Parses a CIDR block definition in to an IP range.
67 *
68 * @param cidrBlock the CIDR block definition
69 *
70 * @return the resultant IP range
71 */
72 public static IPRange parseCIDRBlock(String cidrBlock){
73 String block = DatatypeHelper.safeTrimOrNullString(cidrBlock);
74 if(block == null){
75 throw new IllegalArgumentException("CIDR block definition may not be null");
76 }
77
78 String[] blockParts = block.split("/");
79 try{
80 InetAddress networkAddress = InetAddress.getByName(blockParts[0]);
81 int maskSize = Integer.parseInt(blockParts[1]);
82 return new IPRange(networkAddress, maskSize);
83 }catch(UnknownHostException e){
84 throw new IllegalArgumentException("Invalid IP address");
85 }catch(NumberFormatException e){
86 throw new IllegalArgumentException("Invalid netmask size");
87 }
88 }
89
90 /**
91 * Determines whether the given address is contained in the IP range.
92 *
93 * @param address the address to check
94 *
95 * @return true if the address is in the range, false it not
96 */
97 public boolean contains(InetAddress address) {
98 return contains(address.getAddress());
99 }
100
101 /**
102 * Determines whether the given address is contained in the IP range.
103 *
104 * @param address the address to check
105 *
106 * @return true if the address is in the range, false it not
107 */
108 public boolean contains(byte[] address) {
109 if (address.length * 8 != addressLength) {
110 return false;
111 }
112
113 BitSet addrNetwork = toBitSet(address);
114 addrNetwork.and(mask);
115
116 return addrNetwork.equals(network);
117 }
118
119 /**
120 * Converts a byte array to a BitSet.
121 *
122 * The supplied byte array is assumed to have the most significant bit in element 0.
123 *
124 * @param bytes the byte array with most significant bit in element 0.
125 *
126 * @return the BitSet
127 */
128 protected BitSet toBitSet(byte[] bytes) {
129 BitSet bits = new BitSet(bytes.length * 8);
130
131 for (int i = 0; i < bytes.length * 8; i++) {
132 if ((bytes[bytes.length - i / 8 - 1] & (1 << (i % 8))) > 0) {
133 bits.set(i);
134 }
135 }
136
137 return bits;
138 }
139 }