This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def best_split(tree, X, y): | |
"""Find the best split for a node. | |
"Best" means that the average impurity of the two children, weighted by their | |
population, is the smallest possible. Additionally it must be less than the | |
impurity of the current node. | |
To find the best split, we loop through all the features, and consider all the | |
midpoints between adjacent training samples as possible thresholds. We compute | |
the Gini impurity of the split generated by that particular feature/threshold | |
pair, and return the pair with smallest impurity. | |
Returns: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def best_split_for_idx(tree, idx, X, y,num_parent,best_gini): | |
"""Find the best split for a node and a given index | |
""" | |
# Sort data along selected feature. | |
thresholds, classes = zip(*sorted(zip(X[:, idx], y))) | |
# print("Classes are: ", classes, " ", thresholds) | |
# We could actually split the node according to each feature/threshold pair | |
# and count the resulting population for each class in the children, but | |
# instead we compute them in an iterative fashion, making this for loop | |
# linear rather than quadratic. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def best_split(tree, X, y): | |
"""Find the best split for a node. | |
"Best" means that the average impurity of the two children, weighted by their | |
population, is the smallest possible. Additionally it must be less than the | |
impurity of the current node. | |
To find the best split, we loop through all the features, and consider all the | |
midpoints between adjacent training samples as possible thresholds. We compute | |
the Gini impurity of the split generated by that particular feature/threshold | |
pair, and return the pair with smallest impurity. | |
Returns: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def grow_tree_local(tree, X, y, depth): | |
"""Build a decision tree by recursively finding the best split.""" | |
# Population for each class in current node. The predicted class is the one with | |
# largest population. | |
num_samples_per_class = [np.sum(y == i) for i in range(tree.n_classes_)] | |
predicted_class = np.argmax(num_samples_per_class) | |
node = Node( | |
gini=tree._gini(y), | |
num_samples=y.size, | |
num_samples_per_class=num_samples_per_class, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@ray.remote | |
def grow_tree_remote(tree, X, y, depth=0): | |
"""Build a decision tree by recursively finding the best split.""" | |
# Population for each class in current node. The predicted class is the one with | |
# largest population. | |
num_samples_per_class = [np.sum(y == i) for i in range(tree.n_classes_)] | |
predicted_class = np.argmax(num_samples_per_class) | |
node = Node( | |
gini=tree._gini(y), | |
num_samples=y.size, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2020-12-14 10:01:40,477 INFO services.py:1090 -- View the Ray dashboard at http://127.0.0.1:8265 | |
Serial execution took 421.71840476989746 seconds | |
Test accuracy: 56% |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
% pip install ray numpy |