Last active
December 14, 2020 21:33
-
-
Save waleedkadous/9aed8fae43449ecff5d3a0f53f18a80a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def best_split_for_idx(tree, idx, X, y,num_parent,best_gini): | |
"""Find the best split for a node and a given index | |
""" | |
# Sort data along selected feature. | |
thresholds, classes = zip(*sorted(zip(X[:, idx], y))) | |
# print("Classes are: ", classes, " ", thresholds) | |
# We could actually split the node according to each feature/threshold pair | |
# and count the resulting population for each class in the children, but | |
# instead we compute them in an iterative fashion, making this for loop | |
# linear rather than quadratic. | |
m = y.size | |
num_left = [0] * tree.n_classes_ | |
num_right = num_parent.copy() | |
best_thr = float('NaN') | |
for i in range(1, m): # possible split positions | |
c = classes[i - 1] | |
# print("c is ", c, "num left is", len(num_left)) | |
num_left[c] += 1 | |
num_right[c] -= 1 | |
gini_left = 1.0 - sum( | |
(num_left[x] / i) ** 2 for x in range(tree.n_classes_) | |
) | |
gini_right = 1.0 - sum( | |
(num_right[x] / (m - i)) ** 2 for x in range(tree.n_classes_) | |
) | |
# The Gini impurity of a split is the weighted average of the Gini | |
# impurity of the children. | |
gini = (i * gini_left + (m - i) * gini_right) / m | |
# The following condition is to make sure we don't try to split two | |
# points with identical values for that feature, as it is impossible | |
# (both have to end up on the same side of a split). | |
if thresholds[i] == thresholds[i - 1]: | |
continue | |
if gini < best_gini: | |
best_gini = gini | |
best_thr = (thresholds[i] + thresholds[i - 1]) / 2 # midpoint | |
return best_gini, best_thr | |
@ray.remote | |
def best_split_for_idx_remote(tree, idx, X, y,num_parent,best_gini): | |
return best_split_for_idx(tree, idx, X, y,num_parent,best_gini) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment