00001 #include "util/read_compressed.hh"
00002 
00003 #include "util/file.hh"
00004 #include "util/have.hh"
00005 #include "util/scoped.hh"
00006 
00007 #include <algorithm>
00008 #include <iostream>
00009 
00010 #include <cassert>
00011 #include <climits>
00012 #include <cstdlib>
00013 #include <cstring>
00014 
00015 #ifdef HAVE_ZLIB
00016 #include <zlib.h>
00017 #endif
00018 
00019 #ifdef HAVE_BZLIB
00020 #include <bzlib.h>
00021 #endif
00022 
00023 #ifdef HAVE_XZLIB
00024 #include <lzma.h>
00025 #endif
00026 
00027 namespace util {
00028 
00029 CompressedException::CompressedException() throw() {}
00030 CompressedException::~CompressedException() throw() {}
00031 
00032 GZException::GZException() throw() {}
00033 GZException::~GZException() throw() {}
00034 
00035 BZException::BZException() throw() {}
00036 BZException::~BZException() throw() {}
00037 
00038 XZException::XZException() throw() {}
00039 XZException::~XZException() throw() {}
00040 
00041 class ReadBase {
00042   public:
00043     virtual ~ReadBase() {}
00044 
00045     virtual std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) = 0;
00046 
00047   protected:
00048     static void ReplaceThis(ReadBase *with, ReadCompressed &thunk) {
00049       thunk.internal_.reset(with);
00050     }
00051 
00052     ReadBase *Current(ReadCompressed &thunk) { return thunk.internal_.get(); }
00053 
00054     static uint64_t &ReadCount(ReadCompressed &thunk) {
00055       return thunk.raw_amount_;
00056     }
00057 };
00058 
00059 namespace {
00060 
00061 ReadBase *ReadFactory(int fd, uint64_t &raw_amount, const void *already_data, std::size_t already_size, bool require_compressed);
00062 
00063 
00064 class Complete : public ReadBase {
00065   public:
00066     std::size_t Read(void *, std::size_t, ReadCompressed &) {
00067       return 0;
00068     }
00069 };
00070 
00071 class Uncompressed : public ReadBase {
00072   public:
00073     explicit Uncompressed(int fd) : fd_(fd) {}
00074 
00075     std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) {
00076       std::size_t got = PartialRead(fd_.get(), to, amount);
00077       ReadCount(thunk) += got;
00078       return got;
00079     }
00080 
00081   private:
00082     scoped_fd fd_;
00083 };
00084 
00085 class UncompressedWithHeader : public ReadBase {
00086   public:
00087     UncompressedWithHeader(int fd, const void *already_data, std::size_t already_size) : fd_(fd) {
00088       assert(already_size);
00089       buf_.reset(malloc(already_size));
00090       if (!buf_.get()) throw std::bad_alloc();
00091       memcpy(buf_.get(), already_data, already_size);
00092       remain_ = static_cast<uint8_t*>(buf_.get());
00093       end_ = remain_ + already_size;
00094     }
00095 
00096     std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) {
00097       assert(buf_.get());
00098       assert(remain_ != end_);
00099       std::size_t sending = std::min<std::size_t>(amount, end_ - remain_);
00100       memcpy(to, remain_, sending);
00101       remain_ += sending;
00102       if (remain_ == end_) {
00103         ReplaceThis(new Uncompressed(fd_.release()), thunk);
00104       }
00105       return sending;
00106     }
00107 
00108   private:
00109     scoped_malloc buf_;
00110     uint8_t *remain_;
00111     uint8_t *end_;
00112 
00113     scoped_fd fd_;
00114 };
00115 
00116 static const std::size_t kInputBuffer = 16384;
00117 
00118 template <class Compression> class StreamCompressed : public ReadBase {
00119   public:
00120     StreamCompressed(int fd, const void *already_data, std::size_t already_size)
00121       : file_(fd),
00122         in_buffer_(MallocOrThrow(kInputBuffer)),
00123         back_(memcpy(in_buffer_.get(), already_data, already_size), already_size) {}
00124 
00125     std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) {
00126       if (amount == 0) return 0;
00127       back_.SetOutput(to, amount);
00128       do {
00129         if (!back_.Stream().avail_in) ReadInput(thunk);
00130         if (!back_.Process()) {
00131           
00132           std::size_t ret = static_cast<const uint8_t *>(static_cast<void*>(back_.Stream().next_out)) - static_cast<const uint8_t*>(to);
00133           ReplaceThis(ReadFactory(file_.release(), ReadCount(thunk), back_.Stream().next_in, back_.Stream().avail_in, true), thunk);
00134           if (ret) return ret;
00135           
00136           return Current(thunk)->Read(to, amount, thunk);
00137         }
00138       } while (back_.Stream().next_out == to);
00139       return static_cast<const uint8_t*>(static_cast<void*>(back_.Stream().next_out)) - static_cast<const uint8_t*>(to);
00140     }
00141 
00142   private:
00143     void ReadInput(ReadCompressed &thunk) {
00144       assert(!back_.Stream().avail_in);
00145       std::size_t got = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer);
00146       back_.SetInput(in_buffer_.get(), got);
00147       ReadCount(thunk) += got;
00148     }
00149 
00150     scoped_fd file_;
00151     scoped_malloc in_buffer_;
00152 
00153     Compression back_;
00154 };
00155 
00156 #ifdef HAVE_ZLIB
00157 class GZip {
00158   public:
00159     GZip(const void *base, std::size_t amount) {
00160       SetInput(base, amount);
00161       stream_.zalloc = Z_NULL;
00162       stream_.zfree = Z_NULL;
00163       stream_.opaque = Z_NULL;
00164       stream_.msg = NULL;
00165       
00166       
00167       UTIL_THROW_IF(Z_OK != inflateInit2(&stream_, 32 + 15), GZException, "Failed to initialize zlib.");
00168     }
00169 
00170     ~GZip() {
00171       if (Z_OK != inflateEnd(&stream_)) {
00172         std::cerr << "zlib could not close properly." << std::endl;
00173         abort();
00174       }
00175     }
00176 
00177     void SetOutput(void *to, std::size_t amount) {
00178       stream_.next_out = static_cast<Bytef*>(to);
00179       stream_.avail_out = std::min<std::size_t>(std::numeric_limits<uInt>::max(), amount);
00180     }
00181 
00182     void SetInput(const void *base, std::size_t amount) {
00183       assert(amount < static_cast<std::size_t>(std::numeric_limits<uInt>::max()));
00184       stream_.next_in = const_cast<Bytef*>(static_cast<const Bytef*>(base));
00185       stream_.avail_in = amount;
00186     }
00187 
00188     const z_stream &Stream() const { return stream_; }
00189 
00190     bool Process() {
00191       int result = inflate(&stream_, 0);
00192       switch (result) {
00193         case Z_OK:
00194           return true;
00195         case Z_STREAM_END:
00196           return false;
00197         case Z_ERRNO:
00198           UTIL_THROW(ErrnoException, "zlib error");
00199         default:
00200           UTIL_THROW(GZException, "zlib encountered " << (stream_.msg ? stream_.msg : "an error ") << " code " << result);
00201       }
00202     }
00203 
00204   private:
00205     z_stream stream_;
00206 };
00207 #endif // HAVE_ZLIB
00208 
00209 #ifdef HAVE_BZLIB
00210 class BZip {
00211   public:
00212     BZip(const void *base, std::size_t amount) {
00213       memset(&stream_, 0, sizeof(stream_));
00214       SetInput(base, amount);
00215       HandleError(BZ2_bzDecompressInit(&stream_, 0, 0));
00216     }
00217 
00218     ~BZip() {
00219       try {
00220         HandleError(BZ2_bzDecompressEnd(&stream_));
00221       } catch (const std::exception &e) {
00222         std::cerr << e.what() << std::endl;
00223         abort();
00224       }
00225     }
00226 
00227     bool Process() {
00228       int ret = BZ2_bzDecompress(&stream_);
00229       if (ret == BZ_STREAM_END) return false;
00230       HandleError(ret);
00231       return true;
00232     }
00233 
00234     void SetOutput(void *base, std::size_t amount) {
00235       stream_.next_out = static_cast<char*>(base);
00236       stream_.avail_out = std::min<std::size_t>(std::numeric_limits<unsigned int>::max(), amount);
00237     }
00238 
00239     void SetInput(const void *base, std::size_t amount) {
00240       stream_.next_in = const_cast<char*>(static_cast<const char*>(base));
00241       stream_.avail_in = amount;
00242     }
00243 
00244     const bz_stream &Stream() const { return stream_; }
00245 
00246   private:
00247     void HandleError(int value) {
00248       switch(value) {
00249         case BZ_OK:
00250           return;
00251         case BZ_CONFIG_ERROR:
00252           UTIL_THROW(BZException, "bzip2 seems to be miscompiled.");
00253         case BZ_PARAM_ERROR:
00254           UTIL_THROW(BZException, "bzip2 Parameter error");
00255         case BZ_DATA_ERROR:
00256           UTIL_THROW(BZException, "bzip2 detected a corrupt file");
00257         case BZ_DATA_ERROR_MAGIC:
00258           UTIL_THROW(BZException, "bzip2 detected bad magic bytes.  Perhaps this was not a bzip2 file after all?");
00259         case BZ_MEM_ERROR:
00260           throw std::bad_alloc();
00261         default:
00262           UTIL_THROW(BZException, "Unknown bzip2 error code " << value);
00263       }
00264     }
00265 
00266     bz_stream stream_;
00267 };
00268 #endif // HAVE_BZLIB
00269 
00270 #ifdef HAVE_XZLIB
00271 class XZip {
00272   public:
00273     XZip(const void *base, std::size_t amount)
00274       : stream_(), action_(LZMA_RUN) {
00275       memset(&stream_, 0, sizeof(stream_));
00276       SetInput(base, amount);
00277       HandleError(lzma_stream_decoder(&stream_, UINT64_MAX, 0));
00278     }
00279 
00280     ~XZip() {
00281       lzma_end(&stream_);
00282     }
00283 
00284     void SetOutput(void *base, std::size_t amount) {
00285       stream_.next_out = static_cast<uint8_t*>(base);
00286       stream_.avail_out = amount;
00287     }
00288 
00289     void SetInput(const void *base, std::size_t amount) {
00290       stream_.next_in = static_cast<const uint8_t*>(base);
00291       stream_.avail_in = amount;
00292       if (!amount) action_ = LZMA_FINISH;
00293     }
00294 
00295     const lzma_stream &Stream() const { return stream_; }
00296 
00297     bool Process() {
00298       lzma_ret status = lzma_code(&stream_, action_);
00299       if (status == LZMA_STREAM_END) return false;
00300       HandleError(status);
00301       return true;
00302     }
00303 
00304   private:
00305     void HandleError(lzma_ret value) {
00306       switch (value) {
00307         case LZMA_OK:
00308           return;
00309         case LZMA_MEM_ERROR:
00310           throw std::bad_alloc();
00311         case LZMA_FORMAT_ERROR:
00312           UTIL_THROW(XZException, "xzlib says file format not recognized");
00313         case LZMA_OPTIONS_ERROR:
00314           UTIL_THROW(XZException, "xzlib says unsupported compression options");
00315         case LZMA_DATA_ERROR:
00316           UTIL_THROW(XZException, "xzlib says this file is corrupt");
00317         case LZMA_BUF_ERROR:
00318           UTIL_THROW(XZException, "xzlib says unexpected end of input");
00319         default:
00320           UTIL_THROW(XZException, "unrecognized xzlib error " << value);
00321       }
00322     }
00323 
00324     lzma_stream stream_;
00325     lzma_action action_;
00326 };
00327 #endif // HAVE_XZLIB
00328 
00329 class IStreamReader : public ReadBase {
00330   public:
00331     explicit IStreamReader(std::istream &stream) : stream_(stream) {}
00332 
00333     std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) {
00334       if (!stream_.read(static_cast<char*>(to), amount)) {
00335         UTIL_THROW_IF(!stream_.eof(), ErrnoException, "istream error");
00336         amount = stream_.gcount();
00337       }
00338       ReadCount(thunk) += amount;
00339       return amount;
00340     }
00341 
00342   private:
00343     std::istream &stream_;
00344 };
00345 
00346 enum MagicResult {
00347   UTIL_UNKNOWN, UTIL_GZIP, UTIL_BZIP, UTIL_XZIP
00348 };
00349 
00350 MagicResult DetectMagic(const void *from_void, std::size_t length) {
00351   const uint8_t *header = static_cast<const uint8_t*>(from_void);
00352   if (length >= 2 && header[0] == 0x1f && header[1] == 0x8b) {
00353     return UTIL_GZIP;
00354   }
00355   const uint8_t kBZMagic[3] = {'B', 'Z', 'h'};
00356   if (length >= sizeof(kBZMagic) && !memcmp(header, kBZMagic, sizeof(kBZMagic))) {
00357     return UTIL_BZIP;
00358   }
00359   const uint8_t kXZMagic[6] = { 0xFD, '7', 'z', 'X', 'Z', 0x00 };
00360   if (length >= sizeof(kXZMagic) && !memcmp(header, kXZMagic, sizeof(kXZMagic))) {
00361     return UTIL_XZIP;
00362   }
00363   return UTIL_UNKNOWN;
00364 }
00365 
00366 ReadBase *ReadFactory(int fd, uint64_t &raw_amount, const void *already_data, const std::size_t already_size, bool require_compressed) {
00367   scoped_fd hold(fd);
00368   std::string header(reinterpret_cast<const char*>(already_data), already_size);
00369   if (header.size() < ReadCompressed::kMagicSize) {
00370     std::size_t original = header.size();
00371     header.resize(ReadCompressed::kMagicSize);
00372     std::size_t got = ReadOrEOF(fd, &header[original], ReadCompressed::kMagicSize - original);
00373     raw_amount += got;
00374     header.resize(original + got);
00375   }
00376   if (header.empty()) {
00377     return new Complete();
00378   }
00379   switch (DetectMagic(&header[0], header.size())) {
00380     case UTIL_GZIP:
00381 #ifdef HAVE_ZLIB
00382       return new StreamCompressed<GZip>(hold.release(), header.data(), header.size());
00383 #else
00384       UTIL_THROW(CompressedException, "This looks like a gzip file but gzip support was not compiled in.");
00385 #endif
00386     case UTIL_BZIP:
00387 #ifdef HAVE_BZLIB
00388       return new StreamCompressed<BZip>(hold.release(), &header[0], header.size());
00389 #else
00390       UTIL_THROW(CompressedException, "This looks like a bzip file (it begins with BZh), but bzip support was not compiled in.");
00391 #endif
00392     case UTIL_XZIP:
00393 #ifdef HAVE_XZLIB
00394       return new StreamCompressed<XZip>(hold.release(), header.data(), header.size());
00395 #else
00396       UTIL_THROW(CompressedException, "This looks like an xz file, but xz support was not compiled in.");
00397 #endif
00398     default:
00399       UTIL_THROW_IF(require_compressed, CompressedException, "Uncompressed data detected after a compresssed file.  This could be supported but usually indicates an error.");
00400       return new UncompressedWithHeader(hold.release(), header.data(), header.size());
00401   }
00402 }
00403 
00404 } 
00405 
00406 bool ReadCompressed::DetectCompressedMagic(const void *from_void) {
00407   return DetectMagic(from_void, kMagicSize) != UTIL_UNKNOWN;
00408 }
00409 
00410 ReadCompressed::ReadCompressed(int fd) {
00411   Reset(fd);
00412 }
00413 
00414 ReadCompressed::ReadCompressed(std::istream &in) {
00415   Reset(in);
00416 }
00417 
00418 ReadCompressed::ReadCompressed() {}
00419 
00420 ReadCompressed::~ReadCompressed() {}
00421 
00422 void ReadCompressed::Reset(int fd) {
00423   raw_amount_ = 0;
00424   internal_.reset();
00425   internal_.reset(ReadFactory(fd, raw_amount_, NULL, 0, false));
00426 }
00427 
00428 void ReadCompressed::Reset(std::istream &in) {
00429   internal_.reset();
00430   internal_.reset(new IStreamReader(in));
00431 }
00432 
00433 std::size_t ReadCompressed::Read(void *to, std::size_t amount) {
00434   return internal_->Read(to, amount, *this);
00435 }
00436 
00437 std::size_t ReadCompressed::ReadOrEOF(void *const to_in, std::size_t amount) {
00438   uint8_t *to = reinterpret_cast<uint8_t*>(to_in);
00439   while (amount) {
00440     std::size_t got = Read(to, amount);
00441     if (!got) break;
00442     to += got;
00443     amount -= got;
00444   }
00445   return to - reinterpret_cast<uint8_t*>(to_in);
00446 }
00447 
00448 }