diff --git a/.travis.yml b/.travis.yml index 91192f8621..2cacf861c6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,23 +1,25 @@ language: rust sudo: false +dist: trusty # still in beta, but required for the prebuilt TF binaries cache: cargo: true directories: - $HOME/.cache/bazel -rust: nightly +rust: stable install: - export CC="gcc-4.9" CXX="g++-4.9" - source travis-ci/install.sh script: + - export RUST_BACKTRACE=1 - cargo test -vv -j 2 --features tensorflow_unstable - cargo run --example regression - cargo run --features tensorflow_unstable --example expressions - cargo doc -vv --features tensorflow_unstable - - (cd tensorflow-sys && cargo test -vv -j 1) + - # TODO(#66): Re-enable: (cd tensorflow-sys && cargo test -vv -j 1) - (cd tensorflow-sys && cargo doc -vv) addons: diff --git a/tensorflow-sys/Cargo.toml b/tensorflow-sys/Cargo.toml index b87a3c3fcb..0c20176d7d 100644 --- a/tensorflow-sys/Cargo.toml +++ b/tensorflow-sys/Cargo.toml @@ -17,5 +17,11 @@ links = "tensorflow" libc = "0.2" [build-dependencies] +curl = "0.4" +flate2 = "0.2" pkg-config = "0.3" semver = "0.5" +tar = "0.4" + +[features] +tensorflow_gpu = [] diff --git a/tensorflow-sys/build.rs b/tensorflow-sys/build.rs index a0248b8a4e..80e210a36a 100644 --- a/tensorflow-sys/build.rs +++ b/tensorflow-sys/build.rs @@ -1,17 +1,28 @@ +extern crate curl; +extern crate flate2; extern crate pkg_config; extern crate semver; +extern crate tar; use std::error::Error; use std::fs::File; +use std::io::BufWriter; +use std::io::Write; use std::path::{Path, PathBuf}; use std::process; use std::process::Command; use std::{env, fs}; + +use curl::easy::Easy; +use flate2::read::GzDecoder; use semver::Version; +use tar::Archive; const LIBRARY: &'static str = "tensorflow"; const REPOSITORY: &'static str = "https://github.com/tensorflow/tensorflow.git"; const TARGET: &'static str = "tensorflow:libtensorflow.so"; +// `VERSION` and `TAG` are separate because the tag is not always `'v' + VERSION`. +const VERSION: &'static str = "1.0.0"; const TAG: &'static str = "v1.0.0"; const MIN_BAZEL: &'static str = "0.3.2"; @@ -30,6 +41,91 @@ fn main() { return; } + let force_src = match env::var("TF_RUST_BUILD_FROM_SRC") { + Ok(s) => s == "true", + Err(_) => false, + }; + if !force_src && env::consts::ARCH == "x86_64" && (env::consts::OS == "linux" || env::consts::OS == "macos") { + install_prebuilt(); + } else { + build_from_src(); + } +} + +fn remove_suffix(value: &mut String, suffix: &str) { + if value.ends_with(suffix) { + let n = value.len(); + value.truncate(n - suffix.len()); + } +} + +fn extract, P2: AsRef>(archive_path: P, extract_to: P2) { + let file = File::open(archive_path).unwrap(); + let unzipped = GzDecoder::new(file).unwrap(); + let mut a = Archive::new(unzipped); + a.unpack(extract_to).unwrap(); +} + +// Downloads and unpacks a prebuilt binary. Only works for certain platforms. +fn install_prebuilt() { + // Figure out the file names. + let os = match env::consts::OS { + "macos" => "darwin", + x => x, + }; + let proc_type = if cfg!(feature = "tensorflow_gpu") {"gpu"} else {"cpu"}; + let binary_url = format!( + "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-{}-{}-{}-{}.tar.gz", + proc_type, os, env::consts::ARCH, VERSION); + log_var!(binary_url); + let short_file_name = binary_url.split("/").last().unwrap(); + let mut base_name = short_file_name.to_string(); + remove_suffix(&mut base_name, ".tar.gz"); + log_var!(base_name); + let download_dir = match env::var("TF_RUST_DOWNLOAD_DIR") { + Ok(s) => PathBuf::from(s), + Err(_) => PathBuf::from(&get!("CARGO_MANIFEST_DIR")).join("target"), + }; + if !download_dir.exists() { + fs::create_dir(&download_dir).unwrap(); + } + let file_name = download_dir.join(short_file_name); + log_var!(file_name); + + // Download the tarball. + if !file_name.exists() { + let f = File::create(&file_name).unwrap(); + let mut writer = BufWriter::new(f); + let mut easy = Easy::new(); + easy.url(&binary_url).unwrap(); + easy.write_function(move |data| { + Ok(writer.write(data).unwrap()) + }).unwrap(); + easy.perform().unwrap(); + + let response_code = easy.response_code().unwrap(); + if response_code != 200 { + panic!("Unexpected response code {} for {}", response_code, binary_url); + } + } + + // Extract the tarball. + let unpacked_dir = download_dir.join(base_name); + let lib_dir = unpacked_dir.join("lib"); + if !lib_dir.join(format!("lib{}.so", LIBRARY)).exists() { + extract(file_name, &unpacked_dir); + } + + //run("find", |command| command); // TODO: remove + run("ls", |command| { + command.arg("-l").arg(lib_dir.to_str().unwrap()) + }); // TODO: remove + + println!("cargo:rustc-link-lib=dylib={}", LIBRARY); + println!("cargo:rustc-link-search={}", lib_dir.display()); +} + +fn build_from_src() { let output = PathBuf::from(&get!("OUT_DIR")); log_var!(output); let source = PathBuf::from(&get!("CARGO_MANIFEST_DIR")).join(format!("target/source-{}", TAG)); @@ -71,7 +167,10 @@ fn main() { let configure_hint_file = Path::new(&configure_hint_file_pb); if !configure_hint_file.exists() { run("bash", - |command| command.current_dir(&source).arg("-c").arg("yes ''|./configure")); + |command| command.current_dir(&source) + .env("TF_NEED_CUDA", if cfg!(feature = "tensorflow_gpu") {"1"} else {"0"}) + .arg("-c") + .arg("yes ''|./configure")); File::create(configure_hint_file).unwrap(); } run("bazel", |command| {