博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
tensorflow :ckpt模型转换为pytorch : hdf5模型
阅读量:6909 次
发布时间:2019-06-27

本文共 1333 字,大约阅读时间需要 4 分钟。

参考链接:

import tensorflow as tfimport deepdish as ddimport argparseimport osimport numpy as npdef tr(v):    # tensorflow weights to pytorch weights    if v.ndim == 4:        return np.ascontiguousarray(v.transpose(3,2,0,1))    elif v.ndim == 2:        return np.ascontiguousarray(v.transpose())    return vdef read_ckpt(ckpt):    # https://github.com/tensorflow/tensorflow/issues/1823    reader = tf.train.NewCheckpointReader(ckpt)    weights = {n: reader.get_tensor(n) for (n, _) in reader.get_variable_to_shape_map().iteritems()}    pyweights = {k: tr(v) for (k, v) in weights.items()}    return pyweightsif __name__ == '__main__':    parser = argparse.ArgumentParser(description="Converts ckpt weights to deepdish hdf5")    parser.add_argument("infile", type=str,                        help="Path to the ckpt.")    parser.add_argument("outfile", type=str, nargs='?', default='',                        help="Output file (inferred if missing).")    args = parser.parse_args()    if args.outfile == '':        args.outfile = os.path.splitext(args.infile)[0] + '.h5'    outdir = os.path.dirname(args.outfile)    if not os.path.exists(outdir):        os.makedirs(outdir)    weights = read_ckpt(args.infile)    dd.io.save(args.outfile, weights)    weights2 = dd.io.load(args.outfile)

 

转载于:https://www.cnblogs.com/wangyarui/p/9076401.html

你可能感兴趣的文章
eclipse 注释字体不一致的问题
查看>>
运放的PID电路
查看>>
Ubuntu下sqlite3的安装及使用
查看>>
LintCode - Backpack
查看>>
使用percona-xtrabackup工具对mysql数据库的备份方案
查看>>
C# URL 中文编码与解码
查看>>
jquery源码解析:pushStack,end,ready,eq详解
查看>>
Qt核心模块的组成
查看>>
hdu Is It A Tree?
查看>>
linux下xargs命令用法详解
查看>>
HDU1492 The number of divisors(约数) about Humble Numbers【约数】
查看>>
Vijos P1596 加法表【迭代】
查看>>
整体二分笔记
查看>>
css学习_文本有关的样式属性、sublime快捷生成标签
查看>>
Mysql学习
查看>>
jsp页面无法获取controler层model值解决方案
查看>>
[C++] Swap Two Num
查看>>
详解ABBYY FineReader 12扫描亮度设置
查看>>
线程同步利与弊,线程同步的前提
查看>>
js的escape()、encodeURI()、encodeURIComponent()区别详解
查看>>