Tuesday, September 10, 2024

Getting number of trees in a XGBoost model

 We can dump the number of trees in the xgboost model by below method:

Load the model and save it as a JSON file. Of course this step is not needed if the model was saved as a JSON already.

Then simply extract the num_trees information from the JSON file.

Below code snippet we may use for the same:


  1. import sys
  2. import json
  3. import xgboost as xgb
  4. if len(sys.argv) < 2:
  5. print(f'Usage: {sys.argv[0]} ')
  6. exit(1)
  7. loaded_model = xgb.Booster()
  8. loaded_model.load_model(sys.argv[1])
  9. loaded_model.save_model('/tmp/a_model.json')
  10. with open('/tmp/a_model.json', 'r') as fp:
  11. jsonrepr = json.load(fp)
  12. print(jsonrepr['learner']['gradient_booster']['model']['gbtree_model_param']['num_trees'])


We can run it as shown below (after saving the code snippet to a file xgb_tree_count.py):
python xgb_tree_count.py model-file-path

No comments:

Post a Comment