from interval_spaces.interval_space import IntervalSpace
from decimal import *
from random import uniform
[docs]class Node(object):
""" Node in the AVL tree which represents a valid interval """
def __init__(self, x: float = None, y: float = None, left: object = None, right: object = None, height: int = 1):
"""
Args:
x (float): Lower bound of the interval
y (float): Upper bound of the interval
left (Node): Left, smaller interval
right (Node): Right, larger interval
"""
self.x: Decimal = Decimal(f'{x}') if x is not None else None
self.y: Decimal = Decimal(f'{y}') if y is not None else None
self.l = left
self.r = right
self.h = height
def __str__(self):
return f'<Node ({self.x},{self.y}), height: {self.h}, left: {self.l}, right: {self.r}>'
def __repr__(self):
return self.__str__()
[docs]class TreeSpace(IntervalSpace):
""" Interval Action Space as AVL tree """
root_tree = None
size: Decimal = 0
draw = None
def __init__(self, x: float, y: float):
"""
Args:
x (float): Lower bound of the initial interval
y (float): Upper bound of the initial interval
"""
super().__init__()
getcontext().prec = 28
self.root_tree = Node(x, y)
self.size = Decimal(y) - Decimal(x)
def __contains__(self, item):
return self.contains(item)
[docs] def contains(self, x, root: object = 'root'):
""" Determines if a number is part of the action space
Args:
x: Number
root: Node to start the search from or 'root' for searching the whole tree, default is 'root'
Returns:
Boolean indicating if it is part of the action space
"""
if root == 'root':
root = self.root_tree
x = Decimal(f'{x}')
if not root:
return False
elif root.x <= x <= root.y:
return True
elif root.x > x:
return self.contains(x, root.l)
else:
return self.contains(x, root.r)
[docs] def nearest_elements(self, x, root: Node = 'root'):
""" Finds nearest actions for a number in the action space
Args:
x: Number
root: Node to start the search from or 'root' for searching the whole tree, default is 'root'
Returns:
Nearest elements in the action space. It is the number itself if it is valid.
"""
if root == 'root':
root = self.root_tree
x = Decimal(f'{x}')
if x > root.y:
return self._nearest_elements(x, x - root.y, root.y, root.r)
elif x < root.x:
return self._nearest_elements(x, root.x - x, root.x, root.l)
else:
return x
def _nearest_elements(self, x, min_diff, min_value, root: Node = 'root'):
if root == 'root':
root = self.root_tree
x = Decimal(f'{x}')
min_diff = Decimal(f'{min_diff}')
min_value = Decimal(f'{min_value}')
if not root:
return [min_value]
elif x > root.y:
distance = x - root.y
return [min_value, root.y] if distance == min_diff else [
min_value] if distance > min_diff else self._nearest_elements(x, distance, root.y, root.r)
elif x < root.x:
distance = root.x - x
return [min_value, root.x] if distance == min_diff else [
min_value] if distance > min_diff else self._nearest_elements(x, distance, root.x, root.l)
else:
return x
[docs] def nearest_element(self, x, root: Node = 'root'):
""" Finds nearest action for a number in the action space. Larger actions preferred.
Args:
x: Number
root: Node to start the search from or 'root' for searching the whole tree, default is 'root'
Returns:
Nearest element in the action space. It is the number itself if it is valid.
"""
if root == 'root':
root = self.root_tree
x = Decimal(f'{x}')
return self.nearest_elements(x, root)[-1]
[docs] def last_interval_before_or_within(self, x, root: Node = 'root'):
""" Returns the last interval before or within a number
Args:
x: Number
root: Node to start the search from or 'root' for searching the whole tree, default is 'root'
Returns:
Tuple containing the lower and upper boundaries of the interval and a variable indicating
if the number lies in the interval. For example:
(root.x, root.y), True
"""
if root == 'root':
root = self.root_tree
x = Decimal(f'{x}')
if root.x <= x <= root.y:
return (root.x, root.y), True
elif x < root.x:
return self.last_interval_before_or_within(x, root.l) if root.l is not None else ((root.x, root.y), False)
else:
return self.last_interval_before_or_within(x, root.r) if root.r is not None else (
(root.x, root.y), False) if x < root.y else ((None, None), False)
[docs] def first_interval_after_or_within(self, x, root: Node = 'root'):
""" Returns the first interval after or within a number
Args:
x: Number
root: Node to start the search from or 'root' for searching the whole tree, default is 'root'
Returns:
Tuple containing the lower and upper boundaries of the interval and a variable indicating
if the number lies in the interval. For example:
(root.x, root.y), True
"""
if root == 'root':
root = self.root_tree
x = Decimal(f'{x}')
if root.x <= x <= root.y:
return (root.x, root.y), True
elif x > root.y:
return self.first_interval_after_or_within(x, root.r) if root.r is not None else ((root.x, root.y), False)
else:
return self.first_interval_after_or_within(x, root.l) if root.l is not None else (
(root.x, root.y), False) if x > root.x else ((None, None), False)
[docs] def smallest_interval(self, root: Node = 'root'):
""" Returns the Node of the smallest interval
Args:
root: Node to start the search from or 'root' for searching the whole tree, default is 'root'
Returns:
Node of the smallest interval
"""
if root == 'root':
root = self.root_tree
if root is None or root.l is None:
return root
else:
return self.smallest_interval(root.l)
[docs] def insert(self, x, y, root: Node = 'root'):
""" Adds an interval to the action space
Args:
x: Lower bound of the interval
y: Upper bound of the interval
root: Node to start the insertion from or 'root' for inserting over the whole tree, default is 'root'
Returns:
Updated root node of the action space
"""
assert y > x, 'Upper must be larger than lower bound'
if root == 'root':
root = self.root_tree
if root is None:
self.root_tree = Node(x, y)
self.size += y - x
return self.root_tree
x = Decimal(f'{x}')
y = Decimal(f'{y}')
if not root:
self.size += y - x
return Node(x, y)
elif y < root.x:
root.l = self.insert(x, y, root.l)
elif x > root.y:
root.r = self.insert(x, y, root.r)
else:
old_size = root.y - root.x
root.x = min(root.x, x)
root.y = max(root.y, y)
self.size += root.y - root.x - old_size
updated = False
if root.r is not None and root.y >= root.r.x:
self.size -= root.y - root.r.y
root.y = root.r.y
updated = True
if root.l is not None and root.x <= root.l.y:
self.size -= root.l.x - root.x
root.x = root.l.x
updated = True
root.r = self.remove(root.x, root.y, root.r)
root.l = self.remove(root.x, root.y, root.l)
if updated:
root = self.insert(x, y, root)
root.h = 1 + max(self.getHeight(root.l),
self.getHeight(root.r))
b = self.getBal(root)
if b > 1 and y < root.l.x and self.getBal(root.l) > 0:
self.root_tree = self.rRotate(root)
return self.root_tree
if b < -1 and x > root.r.y and self.getBal(root.r) < 0:
self.root_tree = self.lRotate(root)
return self.root_tree
if b > 1 and x > root.l.y and self.getBal(root.l) < 0:
root.l = self.lRotate(root.l)
self.root_tree = self.rRotate(root)
return self.root_tree
if b < -1 and y < root.r.x and self.getBal(root.r) > 0:
root.r = self.rRotate(root.r)
self.root_tree = self.lRotate(root)
return self.root_tree
self.root_tree = root
return root
[docs] def sample(self, root: Node = 'root') -> float:
""" Sample a random action from a uniform distribution over the action space
Args:
root: Root node of the action space, default is 'root'
Returns:
Sampled action as a float
"""
if root == 'root':
root = self.root_tree
if root is None:
raise Exception('Empty Action Space')
if self.draw is None:
self.draw = Decimal(f'{uniform(0.0, float(self.size))}')
self.draw -= root.y - root.x
if self.draw > 0:
result = None
if root.l is not None:
result = self.sample(root.l)
if not result and root.r is not None:
result = self.sample(root.r)
return result
else:
result = float(root.y + self.draw)
self.draw = None
return result
[docs] def remove(self, x, y, root: Node = 'root', adjust_size: bool = True):
""" Removes an interval from the action space
Args:
x: Lower bound of the interval
y: Upper bound of the interval
root: Node to start the removal from or 'root' for removing over the whole tree, default is 'root'
Returns:
Updated root node of the action space
"""
assert y > x, 'Upper must be larger than lower bound'
if root == 'root':
root = self.root_tree
if root is None:
return root
x = Decimal(f'{x}')
y = Decimal(f'{y}')
if not root:
return None
elif x > root.x and y < root.y:
self.size -= root.y - x
old_maximum = root.y
root.y = x
root = self.insert(y, old_maximum, root)
elif x == root.x and y < root.y:
self.size -= y - x
root.x = y
elif x > root.x and y == root.y:
self.size -= y - x
root.y = x
elif x < root.x < y < root.y:
self.size -= y - root.x
root.x = y
root.l = self.remove(x, y, root.l, adjust_size)
elif root.x < x < root.y < y:
self.size -= root.y - x
root.y = x
root.r = self.remove(x, y, root.r, adjust_size)
elif y <= root.x:
root.l = self.remove(x, y, root.l, adjust_size)
elif x >= root.y:
root.r = self.remove(x, y, root.r, adjust_size)
else:
if adjust_size:
self.size -= root.y - root.x
if root.l is None:
self.root_tree = self.remove(x, y, root.r, adjust_size)
return self.root_tree
elif root.r is None:
self.root_tree = self.remove(x, y, root.l, adjust_size)
return self.root_tree
rgt = self.smallest_interval(root.r)
root.x = rgt.x
root.y = rgt.y
root.r = self.remove(rgt.x, rgt.y, root.r, adjust_size=False)
root = self.remove(x, y, root, adjust_size)
if not root:
return None
root.h = 1 + max(self.getHeight(root.l),
self.getHeight(root.r))
b = self.getBal(root)
if b > 1 and self.getBal(root.l) > 0:
self.root_tree = self.rRotate(root)
return self.root_tree
if b < -1 and self.getBal(root.r) < 0:
self.root_tree = self.lRotate(root)
return self.root_tree
if b > 1 and self.getBal(root.l) < 0:
root.l = self.lRotate(root.l)
self.root_tree = self.rRotate(root)
return self.root_tree
if b < -1 and self.getBal(root.r) > 0:
root.r = self.rRotate(root.r)
self.root_tree = self.lRotate(root)
return self.root_tree
self.root_tree = root
return root
[docs] def lRotate(self, z: Node):
""" Performs a left rotation. Switches roles of parent and child nodes.
Args:
z (Node): Parent node for the rotation
Returns:
Updated parent Node
"""
y = z.r
T2 = y.l
y.l = z
z.r = T2
z.h = 1 + max(self.getHeight(z.l),
self.getHeight(z.r))
y.h = 1 + max(self.getHeight(y.l),
self.getHeight(y.r))
return y
[docs] def rRotate(self, z: Node):
""" Performs a right rotation. Switches roles of parent and child nodes.
Args:
z (Node): Parent node for the rotation
Returns:
Updated parent Node
"""
y = z.l
T3 = y.r
y.r = z
z.l = T3
z.h = 1 + max(self.getHeight(z.l),
self.getHeight(z.r))
y.h = 1 + max(self.getHeight(y.l),
self.getHeight(y.r))
return y
[docs] def getHeight(self, root: Node = 'root'):
""" Returns the height of a Node
Args:
root: Node to return the height from or 'root' for the height of the whole tree, default is 'root'
Returns:
Integer indicating the height
"""
if root == 'root':
root = self.root_tree
if not root:
return 0
return root.h
[docs] def getBal(self, root: Node = 'root'):
""" Calculates balance factor
Args:
root: Node to calculate the balance factor for or 'root' for the balance factor of the whole tree,
default is 'root'
Returns:
Integer indicating the balance factor
"""
if root == 'root':
root = self.root_tree
if not root:
return 0
return self.getHeight(root.l) - self.getHeight(root.r)
[docs] def order(self, root: Node = 'root'):
""" Returns all intervals of the action space ordered
Args:
root: Node to start the search from or 'root' for searching the whole tree, default is 'root'
Returns:
List of tuples containing the ordered intervals. For example:
[(0.1,0.5), (0.7,0.9)]
"""
if root == 'root':
root = self.root_tree
if root is None:
return []
ordered = []
if root.l is not None:
ordered = ordered + self.order(root.l)
ordered.append((float(root.x), float(root.y)))
if root.r is not None:
ordered = ordered + self.order(root.r)
return ordered
def __str__(self):
return f'<IntervalUnionTree>'
def __repr__(self):
return self.__str__()