AI helps you reading Science

AI generates interpretation videos

AI extracts and analyses the key points of the paper to generate videos automatically


pub
View the video

AI Traceability

AI parses the academic lineage of this thesis


Master Reading Tree
Generate MRT

AI Insight

AI extracts a summary of this paper


Weibo:
Just Train Twice recovers a significant portion of the gap in worst-group accuracy between empirical risk minimization and group distributionally robust optimization, closing 75% of the gap on average

Just Train Twice: Improving Group Robustness Without Training Group Information

INTERNATIONAL CONFERENCE ON MACHINE LEARNING, VOL 139, (2021): 6781-6792

Cited by: 4|Views45
EI
Full Text
Bibtex
Weibo

Abstract

Standard training via empirical risk minimization (ERM) can produce models that achieve low error on average but high error on certain groups, especially in the presence of spurious correlations between the input and label. Prior approaches that achieve low worst-group error, like group distributionally robust optimization (group DRO) req...More

Code:

Data:

0
Introduction
  • The standard approach of empirical risk minimization (ERM)—training machine learning models to minimize average training loss—can produce models that achieve low test error on average but still incur high error on certain groups of examples (Hovy & Søgaard, 2015; Blodgett et al, 2016; Tatman, 2017; Hashimoto et al, 2018; Duchi et al, 2019).
  • These performance disparities across groups can be especially pronounced in the presence of spurious correla-.
  • While these approaches have been successful at improving worst-group performance, the required training group annotations are often expensive to obtain; for example, in the toxicity classification task mentioned above, each comment has to be annotated with all the demographic identities that are mentioned
Highlights
  • The standard approach of empirical risk minimization (ERM)—training machine learning models to minimize average training loss—can produce models that achieve low test error on average but still incur high error on certain groups of examples (Hovy & Søgaard, 2015; Blodgett et al, 2016; Tatman, 2017; Hashimoto et al, 2018; Duchi et al, 2019)
  • We propose a simple algorithm, JTT (Just Train Twice), for improving the worst-group error without training group annotations, instead only requiring group annotations on a much smaller validation set to tune hyperparameters
  • JTT recovers a significant portion of the gap in worst-group accuracy between ERM and group distributionally robust optimization (DRO), closing 75% of the gap on average
  • We presented Just Train Twice (JTT), a simple algorithm that substantially improves worst-group performance without requiring expensive group annotations during training
  • A better theoretical understanding of when and why JTT works would help us to refine and further develop methods for training models that are less susceptible to spurious correlations
  • JTT’s worst-group accuracy improvements come at only a modest drop in average accuracy, averaging only 4.2% worse than the highest average accuracy on each dataset
  • JTT and many prior methods on robustness without group information all rely on a validation set that is representative of the distribution shift or annotated with group information
Results
  • Compared to other approaches that do not use training group information, JTT consistently achieves higher worst-group accuracy on all 4 datasets.
  • JTT recovers a significant portion of the gap in worst-group accuracy between ERM and group DRO, closing 75% of the gap on average.
  • The authors note that simple label balancing achieves comparable worst-group accuracy to group DRO on CivilComments.
  • JTT’s worst-group accuracy improvements come at only a modest drop in average accuracy, averaging only 4.2% worse than the highest average accuracy on each dataset.
  • This drop is consistent with Sagawa et al (2020a), which observes a tradeoff between average and worst-group accuracies
Conclusion
  • The authors presented Just Train Twice (JTT), a simple algorithm that substantially improves worst-group performance without requiring expensive group annotations during training.
  • A better theoretical understanding of when and why JTT works would help them to refine and further develop methods for training models that are less susceptible to spurious correlations.
  • JTT and many prior methods on robustness without group information all rely on a validation set that is representative of the distribution shift or annotated with Precision Training epochs CVaR DRO JTT Empirical rate.
  • While these annotations are significantly cheaper that labeling the entire training set, it still requires the practitioner to be aware of any spurious correlations and define groups .
  • Doing so may be notably difficult in real-world applications
  • This leaves open the question of whether methods can perform well with mis-specified groups or no group annotations whatsoever
Tables
  • Table1: Average and worst-group test accuracies of models trained via JTT and baselines. JTT substantially improves worst-group accuracy relative to ERM and CVaR DRO and outperforms LfF (Nam et al, 2020), a recently proposed algorithm for improving worst-group accuracy without group annotations. We also compare with group DRO, an oracle that assumes group annotations. JTT recovers a significant fraction of the gap in worst-group accuracy between ERM and group DRO
  • Table2: The precision and recall of the worst-group examples (i.e., the group with lowest validation accuracy) belonging to JTT’s error set. The error set includes a high fraction of the worst-group examples and includes them at a much higher rate than they occur in the training data
  • Table3: Waterbirds error set breakdowns
  • Table4: CelebA error set breakdowns
  • Table5: MultiNLI error set breakdowns
  • Table6: Worst-group test accuracies with 3 variants of the error set on Waterbirds: (i) standard unchanged error set; (ii) removing all waterbird on water background examples; (iii) swapping each error set example with a random example from the same group
  • Table7: Across the methods that do not use training group annotations, worst-group test performance is significantly higher when hyperparameters are tuned for worst-group validation accuracy instead of average validation accuracy. This shows that for these methods, it is still critical to have validation group annotations
  • Table8: JTT retains high test worst-group accuracy on Waterbirds and
  • Table9: Effect of dynamically computing JTT’s error set on Waterbirds. We first train the identification model for T = 50 epochs, as usual. Then, we dynamically update the error set using the final model after every K epochs of training the final model. Lower values of K have significantly lower accuracies
  • Table10: Learning rates and 2 regularization strengths for Waterbirds
  • Table11: Learning rates and 2 regularization strengths for CelebA
  • Table12: CivilComments error set breakdowns
  • Table13: Average and worst-group test accuracies. UPSAMPLE MINORITY, which upsamples y = a examples, achieves higher worst-group accuracy than ERM, but lower than JTT
Download tables as Excel
Related work
  • In this paper, we focus on group robustness (i.e., training models that obtain good performance on each of a set of predefined groups in the dataset), though other notions of robustness are also studied, such as adversarial examples (Biggio et al, 2013; Szegedy et al, 2014) or domain generalization (Blanchard et al, 2011; Muandet et al, 2013). Approaches for group robustness fall into the two main categories we discuss below.

    Robustness using group information. Several approaches leverage group information during training, either to combat spurious correlations or handle shifts in group proportions between train and test distributions. For example, Mohri et al (2019); Sagawa et al (2020a); Zhang et al (2020) minimize the worst-group loss during training; Goel et al (2020) synthetically expand the minority groups via generative modeling; Shimodaira (2000); Byrd & Lipton (2019); Sagawa et al (2020b) reweight or subsample to artificially balance the majority and minority groups; Cao et al (2019; 2020) impose heavy Lipschitz regularization around minority points. These approaches substantially reduce worst-group error, but obtaining group annotations for the entire training set can be extremely expensive.
Funding
  • This work was supported by NSF Award Grant No 1805310 and in part by Google
  • EL is supported by a National Science Foundation Graduate Research Fellowship under Grant No DGE-1656518
  • AR is supported by a Google PhD Fellowship and Open Philanthropy Project AI Fellowship
  • SS is supported by a Herbert Kunzel Stanford Graduate Fellowship
Study subjects and analysis
datasets: 4
In CivilComments-WILDS, where the input is online comments, the label toxic, nontoxic spuriously correlates with the mention of particular demographics, as discussed above. Our method outperforms ERM on all four datasets, with an average worst-group accuracy improvement of 16.2%, while maintaining competitive average accuracy (only 4.2% worse on average). Furthermore, despite having no group annotations during training, JTT closes 75% of the gap between ERM and group DRO, which uses complete group information on the training data

datasets: 4
Setup. We study four datasets in which prior work has observed poor worst-group performance due to spurious correlations (Figure 2). Full details about these datasets are in Appendix B

datasets: 4
Table 4 reports the average and worst-group accuracies of all approaches. Compared to other approaches that do not use training group information, JTT consistently achieves higher worst-group accuracy on all 4 datasets. Additionally, JTT performs well even relative to approaches that use training group information

datasets: 4
As reported in Table 2, we observe that the error set contains worst-group examples at a much higher rate (precision) than they appear in the training dataset (empirical rate). Worstgroup examples appear in the error set 2.2x to 15.9x more frequently in the error set than in the training data, across the 4 datasets. In other words, the worst group is significantly enriched in the error set compared to the training dataset, which may explain why JTT has much better worst-group performance over ERM

datasets: 4
In other words, the worst group is significantly enriched in the error set compared to the training dataset, which may explain why JTT has much better worst-group performance over ERM. Additionally, the error set has high worst-group recall, ranging from 67.1% to 96.9% and averaging to 86.4% across the 4 datasets. Together, these results indicate that the worst group is included in the error set at relatively high both precision and recall

Reference
  • Agarwal, A., Beygelzimer, A., Dudik, M., Langford, J., and Wallach, H. A reductions approach to fair classification. In International Conference on Machine Learning (ICML), pp. 60–69, 2018.
    Google ScholarLocate open access versionFindings
  • Badgeley, M. A., Zech, J. R., Oakden-Rayner, L., Glicksberg, B. S., Liu, M., Gale, W., McConnell, M. V., Percha, B., Snyder, T. M., and Dudley, J. T. Deep learning predicts hip fracture using confounding patient and healthcare variables. npj Digital Medicine, 2, 2019.
    Google ScholarLocate open access versionFindings
  • Ben-Tal, A., den Hertog, D., Waegenaere, A. D., Melenberg, B., and Rennen, G. Robust solutions of optimization problems affected by uncertain probabilities. Management Science, 59:341–357, 2013.
    Google ScholarLocate open access versionFindings
  • Biggio, B., Corona, I., Maiorca, D., Nelson, B., Srndic, N., Laskov, P., Giacinto, G., and Roli, F. Evasion attacks against machine learning at test time. In Joint European conference on machine learning and knowledge discovery in databases, pp. 387–402, 2013.
    Google ScholarLocate open access versionFindings
  • Blanchard, G., Lee, G., and Scott, C. Generalizing from several related classification tasks to a new unlabeled sample. In Advances in neural information processing systems, pp. 2178–2186, 2011.
    Google ScholarLocate open access versionFindings
  • Blodgett, S. L., Green, L., and O’Connor, B. Demographic dialectal variation in social media: A case study of African-American English. In Empirical Methods in Natural Language Processing (EMNLP), pp. 1119–1130, 2016.
    Google ScholarLocate open access versionFindings
  • Borkan, D., Dixon, L., Sorensen, J., Thain, N., and Vasserman, L. Nuanced metrics for measuring unintended bias with real data for text classification. In World Wide Web (WWW), pp. 491–500, 2019.
    Google ScholarLocate open access versionFindings
  • Byrd, J. and Lipton, Z. What is the effect of importance weighting in deep learning? In International Conference on Machine Learning (ICML), pp. 872–881, 2019.
    Google ScholarLocate open access versionFindings
  • Cao, K., Wei, C., Gaidon, A., Arechiga, N., and Ma, T. Learning imbalanced datasets with label-distributionaware margin loss. In Advances in Neural Information Processing Systems (NeurIPS), 2019.
    Google ScholarLocate open access versionFindings
  • Cao, K., Chen, Y., Lu, J., Arechiga, N., Gaidon, A., and Ma, T. Heteroskedastic and imbalanced deep learning with adaptive regularization. arXiv preprint arXiv:2006.15766, 2020.
    Findings
  • Creager, E., Jacobsen, J.-H., and Zemel, R. Environment inference for invariant learning. In International Conference on Machine Learning (ICML), pp. 2189–2200, 2021.
    Google ScholarLocate open access versionFindings
  • Devlin, J., Chang, M.-W., Lee, K., and Toutanova, K. BERT: Pre-training of deep bidirectional transformers for language understanding. In Association for Computational Linguistics (ACL), pp. 4171–4186, 2019.
    Google ScholarLocate open access versionFindings
  • Duchi, J., Glynn, P., and Namkoong, H. Statistics of robust optimization: A generalized empirical likelihood approach. arXiv, 2016.
    Google ScholarFindings
  • Duchi, J., Hashimoto, T., and Namkoong, H. Distributionally robust losses against mixture covariate shifts. https://cs.stanford.edu/̃thashim/assets/publications/condrisk.pdf, 2019.
    Findings
  • Goel, K., Gu, A., Li, Y., and Re, C. Model patching: Closing the subgroup performance gap with data augmentation. arXiv preprint arXiv:2008.06775, 2020.
    Findings
  • Gururangan, S., Swayamdipta, S., Levy, O., Schwartz, R., Bowman, S., and Smith, N. A. Annotation artifacts in natural language inference data. In Association for Computational Linguistics (ACL), pp. 107–112, 2018.
    Google ScholarLocate open access versionFindings
  • Hardt, M., Price, E., and Srebo, N. Equality of opportunity in supervised learning. In Advances in Neural Information Processing Systems (NeurIPS), pp. 3315–3323, 2016.
    Google ScholarLocate open access versionFindings
  • Hashimoto, T. B., Srivastava, M., Namkoong, H., and Liang, P. Fairness without demographics in repeated loss minimization. In International Conference on Machine Learning (ICML), 2018.
    Google ScholarLocate open access versionFindings
  • He, K., Zhang, X., Ren, S., and Sun, J. Deep residual learning for image recognition. In Computer Vision and Pattern Recognition (CVPR), 2016.
    Google ScholarLocate open access versionFindings
  • Hovy, D. and Søgaard, A. Tagging performance correlates with age. In Association for Computational Linguistics (ACL), pp. 483–488, 2015.
    Google ScholarLocate open access versionFindings
  • Ioffe, S. and Szegedy, C. Batch normalization: Accelerating deep network training by reducing internal covariate shift. In International Conference on Machine Learning (ICML), pp. 448–456, 2015.
    Google ScholarLocate open access versionFindings
  • Khani, F., Raghunathan, A., and Liang, P. Maximum weighted loss discrepancy. arXiv preprint arXiv:1906.03518, 2019.
    Findings
  • Kim, M. P., Ghorbani, A., and Zou, J. Multiaccuracy: Blackbox post-processing for fairness in classification. In Association for the Advancement of Artificial Intelligence (AAAI), pp. 247–254, 2019.
    Google ScholarLocate open access versionFindings
  • Koh, P. W., Sagawa, S., Marklund, H., Xie, S. M., Zhang, M., Balsubramani, A., Hu, W., Yasunaga, M., Phillips, R. L., Gao, I., Lee, T., David, E., Stavness, I., Guo, W., Earnshaw, B. A., Haque, I. S., Beery, S., Leskovec, J., Kundaje, A., Pierson, E., Levine, S., Finn, C., and Liang, P. WILDS: A benchmark of in-the-wild distribution shifts. In International Conference on Machine Learning (ICML), 2021.
    Google ScholarLocate open access versionFindings
  • Lam, H. and Zhou, E. Quantifying input uncertainty in stochastic optimization. In 2015 Winter Simulation Conference, 2015.
    Google ScholarLocate open access versionFindings
  • Levy, D., Carmon, Y., Duchi, J. C., and Sidford, A. Largescale methods for distributionally robust optimization. arXiv preprint arXiv:2010.05893, 2020.
    Findings
  • Liu, Z., Luo, P., Wang, X., and Tang, X. Deep learning face attributes in the wild. In Proceedings of the IEEE International Conference on Computer Vision, pp. 3730– 3738, 2015.
    Google ScholarLocate open access versionFindings
  • Loshchilov, I. and Hutter, F. Sgdr: Stochastic gradient descent with warm restarts. In International Conference on Learning Representations (ICLR), 2017.
    Google ScholarLocate open access versionFindings
  • McCoy, R. T., Pavlick, E., and Linzen, T. Right for the wrong reasons: Diagnosing syntactic heuristics in natural language inference. In Association for Computational Linguistics (ACL), 2019.
    Google ScholarLocate open access versionFindings
  • Mohri, M., Sivek, G., and Suresh, A. T. Agnostic federated learning. In International Conference on Machine Learning (ICML), pp. 4615–4625, 2019.
    Google ScholarLocate open access versionFindings
  • Muandet, K., Balduzzi, D., and Scholkopf, B. Domain generalization via invariant feature representation. In International Conference on Machine Learning (ICML), pp. 10–18, 2013.
    Google ScholarLocate open access versionFindings
  • Nam, J., Cha, H., Ahn, S., Lee, J., and Shin, J. Learning from failure: Training debiased classifier from biased classifier. arXiv preprint arXiv:2007.02561, 2020.
    Findings
  • Namkoong, H. and Duchi, J. Variance regularization with convex objectives. In Advances in Neural Information Processing Systems (NeurIPS), 2017.
    Google ScholarLocate open access versionFindings
  • Oakden-Rayner, L., Dunnmon, J., Carneiro, G., and Re, C. Hidden stratification causes clinically meaningful failures in machine learning for medical imaging. In Proceedings of the ACM Conference on Health, Inference, and Learning, pp. 151–159, 2020.
    Google ScholarLocate open access versionFindings
  • Oren, Y., Sagawa, S., Hashimoto, T., and Liang, P. Distributionally robust language modeling. In Empirical Methods in Natural Language Processing (EMNLP), 2019.
    Google ScholarLocate open access versionFindings
  • Paszke, A., Gross, S., Chintala, S., Chanan, G., Yang, E., DeVito, Z., Lin, Z., Desmaison, A., Antiga, L., and Lerer, A. Automatic differentiation in pytorch, 2017.
    Google ScholarLocate open access versionFindings
  • Pezeshki, M., Kaba, S.-O., Bengio, Y., Courville, A., Precup, D., and Lajoie, G. Gradient starvation: A learning proclivity in neural networks. arXiv preprint arXiv:2011.09468, 2020.
    Findings
  • Pleiss, G., Raghavan, M., Wu, F., Kleinberg, J., and Weinberger, K. Q. On fairness and calibration. In Advances in Neural Information Processing Systems (NeurIPS), pp. 5684–5693, 2017.
    Google ScholarLocate open access versionFindings
  • Ren, M., Zeng, W., Yang, B., and Urtasun, R. Learning to reweight examples for robust deep learning. In International Conference on Machine Learning (ICML), 2018.
    Google ScholarLocate open access versionFindings
  • Rockafellar, R. T. and Uryasev, S. Optimization of conditional value-at-risk. Journal of Risk, 2:21–41, 2000.
    Google ScholarLocate open access versionFindings
  • Sagawa, S., Koh, P. W., Hashimoto, T. B., and Liang, P. Distributionally robust neural networks for group shifts: On the importance of regularization for worst-case generalization. In International Conference on Learning Representations (ICLR), 2020a.
    Google ScholarLocate open access versionFindings
  • Sagawa, S., Raghunathan, A., Koh, P. W., and Liang, P. An investigation of why overparameterization exacerbates spurious correlations. In International Conference on Machine Learning (ICML), 2020b.
    Google ScholarLocate open access versionFindings
  • Shimodaira, H. Improving predictive inference under covariate shift by weighting the log-likelihood function. Journal of Statistical Planning and Inference, 90:227–244, 2000.
    Google ScholarLocate open access versionFindings
  • Sohoni, N. S., Dunnmon, J. A., Angus, G., Gu, A., and Re, C. No subclass left behind: Fine-grained robustness in coarse-grained classification problems. arXiv preprint arXiv:2011.12945, 2020.
    Findings
  • Szegedy, C., Zaremba, W., Sutskever, I., Bruna, J., Erhan, D., Goodfellow, I., and Fergus, R. Intriguing properties of neural networks. In International Conference on Learning Representations (ICLR), 2014.
    Google ScholarLocate open access versionFindings
  • Tatman, R. Gender and dialect bias in YouTube’s automatic captions. In Workshop on Ethics in Natural Langauge Processing, volume 1, pp. 53–59, 2017.
    Google ScholarLocate open access versionFindings
  • Wah, C., Branson, S., Welinder, P., Perona, P., and Belongie, S. The Caltech-UCSD Birds-200-2011 dataset. Technical report, California Institute of Technology, 2011.
    Google ScholarFindings
  • Williams, A., Nangia, N., and Bowman, S. A broadcoverage challenge corpus for sentence understanding through inference. In Association for Computational Linguistics (ACL), pp. 1112–1122, 2018.
    Google ScholarLocate open access versionFindings
  • Wolf, T., Debut, L., Sanh, V., Chaumond, J., Delangue, C., Moi, A., Cistac, P., Rault, T., Louf, R., Funtowicz, M., and Brew, J. HuggingFace’s transformers: Stateof-the-art natural language processing. arXiv preprint arXiv:1910.03771, 2019.
    Findings
  • Woodworth, B., Gunasekar, S., Ohannessian, M. I., and Srebro, N. Learning non-discriminatory predictors. In Conference on Learning Theory (COLT), pp. 1920–1953, 2017.
    Google ScholarLocate open access versionFindings
  • Zhang, J., Menon, A., Veit, A., Bhojanapalli, S., Kumar, S., and Sra, S. Coping with label shift via distributionally robust optimisation. arXiv preprint arXiv:2010.12230, 2020.
    Findings
  • Zhang, Z. and Sabuncu, M. R. Generalized cross entropy loss for training deep neural networks with noisy labels. In Advances in Neural Information Processing Systems (NeurIPS), 2018.
    Google ScholarLocate open access versionFindings
  • Zhou, B., Lapedriza, A., Khosla, A., Oliva, A., and Torralba, A. Places: A 10 million image database for scene recognition. IEEE Transactions on Pattern Analysis and Machine Intelligence, 40(6):1452–1464, 2017.
    Google ScholarLocate open access versionFindings
  • Shu, J., Xie, Q., Yi, L., Zhao, Q., Zhou, S., Xu, Z., and Meng, D. Meta-Weight-Net: Learning an explicit mapping for sample weighting. In Advances in Neural Information Processing Systems (NeurIPS), pp. 1919–1930, 2019.
    Google ScholarLocate open access versionFindings
  • In this section, we detail the model architectures and hyperparameters used by each approach. Within each dataset, we used the same model architecture across all approaches: ResNet-50 (He et al., 2016) for Waterbirds and CelebA, and BERT for MultiNLI and CivilComments (Devlin et al., 2019). For ResNet-50, we used the PyTorch (Paszke et al., 2017) implementation of ResNet-50, starting from ImageNet-pretrained weights. For BERT, we used the the HuggingFace implementation (Wolf et al., 2019) of BERT, also starting from pretrained weights.
    Google ScholarLocate open access versionFindings
  • We use the LfF implementation released by Nam et al. (2020). We use the group DRO and ERM implementations released by Sagawa et al. (2020a) and also implement CVaR DRO and JTT on top of this code base, with the CVaR DRO implementation adapted from Levy et al. (2020). For the group DRO experiments on Waterbirds, CelebA, and MultiNLI, we directly use the reported performance numbers from Sagawa et al. (2020a). We note that these numbers utilize group-specific loss adjustments that encourage the model to attain lower training losses on smaller groups, which was shown to improve worst-group generalization. We train our own group DRO model on CivilCommentsWILDS as it was not included in Sagawa et al. (2020a); for this, we did not implement these group adjustments. We train our own models for all other algorithms.
    Google ScholarLocate open access versionFindings
  • For LfF, we tune the hyperparameter q by grid searching over q ∈ {0.1, 0.3, 0.5, 0.7, 0.9}. For CivilComments, we additionally sample two values loguniformly from (0, 0.1]. This hyperparameter was not tuned in the experiments in Nam et al. (2020).
    Google ScholarLocate open access versionFindings
Your rating :
0

 

Tags
Comments
数据免责声明
页面数据均来自互联网公开来源、合作出版商和通过AI技术自动分析结果,我们不对页面数据的有效性、准确性、正确性、可靠性、完整性和及时性做出任何承诺和保证。若有疑问,可以通过电子邮件方式联系我们:report@aminer.cn
小科