当前位置:   article > 正文

基于pytorch故障诊断深度学习中的小utils(新手版自用)_pytorch故障诊断测试

pytorch故障诊断测试

1.metrics 指标函数

  1. def recur(fn, input, *args):
  2. if isinstance(input, torch.Tensor) or isinstance(input, np.ndarray): # tensor
  3. output = fn(input, *args)
  4. elif isinstance(input, list): # list
  5. output = []
  6. for i in range(len(input)):
  7. output.append(recur(fn, input[i], *args))
  8. elif isinstance(input, tuple): # tuple
  9. output = []
  10. for i in range(len(input)):
  11. output.append(recur(fn, input[i], *args))
  12. output = tuple(output)
  13. elif isinstance(input, dict): # dict
  14. output = {}
  15. for key in input:
  16. output[key] = recur(fn, input[key], *args)
  17. elif isinstance(input, str): # str
  18. output = input
  19. elif input is None:
  20. output = None
  21. else:
  22. raise ValueError('Not valid input type')
  23. return output
  24. def Accuracy(output, target, topk=1):
  25. with torch.no_grad():
  26. if target.dtype != torch.int64:
  27. target = (target.topk(1, 1, True, True)[1]).view(-1)
  28. batch_size = target.size(0)
  29. pred_k = output.topk(topk, 1, True, True)[1]
  30. correct_k = pred_k.eq(target.view(-1, 1).expand_as(pred_k)).float().sum()
  31. acc = (correct_k * (100.0 / batch_size)).item()
  32. return acc
  33. def MAccuracy(output, target, mask, topk=1):
  34. if torch.any(mask):
  35. output = output[mask]
  36. target = target[mask]
  37. acc = Accuracy(output, target, topk)
  38. else:
  39. acc = 0
  40. return acc
  41. def LabelRatio(mask):
  42. with torch.no_grad():
  43. lr = mask.float().mean().item()
  44. return lr
  45. class Metric(object):
  46. def __init__(self, metric_name):
  47. self.metric_name = self.make_metric_name(metric_name)
  48. self.metric = {'Loss': (lambda input, output: output['loss'].item()),
  49. 'Accuracy': (lambda input, output: recur(Accuracy, output['target'], input['target'])),
  50. 'PAccuracy': (lambda input, output: recur(Accuracy, output['target'], input['target'])),
  51. 'MAccuracy': (lambda input, output: recur(MAccuracy, output['target'], input['target'],
  52. output['mask'])),
  53. 'LabelRatio': (lambda input, output: recur(LabelRatio, output['mask']))}
  54. def make_metric_name(self, metric_name):
  55. return metric_name
  56. def evaluate(self, metric_names, input, output):
  57. evaluation = {}
  58. for metric_name in metric_names:
  59. evaluation[metric_name] = self.metric[metric_name](input, output)
  60. return evaluation
  61. def compare(self, val):
  62. if self.pivot_direction == 'down':
  63. compared = self.pivot > val
  64. elif self.pivot_direction == 'up':
  65. compared = self.pivot < val
  66. else:
  67. raise ValueError('Not valid pivot direction')
  68. return compared
  69. def update(self, val):
  70. self.pivot = val
  71. return

========================================================================

2.Python collections模块之Counter

Counter()

主要功能:可以支持方便、快速的计数,将元素数量统计,然后计数并返回一个字典,键为元素,值为元素个数。

  1. from collections import Counter
  2. list1 = ["a", "a", "a", "b", "c", "c", "f", "g", "g", "g", "f"]
  3. dic = Counter(list1)
  4. print(dic)
  5. #结果:次数是从高到低的
  6. #Counter({'a': 3, 'g': 3, 'c': 2, 'f': 2, 'b': 1})
  7. print(dict(dic))
  8. #结果:按字母顺序排序的
  9. #{'a': 3, 'b': 1, 'c': 2, 'f': 2, 'g': 3}
  10. print(dic.items()) #dic.items()获取字典的key和value
  11. #结果:按字母顺序排序的
  12. #dict_items([('a', 3), ('b', 1), ('c', 2), ('f', 2), ('g', 3)])
  13. print(dic.keys())
  14. #结果:
  15. #dict_keys(['a', 'b', 'c', 'f', 'g'])
  16. print(dic.values())
  17. #结果:
  18. #dict_values([3, 1, 2, 2, 3])
  19. print(sorted(dic.items(), key=lambda s: (-s[1])))
  20. #结果:按统计次数降序排序
  21. #[('a', 3), ('g', 3), ('c', 2), ('f', 2), ('b', 1)]
  22. for i, v in dic.items():
  23. if v == 1:
  24. print(i)
  25. #结果:
  26. #b
  1. from collections import Counter
  2. str1 = "aabbfkrigbgsejaae"
  3. print(Counter(str1))
  4. print(dict(Counter(str1)))
  5. #结果:
  6. #Counter({'a': 4, 'b': 3, 'g': 2, 'e': 2, 'f': 1, 'k': 1, 'r': 1, 'i': 1, 's': 1, 'j': 1})
  7. #{'a': 4, 'b': 3, 'f': 1, 'k': 1, 'r': 1, 'i': 1, 'g': 2, 's': 1, 'e': 2, 'j': 1}
  8. dic1 = {'a': 3, 'b': 4, 'c': 0, 'd': -2}
  9. print(Counter(dic1))

python中的pop()函数

pop()

用于删除并返回列表中的一个元素(默认为最后一个元素

  1. >>> list1 = [1,2,4,"hello","xy","你好"]
  2. >>> a = list1.pop()#默认弹出最后一个元素
  3. >>> print(a,list1)
  4. 你好 [1,2,4,"hello","xy"]
  1. >>> list2 = [1,2,4,"hello","xy","你好"]
  2. >>> b = list2.pop(3)#弹出列表中第四个元素
  3. >>> print(b,list2)
  4. hello [1,2,4,"xy","你好"]

张量操作:

torch.cat() 功能:将张量按维度dim进行拼接

torch.stack() 功能:在新创建的维度dim上进行拼接

---------------------------------------------------

从list中随机抽取元素的方法

1.随机抽取一个元素

  1. from random import choice
  2. l = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  3. print(choice(l)) # 随机抽取一个

 2.随机抽取若干的元素(无重复)

  1. from random import sample
  2. l = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  3. print(sample(l, 5)) # 随机抽取5个元素

3.随机抽取若干个元素(有重复)

  1. import numpy as np
  2. l = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  3. idxs = np.random.randint(0, len(l), size=5) # 生成长度为5的随机数组,范围为 [0,10),作为索引
  4. print([l[i] for i in idxs]) # 按照索引,去l中获取到对应的值

----------------------------------------------------

打乱列表的顺序

  1. import random
  2. x = [i for i in range(10)]
  3. print(x)
  4. random.shuffle(x)
  5. print(x)

---------------------------------------------------- 

set()

set() 函数创建一个无序不重复元素集,可进行关系测试,删除重复数据,还可以计算交集、差集、并集等。

----------------------------------------------------  

-------------------------------------------------- 

iid数据划分:

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/413971
推荐阅读
相关标签
  

闽ICP备14008679号