metrics.py (13389B)
1 from typing import Dict, Optional, Union 2 import math 3 4 import torch 5 import transformers 6 7 from gbure.data.dictionary import RelationDictionary 8 import gbure.data.graph 9 10 11 class Metrics: 12 """ 13 Class for computing metrics. 14 15 Twenty metrics are computed: 16 - Optimized loss (usually negative log likelihood) 17 - Accuracy 18 - {directed, undirected, half_directed} {micro, macro} {f1, precision, recall} 19 Note that the Accuracy is the true accuracy, taking directionality into account and scoring the unknown relation as any other relation. 20 The last 18 metrics follow the SemEval scorer: 21 - The unknown ("Other") relation is only scored indirectly 22 - Directed is equivalent to the metrics "USING DIRECTIONALITY" 23 - Undirected is equivalent to the metrics "IGNORING DIRECTIONALITY" 24 - Half-directed is equivalent to the metrics "TAKING DIRECTIONALITY INTO ACCOUNT -- OFFICIAL" 25 Note that the directed and half_directed micro metrics are equivalents. 26 """ 27 28 def __init__(self, config: gbure.utils.dotdict, tokenizer: transformers.PreTrainedTokenizer, relation_dictionary: RelationDictionary, graph: Optional[gbure.data.graph.Graph]) -> None: 29 """ Initialize all metrics. """ 30 self.config: gbure.utils.dotdict = config 31 self.tokenizer: transformers.PreTrainedTokenizer = tokenizer 32 self.relation_dictionary: RelationDictionary = relation_dictionary 33 self.graph: Optional[gbure.data.graph.Graph] = graph 34 self.num_relations: int = len(relation_dictionary) 35 self.num_base_relations: int = relation_dictionary.base_size() 36 37 self.mask: torch.Tensor = self.build_mask() 38 self.base_transition: torch.Tensor = self.build_base_transition() 39 self.base_mask: torch.Tensor = self.base_transition.t().mv(self.mask).clamp(0, 1) 40 41 self.loss_sum: float = 0.0 42 self.relation_buckets: int = 1 + self.config.get("neighborhood_size", 0) 43 self.per_bucket_confusion: torch.Tensor = torch.zeros((self.relation_buckets, self.num_relations, self.num_relations), dtype=torch.int32) 44 45 @property 46 def confusion(self): 47 return self.per_bucket_confusion.sum(0) 48 49 def build_mask(self) -> torch.Tensor: 50 """ Return the relation mask used by semeval scorer (which partly ignore the unknown relation). """ 51 mask: torch.Tensor = torch.ones(self.num_relations) 52 if self.relation_dictionary.unknown is not None: 53 assert(self.relation_dictionary.decode(0) == self.relation_dictionary.unknown) 54 mask[0] = 0 55 return mask 56 57 def build_base_transition(self) -> torch.Tensor: 58 """ Return the transition matrix from "directed relations" to "undirected relations". """ 59 base_transition: torch.Tensor = torch.zeros((self.num_relations, self.num_base_relations)) 60 for id, bid in enumerate(self.relation_dictionary.id_to_bid): 61 base_transition[id, bid] = 1 62 return base_transition 63 64 def compute_neighborhood_bucket(self, batch: Dict[str, torch.Tensor], index: int) -> int: 65 """ Return an index between 0 and self.config.neighborhood_size corresponding to the minimum number of neighbors in the sample. """ 66 if self.graph is None: 67 return 0 68 neighborhood_size: Union[float, int] = math.inf 69 for feature, value in batch.items(): 70 if "degree" in feature and "neighborhood" not in feature: 71 neighborhood_size = min(neighborhood_size, value[index].min().item()) 72 # TODO here substract 1 for unsupervised (not that important since we don't care about unsupervised accuracies) 73 return min(neighborhood_size, self.relation_buckets-1) if neighborhood_size != math.inf else 0 # pytype: disable=bad-return-type 74 75 def update(self, batch: Dict[str, torch.Tensor], loss: torch.Tensor, losses: Dict[str, torch.Tensor], variables: Dict[str, torch.Tensor]) -> None: 76 """ 77 Update metrics according to the given batch and the outputs of the model on this batch. 78 79 The variables dictionary returned by the model should usually contain a predicted_relation tensor. 80 81 Args: 82 batch: the input values used for evaluation 83 loss: the loss optimized by the model 84 losses: intermediary (unweighted) losses 85 variables: internal variables used by the model to compute the loss 86 """ 87 predictions: torch.Tensor = variables.get("predicted_relation") 88 targets: torch.Tensor = batch.get("relation") 89 if targets is None: 90 targets = batch.get("query_relation") 91 92 if predictions is None and targets is None: 93 predictions = variables.get("prediction_relative") 94 targets = batch.get("answer") 95 96 for i, (prediction, target) in enumerate(zip(predictions, targets)): 97 neighborhood_bucket: int = self.compute_neighborhood_bucket(batch, i) 98 self.per_bucket_confusion[neighborhood_bucket, prediction, target] += 1 99 100 batch_size: int = predictions.shape[0] 101 self.loss_sum += loss.item() * batch_size 102 103 @property 104 def summary(self) -> Dict[str, str]: 105 """ Return a summary of metrics to be quickly displayed. """ 106 metrics: Dict[str, str] = {"accuracy": f"{self.accuracy*100:.2f}", 107 "loss": f"{self.loss:.2f}"} 108 if self.relation_buckets > 1: 109 metrics.update({"accuracy_non_empty": f"{self.accuracy_non_empty*100:.2f}", 110 "accuracy_full": f"{self.accuracy_full*100:.2f}"}) 111 return metrics 112 113 @property 114 def all(self) -> Dict[str, float]: 115 """ Return a dictionary of all metrics. """ 116 keys = ["accuracy", "accuracy_non_empty", "accuracy_full", "loss"] + [ 117 f"{direction}_{level}_{metric}" 118 for direction in ["directed", "undirected", "half_directed"] 119 for level in ["macro", "micro"] 120 for metric in ["f1", "precision", "recall"]] 121 return {key: getattr(self, key) for key in keys} 122 123 @property 124 def base_confusion(self) -> torch.Tensor: 125 """ Confusion matrix between "undirected" relation classes. """ 126 return self.base_transition.t().mm(self.confusion.type_as(self.base_transition)).mm(self.base_transition) 127 128 @property 129 def accuracy(self) -> float: 130 return math.nan if self.confusion.sum() == 0 else self.confusion.diagonal().sum() / self.confusion.sum().type(torch.float32) 131 132 @property 133 def accuracy_non_empty(self) -> float: 134 non_empty: torch.Tensor = self.per_bucket_confusion[1:].sum(0) 135 return math.nan if non_empty.sum() == 0 else non_empty.diagonal().sum() / non_empty.sum().type(torch.float32) 136 137 @property 138 def accuracy_full(self) -> float: 139 full: torch.Tensor = self.per_bucket_confusion[-1] 140 return math.nan if full.sum() == 0 else full.diagonal().sum() / full.sum().type(torch.float32) 141 142 @property 143 def loss(self) -> float: 144 return math.nan if self.confusion.sum() == 0 else self.loss_sum / self.confusion.sum().type(torch.float32) 145 146 ########################## 147 # Directed macro metrics # 148 ########################## 149 150 @property 151 def directed_class_precision(self) -> torch.Tensor: 152 norm: torch.Tensor = self.confusion.sum(1) 153 norm[norm == 0] = 1 154 return self.confusion.diagonal() / norm.type(torch.float32) 155 156 @property 157 def directed_class_recall(self) -> torch.Tensor: 158 norm: torch.Tensor = self.confusion.sum(0) 159 norm[norm == 0] = 1 160 return self.confusion.diagonal() / norm.type(torch.float32) 161 162 @property 163 def directed_class_f1(self) -> torch.Tensor: 164 norm: torch.Tensor = self.directed_class_precision + self.directed_class_recall 165 norm[norm == 0] = 1 166 return 2 * self.directed_class_precision * self.directed_class_recall / norm 167 168 @property 169 def directed_macro_precision(self) -> float: 170 return ((self.directed_class_precision * self.mask).sum() / self.mask.sum()).item() 171 172 @property 173 def directed_macro_recall(self) -> float: 174 return ((self.directed_class_recall * self.mask).sum() / self.mask.sum()).item() 175 176 @property 177 def directed_macro_f1(self) -> float: 178 return ((self.directed_class_f1 * self.mask).sum() / self.mask.sum()).item() 179 180 ############################ 181 # Undirected macro metrics # 182 ############################ 183 184 @property 185 def undirected_class_precision(self) -> torch.Tensor: 186 norm: torch.Tensor = self.base_confusion.sum(1) 187 norm[norm == 0] = 1 188 return self.base_confusion.diagonal() / norm 189 190 @property 191 def undirected_class_recall(self) -> torch.Tensor: 192 norm: torch.Tensor = self.base_confusion.sum(0) 193 norm[norm == 0] = 1 194 return self.base_confusion.diagonal() / norm 195 196 @property 197 def undirected_class_f1(self) -> torch.Tensor: 198 norm: torch.Tensor = self.undirected_class_precision + self.undirected_class_recall 199 norm[norm == 0] = 1 200 return 2 * self.undirected_class_precision * self.undirected_class_recall / norm 201 202 @property 203 def undirected_macro_precision(self) -> float: 204 return ((self.undirected_class_precision * self.base_mask).sum() / self.base_mask.sum()).item() 205 206 @property 207 def undirected_macro_recall(self) -> float: 208 return ((self.undirected_class_recall * self.base_mask).sum() / self.base_mask.sum()).item() 209 210 @property 211 def undirected_macro_f1(self) -> float: 212 return ((self.undirected_class_f1 * self.base_mask).sum() / self.base_mask.sum()).item() 213 214 ############################### 215 # Half-directed macro metrics # 216 ############################### 217 218 @property 219 def half_directed_class_precision(self) -> torch.Tensor: 220 norm: torch.Tensor = self.base_confusion.sum(1) 221 norm[norm == 0] = 1 222 return self.base_transition.t().mv(self.confusion.diagonal().type_as(self.base_transition)) / norm 223 224 @property 225 def half_directed_class_recall(self) -> torch.Tensor: 226 norm: torch.Tensor = self.base_confusion.sum(0) 227 norm[norm == 0] = 1 228 return self.base_transition.t().mv(self.confusion.diagonal().type_as(self.base_transition)) / norm 229 230 @property 231 def half_directed_class_f1(self) -> torch.Tensor: 232 norm: torch.Tensor = self.half_directed_class_precision + self.half_directed_class_recall 233 norm[norm == 0] = 1 234 return 2 * self.half_directed_class_precision * self.half_directed_class_recall / norm 235 236 @property 237 def half_directed_macro_precision(self) -> float: 238 return ((self.half_directed_class_precision * self.base_mask).sum() / self.base_mask.sum()).item() 239 240 @property 241 def half_directed_macro_recall(self) -> float: 242 return ((self.half_directed_class_recall * self.base_mask).sum() / self.base_mask.sum()).item() 243 244 @property 245 def half_directed_macro_f1(self) -> float: 246 return ((self.half_directed_class_f1 * self.base_mask).sum() / self.base_mask.sum()).item() 247 248 ################# 249 # Micro metrics # 250 ################# 251 252 @property 253 def directed_micro_precision(self) -> float: 254 norm: torch.Tensor = (self.confusion.sum(1) * self.mask).sum() 255 return 0 if norm == 0 else ((self.confusion.diagonal() * self.mask).sum() / norm).item() 256 257 @property 258 def directed_micro_recall(self) -> float: 259 norm: torch.Tensor = (self.confusion.sum(0) * self.mask).sum() 260 return 0 if norm == 0 else ((self.confusion.diagonal() * self.mask).sum() / norm).item() 261 262 @property 263 def directed_micro_f1(self) -> float: 264 norm: float = self.directed_micro_precision + self.directed_micro_recall 265 return 0 if norm == 0 else 2 * (self.directed_micro_precision * self.directed_micro_recall) / norm 266 267 @property 268 def half_directed_micro_precision(self) -> float: 269 norm: torch.Tensor = (self.confusion.sum(1) * self.mask).sum() 270 return 0 if norm == 0 else ((self.confusion.diagonal() * self.mask).sum() / norm).item() 271 272 @property 273 def half_directed_micro_recall(self) -> float: 274 norm: torch.Tensor = (self.confusion.sum(0) * self.mask).sum() 275 return 0 if norm == 0 else ((self.confusion.diagonal() * self.mask).sum() / norm).item() 276 277 @property 278 def half_directed_micro_f1(self) -> float: 279 norm: float = self.half_directed_micro_precision + self.half_directed_micro_recall 280 return 0 if norm == 0 else 2 * (self.half_directed_micro_precision * self.half_directed_micro_recall) / norm 281 282 @property 283 def undirected_micro_precision(self) -> float: 284 norm: torch.Tensor = (self.base_confusion.sum(1) * self.base_mask).sum() 285 return 0 if norm == 0 else ((self.base_confusion.diagonal() * self.base_mask).sum() / norm).item() 286 287 @property 288 def undirected_micro_recall(self) -> float: 289 norm: torch.Tensor = (self.base_confusion.sum(0) * self.base_mask).sum() 290 return 0 if norm == 0 else ((self.base_confusion.diagonal() * self.base_mask).sum() / norm).item() 291 292 @property 293 def undirected_micro_f1(self) -> float: 294 norm: float = self.undirected_micro_precision + self.undirected_micro_recall 295 return 0 if norm == 0 else 2 * (self.undirected_micro_precision * self.undirected_micro_recall) / norm