| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  | import numpy as np | 
					
						
							| 
									
										
										
										
											2020-09-02 13:03:32 +08:00
										 |  |  | from numba import njit | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  | from typing import Union, Optional | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class SegmentTree: | 
					
						
							| 
									
										
										
										
											2020-09-11 07:55:37 +08:00
										 |  |  |     """Implementation of Segment Tree.
 | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-11 07:55:37 +08:00
										 |  |  |     The segment tree stores an array ``arr`` with size ``n``. It supports value | 
					
						
							|  |  |  |     update and fast query of the sum for the interval ``[left, right)`` in | 
					
						
							|  |  |  |     O(log n) time. The detailed procedure is as follows: | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-11 07:55:37 +08:00
										 |  |  |     1. Pad the array to have length of power of 2, so that leaf nodes in the \ | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  |     segment tree have the same depth. | 
					
						
							|  |  |  |     2. Store the segment tree in a binary heap. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     :param int size: the size of segment tree. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-02 13:03:32 +08:00
										 |  |  |     def __init__(self, size: int) -> None: | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  |         bound = 1 | 
					
						
							|  |  |  |         while bound < size: | 
					
						
							|  |  |  |             bound *= 2 | 
					
						
							|  |  |  |         self._size = size | 
					
						
							|  |  |  |         self._bound = bound | 
					
						
							| 
									
										
										
										
											2020-09-02 13:03:32 +08:00
										 |  |  |         self._value = np.zeros([bound * 2]) | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |         self._compile() | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     def __len__(self) -> int: | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  |         return self._size | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     def __getitem__( | 
					
						
							|  |  |  |         self, index: Union[int, np.ndarray] | 
					
						
							|  |  |  |     ) -> Union[float, np.ndarray]: | 
					
						
							| 
									
										
										
										
											2020-09-11 07:55:37 +08:00
										 |  |  |         """Return self[index].""" | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  |         return self._value[index + self._bound] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     def __setitem__( | 
					
						
							|  |  |  |         self, index: Union[int, np.ndarray], value: Union[float, np.ndarray] | 
					
						
							|  |  |  |     ) -> None: | 
					
						
							| 
									
										
										
										
											2020-09-11 07:55:37 +08:00
										 |  |  |         """Update values in segment tree.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Duplicate values in ``index`` are handled by numpy: later index | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  |         overwrites previous ones. | 
					
						
							|  |  |  |         :: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             >>> a = np.array([1, 2, 3, 4]) | 
					
						
							|  |  |  |             >>> a[[0, 1, 0, 1]] = [4, 5, 6, 7] | 
					
						
							|  |  |  |             >>> print(a) | 
					
						
							|  |  |  |             [6 7 3 4] | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         if isinstance(index, int): | 
					
						
							| 
									
										
										
										
											2020-09-02 13:03:32 +08:00
										 |  |  |             index, value = np.array([index]), np.array([value]) | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  |         assert np.all(0 <= index) and np.all(index < self._size) | 
					
						
							| 
									
										
										
										
											2020-09-02 13:03:32 +08:00
										 |  |  |         _setitem(self._value, index + self._bound, value) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def reduce(self, start: int = 0, end: Optional[int] = None) -> float: | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  |         """Return operation(value[start:end]).""" | 
					
						
							|  |  |  |         if start == 0 and end is None: | 
					
						
							|  |  |  |             return self._value[1] | 
					
						
							|  |  |  |         if end is None: | 
					
						
							|  |  |  |             end = self._size | 
					
						
							|  |  |  |         if end < 0: | 
					
						
							|  |  |  |             end += self._size | 
					
						
							| 
									
										
										
										
											2020-09-02 13:03:32 +08:00
										 |  |  |         return _reduce(self._value, start + self._bound - 1, end + self._bound) | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def get_prefix_sum_idx( | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |         self, value: Union[float, np.ndarray] | 
					
						
							|  |  |  |     ) -> Union[int, np.ndarray]: | 
					
						
							| 
									
										
										
										
											2020-09-11 07:55:37 +08:00
										 |  |  |         r"""Find the index with given value.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Return the minimum index for each ``v`` in ``value`` so that | 
					
						
							|  |  |  |         :math:`v \le \mathrm{sums}_i`, where | 
					
						
							|  |  |  |         :math:`\mathrm{sums}_i = \sum_{j = 0}^{i} \mathrm{arr}_j`. | 
					
						
							| 
									
										
										
										
											2020-09-02 13:03:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         .. warning:: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             Please make sure all of the values inside the segment tree are | 
					
						
							|  |  |  |             non-negative when using this function. | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |         assert np.all(value >= 0.0) and np.all(value < self._value[1]) | 
					
						
							| 
									
										
										
										
											2020-08-06 10:26:24 +08:00
										 |  |  |         single = False | 
					
						
							|  |  |  |         if not isinstance(value, np.ndarray): | 
					
						
							|  |  |  |             value = np.array([value]) | 
					
						
							|  |  |  |             single = True | 
					
						
							|  |  |  |         index = _get_prefix_sum_idx(value, self._bound, self._value) | 
					
						
							|  |  |  |         return index.item() if single else index | 
					
						
							| 
									
										
										
										
											2020-09-02 13:03:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     def _compile(self) -> None: | 
					
						
							|  |  |  |         f64 = np.array([0, 1], dtype=np.float64) | 
					
						
							|  |  |  |         f32 = np.array([0, 1], dtype=np.float32) | 
					
						
							|  |  |  |         i64 = np.array([0, 1], dtype=np.int64) | 
					
						
							|  |  |  |         _setitem(f64, i64, f64) | 
					
						
							|  |  |  |         _setitem(f64, i64, f32) | 
					
						
							|  |  |  |         _reduce(f64, 0, 1) | 
					
						
							|  |  |  |         _get_prefix_sum_idx(f64, 1, f64) | 
					
						
							|  |  |  |         _get_prefix_sum_idx(f32, 1, f64) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-02 13:03:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | @njit | 
					
						
							|  |  |  | def _setitem(tree: np.ndarray, index: np.ndarray, value: np.ndarray) -> None: | 
					
						
							| 
									
										
										
										
											2020-09-11 07:55:37 +08:00
										 |  |  |     """Numba version, 4x faster: 0.1 -> 0.024.""" | 
					
						
							| 
									
										
										
										
											2020-09-02 13:03:32 +08:00
										 |  |  |     tree[index] = value | 
					
						
							|  |  |  |     while index[0] > 1: | 
					
						
							|  |  |  |         index //= 2 | 
					
						
							|  |  |  |         tree[index] = tree[index * 2] + tree[index * 2 + 1] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @njit | 
					
						
							|  |  |  | def _reduce(tree: np.ndarray, start: int, end: int) -> float: | 
					
						
							| 
									
										
										
										
											2020-09-11 07:55:37 +08:00
										 |  |  |     """Numba version, 2x faster: 0.009 -> 0.005.""" | 
					
						
							| 
									
										
										
										
											2020-09-02 13:03:32 +08:00
										 |  |  |     # nodes in (start, end) should be aggregated | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     result = 0.0 | 
					
						
							| 
									
										
										
										
											2020-09-02 13:03:32 +08:00
										 |  |  |     while end - start > 1:  # (start, end) interval is not empty | 
					
						
							|  |  |  |         if start % 2 == 0: | 
					
						
							|  |  |  |             result += tree[start + 1] | 
					
						
							|  |  |  |         start //= 2 | 
					
						
							|  |  |  |         if end % 2 == 1: | 
					
						
							|  |  |  |             result += tree[end - 1] | 
					
						
							|  |  |  |         end //= 2 | 
					
						
							|  |  |  |     return result | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @njit | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  | def _get_prefix_sum_idx( | 
					
						
							|  |  |  |     value: np.ndarray, bound: int, sums: np.ndarray | 
					
						
							|  |  |  | ) -> np.ndarray: | 
					
						
							| 
									
										
										
										
											2020-09-11 07:55:37 +08:00
										 |  |  |     """Numba version (v0.51), 5x speed up with size=100000 and bsz=64.
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-02 13:03:32 +08:00
										 |  |  |     vectorized np: 0.0923 (numpy best) -> 0.024 (now) | 
					
						
							|  |  |  |     for-loop: 0.2914 -> 0.019 (but not so stable) | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     index = np.ones(value.shape, dtype=np.int64) | 
					
						
							|  |  |  |     while index[0] < bound: | 
					
						
							|  |  |  |         index *= 2 | 
					
						
							|  |  |  |         lsons = sums[index] | 
					
						
							|  |  |  |         direct = lsons < value | 
					
						
							|  |  |  |         value -= lsons * direct | 
					
						
							|  |  |  |         index += direct | 
					
						
							|  |  |  |     index -= bound | 
					
						
							|  |  |  |     return index |