当前位置:   article > 正文

anomalib代码解析之二:train下_anomalib参数

anomalib参数

四、关于get_experiment_logger

 get_experiment_logger里面放的是什么呢?

  1. def get_experiment_logger(
  2. config: DictConfig | ListConfig,
  3. ) -> Logger | Iterable[Logger] | bool:
  4. """Return a logger based on the choice of logger in the config file.
  5. Args:
  6. config (DictConfig): config.yaml file for the corresponding anomalib model.
  7. Raises:
  8. ValueError: for any logger types apart from false and tensorboard
  9. Returns:
  10. Logger | Iterable[Logger] | bool]: Logger
  11. """
  12. logger.info("Loading the experiment logger(s)")
  13. # TODO remove when logger is deprecated from project
  14. if "logger" in config.project.keys():
  15. warnings.warn(
  16. "'logger' key will be deprecated from 'project' section of the config file."
  17. " Please use the logging section in config file.",
  18. DeprecationWarning,
  19. )
  20. if "logging" not in config:
  21. config.logging = {"logger": config.project.logger, "log_graph": False}
  22. else:
  23. config.logging.logger = config.project.logger
  24. if config.logging.logger in (None, False):
  25. return False
  26. print("-------------------hahahhahhah")
  27. print(config.logging.logger)
  28. print("-------------------hahahhahhah-end")
  29. logger_list: list[Logger] = []
  30. if isinstance(config.logging.logger, str):
  31. config.logging.logger = [config.logging.logger]
  32. print("------------------------gao1")
  33. for experiment_logger in config.logging.logger:
  34. print("-------------------------gao2")
  35. print(experiment_logger)
  36. if experiment_logger == "tensorboard":
  37. logger_list.append(
  38. AnomalibTensorBoardLogger(
  39. name="Tensorboard Logs",
  40. save_dir=os.path.join(config.project.path, "logs"),
  41. log_graph=config.logging.log_graph,
  42. )
  43. )
  44. elif experiment_logger == "wandb":
  45. wandb_logdir = os.path.join(config.project.path, "logs")
  46. Path(wandb_logdir).mkdir(parents=True, exist_ok=True)
  47. name = (
  48. config.model.name
  49. if "category" not in config.dataset.keys()
  50. else f"{config.dataset.category} {config.model.name}"
  51. )
  52. logger_list.append(
  53. AnomalibWandbLogger(
  54. project=config.dataset.name,
  55. name=name,
  56. save_dir=wandb_logdir,
  57. )
  58. )
  59. elif experiment_logger == "comet":
  60. comet_logdir = os.path.join(config.project.path, "logs")
  61. Path(comet_logdir).mkdir(parents=True, exist_ok=True)
  62. run_name = (
  63. config.model.name
  64. if "category" not in config.dataset.keys()
  65. else f"{config.dataset.category} {config.model.name}"
  66. )
  67. logger_list.append(
  68. AnomalibCometLogger(project_name=config.dataset.name, experiment_name=run_name, save_dir=comet_logdir)
  69. )
  70. elif experiment_logger == "csv":
  71. logger_list.append(CSVLogger(save_dir=os.path.join(config.project.path, "logs")))
  72. else:
  73. raise UnknownLogger(
  74. f"Unknown logger type: {config.logging.logger}. "
  75. f"Available loggers are: {AVAILABLE_LOGGERS}.\n"
  76. f"To enable the logger, set `project.logger` to `true` or use one of available loggers in config.yaml\n"
  77. f"To disable the logger, set `project.logger` to `false`."
  78. )
  79. print("-------------------------gao3")
  80. print(logger_list)
  81. print("-------------------------gao3-end")
  82. return logger_list

你看最后这几行,我打印出来的logger_list,是一个[]

白折腾了,呵呵。

五、关于get_callbacks

callbacks = get_callbacks(config)
看看里面都是啥:
  1. def get_callbacks(config: DictConfig | ListConfig) -> list[Callback]:
  2. """Return base callbacks for all the lightning models.
  3. Args:
  4. config (DictConfig): Model config
  5. Return:
  6. (list[Callback]): List of callbacks.
  7. """
  8. logger.info("Loading the callbacks")
  9. callbacks: list[Callback] = []
  10. monitor_metric = None if "early_stopping" not in config.model.keys() else config.model.early_stopping.metric
  11. monitor_mode = "max" if "early_stopping" not in config.model.keys() else config.model.early_stopping.mode
  12. checkpoint = ModelCheckpoint(
  13. dirpath=os.path.join(config.project.path, "weights"),
  14. filename="model",
  15. monitor=monitor_metric,
  16. mode=monitor_mode,
  17. auto_insert_metric_name=False,
  18. )
  19. callbacks.extend([checkpoint, TimerCallback()])
  20. if "resume_from_checkpoint" in config.trainer.keys() and config.trainer.resume_from_checkpoint is not None:
  21. load_model = LoadModelCallback(config.trainer.resume_from_checkpoint)
  22. callbacks.append(load_model)
  23. # Add post-processing configurations to AnomalyModule.
  24. image_threshold = (
  25. config.metrics.threshold.manual_image if "manual_image" in config.metrics.threshold.keys() else None
  26. )
  27. pixel_threshold = (
  28. config.metrics.threshold.manual_pixel if "manual_pixel" in config.metrics.threshold.keys() else None
  29. )
  30. post_processing_callback = PostProcessingConfigurationCallback(
  31. threshold_method=config.metrics.threshold.method,
  32. manual_image_threshold=image_threshold,
  33. manual_pixel_threshold=pixel_threshold,
  34. )
  35. callbacks.append(post_processing_callback)
  36. # Add metric configuration to the model via MetricsConfigurationCallback
  37. metrics_callback = MetricsConfigurationCallback(
  38. config.dataset.task,
  39. config.metrics.get("image", None),
  40. config.metrics.get("pixel", None),
  41. )
  42. callbacks.append(metrics_callback)
  43. if "normalization_method" in config.model.keys() and not config.model.normalization_method == "none":
  44. if config.model.normalization_method == "cdf":
  45. if config.model.name in ("padim", "stfpm"):
  46. if "nncf" in config.optimization and config.optimization.nncf.apply:
  47. raise NotImplementedError("CDF Score Normalization is currently not compatible with NNCF.")
  48. callbacks.append(CdfNormalizationCallback())
  49. else:
  50. raise NotImplementedError("Score Normalization is currently supported for PADIM and STFPM only.")
  51. elif config.model.normalization_method == "min_max":
  52. callbacks.append(MinMaxNormalizationCallback())
  53. else:
  54. raise ValueError(f"Normalization method not recognized: {config.model.normalization_method}")
  55. add_visualizer_callback(callbacks, config)
  56. if "optimization" in config.keys():
  57. if "nncf" in config.optimization and config.optimization.nncf.apply:
  58. # NNCF wraps torch's jit which conflicts with kornia's jit calls.
  59. # Hence, nncf is imported only when required
  60. nncf_module = import_module("anomalib.utils.callbacks.nncf.callback")
  61. nncf_callback = getattr(nncf_module, "NNCFCallback")
  62. nncf_config = yaml.safe_load(OmegaConf.to_yaml(config.optimization.nncf))
  63. callbacks.append(
  64. nncf_callback(
  65. config=nncf_config,
  66. export_dir=os.path.join(config.project.path, "compressed"),
  67. )
  68. )
  69. if config.optimization.export_mode is not None:
  70. from .export import ( # pylint: disable=import-outside-toplevel
  71. ExportCallback,
  72. )
  73. logger.info("Setting model export to %s", config.optimization.export_mode)
  74. callbacks.append(
  75. ExportCallback(
  76. input_size=config.model.input_size,
  77. dirpath=config.project.path,
  78. filename="model",
  79. export_mode=ExportMode(config.optimization.export_mode),
  80. )
  81. )
  82. else:
  83. warnings.warn(f"Export option: {config.optimization.export_mode} not found. Defaulting to no model export")
  84. # Add callback to log graph to loggers
  85. if config.logging.log_graph not in (None, False):
  86. callbacks.append(GraphLogger())
  87. print("-------------gao callbacks")
  88. print(callbacks)
  89. print("-------------gao callbacks-end")
  90. return callbacks

看看我最后打印出来的callbacks,是啥样的?

六、关于Trainer

终于来到了最总要的地方了:

 Trainer的第一个参数,是下面这样婶儿的:

 第二个参数,前面我们说了,就是[]

第三个参数,就是上面那一堆callbacks

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/691991
推荐阅读
相关标签
  

闽ICP备14008679号