Skip to content

Instantly share code, notes, and snippets.

@waleedkadous
Last active December 14, 2020 21:33
Show Gist options
  • Save waleedkadous/9aed8fae43449ecff5d3a0f53f18a80a to your computer and use it in GitHub Desktop.
Save waleedkadous/9aed8fae43449ecff5d3a0f53f18a80a to your computer and use it in GitHub Desktop.
def best_split_for_idx(tree, idx, X, y,num_parent,best_gini):
"""Find the best split for a node and a given index
"""
# Sort data along selected feature.
thresholds, classes = zip(*sorted(zip(X[:, idx], y)))
# print("Classes are: ", classes, " ", thresholds)
# We could actually split the node according to each feature/threshold pair
# and count the resulting population for each class in the children, but
# instead we compute them in an iterative fashion, making this for loop
# linear rather than quadratic.
m = y.size
num_left = [0] * tree.n_classes_
num_right = num_parent.copy()
best_thr = float('NaN')
for i in range(1, m): # possible split positions
c = classes[i - 1]
# print("c is ", c, "num left is", len(num_left))
num_left[c] += 1
num_right[c] -= 1
gini_left = 1.0 - sum(
(num_left[x] / i) ** 2 for x in range(tree.n_classes_)
)
gini_right = 1.0 - sum(
(num_right[x] / (m - i)) ** 2 for x in range(tree.n_classes_)
)
# The Gini impurity of a split is the weighted average of the Gini
# impurity of the children.
gini = (i * gini_left + (m - i) * gini_right) / m
# The following condition is to make sure we don't try to split two
# points with identical values for that feature, as it is impossible
# (both have to end up on the same side of a split).
if thresholds[i] == thresholds[i - 1]:
continue
if gini < best_gini:
best_gini = gini
best_thr = (thresholds[i] + thresholds[i - 1]) / 2 # midpoint
return best_gini, best_thr
@ray.remote
def best_split_for_idx_remote(tree, idx, X, y,num_parent,best_gini):
return best_split_for_idx(tree, idx, X, y,num_parent,best_gini)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment