experiments.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import os
  2. from datetime import datetime
  3. from itertools import product
  4. from typing import Dict, Set, List, Any, Tuple
  5. import pandas as pd
  6. from joblib import Parallel, delayed
  7. import params
  8. from train import train_kd, ExperimentResults
  9. from utils import log, merge_dicts, pickle_object, unpickle, hash_dict
  10. ParamSet = Dict[str, Any]
  11. ParamGrid = List[ParamSet]
  12. RunnerUUID = str
  13. class ExperimentRunner:
  14. def __init__(self, experiment_parameters: Dict[str, Set[Any]], n_jobs: int):
  15. self.experiment_parameters = experiment_parameters
  16. self.n_jobs = n_jobs
  17. self.timestamp: str = datetime.now().strftime('%Y%m%d%H%M%S')
  18. os.makedirs('results/', exist_ok=True)
  19. @staticmethod
  20. def _get_param_grid(parameters: Dict[str, Set[Any]]) -> ParamGrid:
  21. return [dict(zip(parameters.keys(), t)) for t in product(*parameters.values())]
  22. @staticmethod
  23. def _file_path_experiment_results(runner_uuid: RunnerUUID) -> str:
  24. return f'results/{runner_uuid}_experiment_results.pkl'
  25. def _experiment_result_exists(self, runner_uuid: RunnerUUID) -> bool:
  26. return os.path.isfile(self._file_path_experiment_results(runner_uuid))
  27. def _param_run(self, param_set: ParamSet) -> Tuple[ExperimentResults, RunnerUUID]:
  28. log(f'Running param set: {param_set}')
  29. uuid = hash_dict(param_set)
  30. if self._experiment_result_exists(uuid):
  31. log('Loading experiment results from cache')
  32. log(uuid)
  33. experiment_results = unpickle(self._file_path_experiment_results(uuid))
  34. else:
  35. log(f'Running uuid {uuid}')
  36. experiment_results = train_kd(**param_set)
  37. pickle_object(experiment_results, self._file_path_experiment_results(uuid))
  38. return experiment_results, uuid
  39. @staticmethod
  40. def _get_dict_from_results(results: ExperimentResults) -> Dict:
  41. _, accs, f1s = results
  42. return {'val_acc': max(accs), 'val_f1': max(f1s)}
  43. def run(self):
  44. param_grid = self._get_param_grid(self.experiment_parameters)
  45. if self.n_jobs > 1:
  46. run_output = Parallel(n_jobs=self.n_jobs)(delayed(self._param_run)(param) for param in param_grid)
  47. else:
  48. run_output = [self._param_run(param) for param in param_grid]
  49. results_enriched = [
  50. merge_dicts(self._get_dict_from_results(result), param_set,
  51. {'runner_uuid': runner_uuid},
  52. {'experiment_uuid': self.timestamp})
  53. for (result, runner_uuid), param_set in zip(run_output, param_grid)
  54. ]
  55. pd.DataFrame(results_enriched).to_csv(f'results/results_{self.timestamp}.csv', index=False)
  56. if __name__ == '__main__':
  57. experiment1 = ExperimentRunner(params.experiments, n_jobs=params.n_jobs)
  58. experiment1.run()