Press "Enter" to skip to content

Tensorflow中使用tf.keras.utils.get_file下载数据集

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

在神经网络中经常需要下载数据集(Dataset),Tensorflow的Keras提供了tf.keras.utils.get_file()函数帮助我们实现数据集下载解压的功能。

 

函数原型

 

tf.keras.utils.get_file(
    fname, origin, untar=False, md5_hash=None, 
    file_hash=None, cache_subdir='datasets',
    hash_algorithm='auto', extract=False,
    archive_format='auto', cache_dir=None
)

 

origin: 数据集(Dataset)的URL路径;

 

fname: 下载到本地后的文件名称,如果是绝对路径,下载的文件就会存储在这个路径下;

 

md5_hash: 已废弃,使用file_hash;

 

file_hash: 文件的md5,用于数据校验;

 

cache_subdir: 下载到本地的文件存储子目录;

 

cache_dir: 下载到本地的文件存储目录,默认路径~/.keras;

 

hash_algorithm:计算hash值使用的算法,可选值有: 'md5'
'sha256'
, 和 'auto'
。默认的auto会自动检测使用的哈希算法;

 

extract:为True时,该函数会尝试解压下载的文件;

 

archive_format:下载文件的解压算法,可选的值有 'auto'
'tar'
'zip'
和  None
。tar包含tar,tar.gz和tar.bz文件,auto对应tar和zip。

 

使用Demo

 

下图的示例代码展示了tf.keras.utils.get_file函数下载数据集的常规用法。首先检测目录是否存在,如果不存在,则调用tf.keras.utils.get_file进行数据下载,下载完成后,再通过os.remove()函数将解压前的压缩文件移除掉。

 

使用tf.keras.utils.get_file函数下载图片标题文件:

 

# Download caption annotation files
annotation_folder = '/annotations/'
if not os.path.exists(os.path.abspath('.') + annotation_folder):
  annotation_zip = tf.keras.utils.get_file('captions.zip',                                         
      cache_subdir=os.path.abspath('.'),
      origin = 'http://images.cocodataset.org/annotations/ 
            annotations_trainval2014.zip',
      extract = True)
  annotation_file = os.path.dirname(annotation_zip)+'/annotations 
      /captions_train2014.json'
  os.remove(annotation_zip)

 

使用tf.keras.utils.get_file函数下载图片文件:

 

# Download image files
image_folder = '/train2014/'
if not os.path.exists(os.path.abspath('.') + image_folder):
  image_zip = tf.keras.utils.get_file('train2014.zip',                                
      cache_subdir=os.path.abspath('.'),
      origin = 'http://images.cocodataset.org/zips/train2014.zip',
      extract = True)
  PATH = os.path.dirname(image_zip) + image_folder
  os.remove(image_zip)
else:
  PATH = os.path.abspath('.') + image_folder

 

参考材料

 

1. https://www.tensorflow.org/api_docs/python/tf/keras/utils/get_file

Be First to Comment

发表评论

电子邮件地址不会被公开。 必填项已用*标注