25 #ifndef DCCLFIELDCODECARITHMETIC20120726H
26 #define DCCLFIELDCODECARITHMETIC20120726H
31 #include <boost/bimap.hpp>
32 #include <boost/lexical_cast.hpp>
34 #include "dccl/field_codec_typed.h"
36 #include "dccl/arithmetic/protobuf/arithmetic_extensions.pb.h"
37 #include "dccl/arithmetic/protobuf/arithmetic.pb.h"
39 #include "dccl/logger.h"
41 #include "dccl/binary.h"
60 typedef int symbol_type;
61 typedef double value_type;
63 static const symbol_type OUT_OF_RANGE_SYMBOL = -1;
64 static const symbol_type EOF_SYMBOL = -2;
65 static const symbol_type MIN_SYMBOL = EOF_SYMBOL;
67 static const int CODE_VALUE_BITS = 32;
68 static const int FREQUENCY_BITS = CODE_VALUE_BITS - 2;
70 static const freq_type MAX_FREQUENCY = (1 << FREQUENCY_BITS) - 1;
74 static std::map<std::string, std::map<std::string, Bitset> > last_bits_map;
85 symbol_type value_to_symbol(value_type value)
const;
86 value_type symbol_to_value(symbol_type symbol)
const;
87 symbol_type total_symbols()
88 {
return encoder_cumulative_freqs_.size(); }
91 {
return user_model_; }
93 symbol_type max_symbol()
const {
return user_model_.frequency_size() - 1; }
95 freq_type total_freq(ModelState state)
const
98 const boost::bimap<symbol_type, freq_type>& c_freqs = (state == ENCODER) ?
99 encoder_cumulative_freqs_ :
100 decoder_cumulative_freqs_;
102 return c_freqs.left.at(max_symbol());
105 void update_model(symbol_type symbol, ModelState state);
108 std::pair<freq_type, freq_type> symbol_to_cumulative_freq(symbol_type symbol, ModelState state)
const;
109 std::pair<symbol_type, symbol_type> cumulative_freq_to_symbol(std::pair<freq_type, freq_type> c_freq_pair, ModelState state)
const;
114 boost::bimap<symbol_type, freq_type> encoder_cumulative_freqs_;
115 boost::bimap<symbol_type, freq_type> decoder_cumulative_freqs_;
123 Model new_model(model);
124 create_and_validate_model(&new_model);
125 if(arithmetic_models_.count(model.name()))
126 arithmetic_models_.erase(model.name());
127 arithmetic_models_.insert(std::make_pair(model.name(), new_model));
130 static void create_and_validate_model(
Model* model)
132 if(!model->user_model_.IsInitialized())
135 model->user_model_.DebugString() +
136 "Missing fields: " + model->user_model_.InitializationErrorString()));
139 Model::freq_type cumulative_freq = 0;
140 for(Model::symbol_type symbol = Model::MIN_SYMBOL, n = model->user_model_.frequency_size(); symbol < n; ++symbol)
142 Model::freq_type freq;
143 if(symbol == Model::EOF_SYMBOL)
144 freq = model->user_model_.eof_frequency();
145 else if(symbol == Model::OUT_OF_RANGE_SYMBOL)
146 freq = model->user_model_.out_of_range_frequency();
148 freq = model->user_model_.frequency(symbol);
150 if(freq == 0 && symbol != Model::OUT_OF_RANGE_SYMBOL && symbol != Model::EOF_SYMBOL)
153 model->user_model_.DebugString() +
154 "All frequencies must be nonzero."));
156 cumulative_freq += freq;
157 model->encoder_cumulative_freqs_.left.insert(std::make_pair(symbol, cumulative_freq));
162 model->decoder_cumulative_freqs_ = model->encoder_cumulative_freqs_;
164 if(model->total_freq(Model::ENCODER) > Model::MAX_FREQUENCY)
167 model->user_model_.DebugString() +
168 "Sum of all frequencies must be less than " +
169 boost::lexical_cast<std::string>(Model::MAX_FREQUENCY) +
170 " in order to use 64 bit arithmetic"));
173 if(model->user_model_.value_bound_size() != model->user_model_.frequency_size() + 1)
176 model->user_model_.DebugString() +
177 "`value_bound` size must be exactly 1 more than number of symbols (= size of `frequency`)."));
182 if(std::adjacent_find (model->user_model_.value_bound().begin(),
183 model->user_model_.value_bound().end(),
184 std::greater_equal<Model::value_type>()) != model->user_model_.value_bound().end())
187 model->user_model_.DebugString() +
188 "`value_bound` must be monotonically increasing."));
193 static Model& find(
const std::string& name)
195 std::map<std::string, Model>::iterator it = arithmetic_models_.find(name);
196 if(it == arithmetic_models_.end())
197 throw(
Exception(
"Cannot find model called: " + name));
203 static std::map<std::string, Model> arithmetic_models_;
207 template<
typename FieldType = Model::value_type>
212 static const uint64 TOP_VALUE = (
static_cast<uint64>(1) << Model::CODE_VALUE_BITS) - 1;
213 static const uint64 HALF = (
static_cast<uint64>(1) << (Model::CODE_VALUE_BITS-1));
214 static const uint64 FIRST_QTR = HALF >> 1;
215 static const uint64 THIRD_QTR = HALF+FIRST_QTR;
217 Bitset encode_repeated(
const std::vector<Model::value_type>& wire_value)
219 return encode_repeated(wire_value,
true);
222 Bitset encode_repeated(
const std::vector<Model::value_type>& wire_value,
226 using namespace dccl::logger;
228 Model& model = current_model();
232 int bits_to_follow = 0;
236 for(
unsigned value_index = 0, n = max_repeat(); value_index < n; ++value_index)
238 Model::symbol_type symbol = Model::EOF_SYMBOL;
240 if(wire_value.size() > value_index)
242 Model::value_type value = wire_value[value_index];
244 dlog <<
"(ArithmeticFieldCodec) value is : " << value << std::endl;
246 symbol = model.value_to_symbol(value);
250 if(symbol == Model::OUT_OF_RANGE_SYMBOL &&
251 model.user_model().out_of_range_frequency() == 0)
255 dlog <<
"(ArithmeticFieldCodec) out of range symbol, but no frequency given; ending encoding" << std::endl;
257 symbol = Model::EOF_SYMBOL;
261 if(symbol == Model::EOF_SYMBOL &&
262 model.user_model().eof_frequency() == 0)
265 dlog <<
"(ArithmeticFieldCodec) end of file, but no frequency given; filling with most probable symbol" << std::endl;
266 symbol = *std::max_element(model.user_model().frequency().begin(), model.user_model().frequency().end());
271 dlog <<
"(ArithmeticFieldCodec) symbol is : " << symbol << std::endl;
274 dlog <<
"(ArithmeticFieldCodec) current interval: [" << (double)low / TOP_VALUE <<
","
275 << (
double)high / TOP_VALUE <<
")" << std::endl;
278 uint64 range = (high-low)+1;
280 std::pair<Model::freq_type, Model::freq_type> c_freq_range =
281 model.symbol_to_cumulative_freq(symbol, Model::ENCODER);
284 dlog <<
"(ArithmeticFieldCodec) input symbol (" << symbol
285 <<
") cumulative freq: ["<< c_freq_range.first <<
"," << c_freq_range.second <<
")" << std::endl;
287 high = low + (range*c_freq_range.second)/model.total_freq(Model::ENCODER)-1;
288 low += (range*c_freq_range.first)/model.total_freq(Model::ENCODER);
291 dlog <<
"(ArithmeticFieldCodec) input symbol (" << symbol <<
") interval: ["
292 << (double)low / TOP_VALUE <<
"," << (
double)high / TOP_VALUE <<
")" << std::endl;
297 dlog <<
"(ArithmeticFieldCodec) Q1: " <<
Bitset(Model::CODE_VALUE_BITS, FIRST_QTR)
298 <<
", Q2: " <<
Bitset(Model::CODE_VALUE_BITS, HALF)
299 <<
", Q3 : " <<
Bitset(Model::CODE_VALUE_BITS, THIRD_QTR)
300 <<
", top: " <<
Bitset(Model::CODE_VALUE_BITS, TOP_VALUE) << std::endl;
302 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec) low: " <<
Bitset(Model::CODE_VALUE_BITS, low).
to_string() << std::endl;
303 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec) high: " <<
Bitset(Model::CODE_VALUE_BITS, high).
to_string() << std::endl;
306 model.update_model(symbol, Model::ENCODER);
312 bit_plus_follow(&bits, &bits_to_follow, 0);
313 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec): completely in [0, 0.5): EXPAND" << std::endl;
317 bit_plus_follow(&bits, &bits_to_follow, 1);
320 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec): completely in [0.5, 1): EXPAND" << std::endl;
322 else if(low>=FIRST_QTR && high < THIRD_QTR)
324 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec): straddle middle [0.25, 0.75): EXPAND" << std::endl;
336 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec) low: " <<
Bitset(Model::CODE_VALUE_BITS, low).
to_string() << std::endl;
337 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec) high: " <<
Bitset(Model::CODE_VALUE_BITS, high).
to_string() << std::endl;
340 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec) current interval: [" << (double)low / TOP_VALUE <<
"," << (
double)high / TOP_VALUE <<
")" << std::endl;
345 if(value_index == wire_value.size())
356 if(high != TOP_VALUE || bits_to_follow > 0)
357 bit_plus_follow(&bits, &bits_to_follow, 0);
361 else if(high == TOP_VALUE)
363 bit_plus_follow(&bits, &bits_to_follow, 1);
371 bit_plus_follow(&bits, &bits_to_follow, (low < FIRST_QTR) ? 0 : 1);
385 void bit_plus_follow(
Bitset* bits,
int* bits_to_follow,
bool bit)
387 bits->push_back(bit);
388 dccl::dlog.
is(dccl::logger::DEBUG3) && dccl::dlog <<
"(ArithmeticFieldCodec): emitted bit: " << bit << std::endl;
390 while(*bits_to_follow)
392 dccl::dlog.
is(dccl::logger::DEBUG3) && dccl::dlog <<
"(ArithmeticFieldCodec): emitted bit (from follow): " << !bit << std::endl;
394 bits->push_back(!bit);
395 (*bits_to_follow) -= 1;
402 using namespace dccl::logger;
404 std::vector<Model::value_type> values;
406 Model& model = current_model();
414 int bit_stream_offset = Model::CODE_VALUE_BITS - bits->size();
416 for(
int i = 0, n = Model::CODE_VALUE_BITS; i < n; ++i)
418 if(i >= bit_stream_offset)
419 value |= (
static_cast<uint64>((*bits)[bits->size()-(i-bit_stream_offset)-1]) << i);
422 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec): starting value: " <<
Bitset(Model::CODE_VALUE_BITS, value).
to_string() << std::endl;
426 for(
unsigned value_index = 0, n = max_repeat(); value_index < n; ++value_index)
428 uint64 range = (high-low)+1;
430 Model::symbol_type symbol = bits_to_symbol(bits, value, bit_stream_offset, low, range);
432 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec) symbol is: " << symbol << std::endl;
435 std::pair<Model::freq_type, Model::freq_type> c_freq_range =
436 model.symbol_to_cumulative_freq(symbol, Model::DECODER);
438 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec) input symbol (" << symbol <<
") cumulative freq: ["<< c_freq_range.first <<
"," << c_freq_range.second <<
")" << std::endl;
440 high = low + (range*c_freq_range.second)/model.total_freq(Model::DECODER)-1;
441 low += (range*c_freq_range.first)/model.total_freq(Model::DECODER);
443 model.update_model(symbol, Model::DECODER);
445 if(symbol == Model::EOF_SYMBOL)
448 values.push_back(model.symbol_to_value(symbol));
450 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec) value is: " << values.back() << std::endl;
465 else if(low >= FIRST_QTR
478 bit_stream_offset +=1;
489 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec) bits used is (" << bits->size() <<
"): " << *bits << std::endl;
490 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec) bits original is (" << in.size() <<
"): " << in << std::endl;
500 unsigned size_repeated(
const std::vector<Model::value_type>& wire_values)
503 return encode_repeated(wire_values,
false).size();
513 Model& model = current_model();
517 Model::freq_type out_of_range_freq = model.user_model().out_of_range_frequency();
518 if(out_of_range_freq == 0)
519 out_of_range_freq = Model::MAX_FREQUENCY;
521 Model::value_type lowest_frequency = std::min(out_of_range_freq,
522 *std::min_element(model.user_model().frequency().begin(), model.user_model().frequency().end()));
525 unsigned size_least_probable = (unsigned)(std::ceil(max_repeat()*(log2(model.total_freq(Model::ENCODER))-log2(lowest_frequency))));
527 dccl::dlog.
is(dccl::logger::DEBUG3) && dccl::dlog <<
"(ArithmeticFieldCodec) size_least_probable: " << size_least_probable << std::endl;
530 Model::freq_type eof_freq = model.user_model().eof_frequency();
532 unsigned size_least_probable_plus_eof = (unsigned)((eof_freq != 0 ) ? std::ceil(max_repeat()*log2(model.total_freq(Model::ENCODER))-(max_repeat()-1)*log2(lowest_frequency)-log2(eof_freq)) : 0);
534 dccl::dlog.
is(dccl::logger::DEBUG3) && dccl::dlog <<
"(ArithmeticFieldCodec) size_least_probable_plus_eof: " << size_least_probable_plus_eof << std::endl;
536 return std::max(size_least_probable_plus_eof, size_least_probable) + 1;
544 const Model& model = current_model();
546 if(model.user_model().is_adaptive())
551 Model::freq_type out_of_range_freq = model.user_model().out_of_range_frequency();
552 if(out_of_range_freq == 0)
553 out_of_range_freq = 1;
555 Model::freq_type eof_freq = model.user_model().eof_frequency();
557 unsigned size_empty = (unsigned)((eof_freq != 0) ? std::ceil(log2(model.total_freq(Model::ENCODER))-log2(eof_freq)) : std::numeric_limits<unsigned>::max());
559 dccl::dlog.
is(dccl::logger::DEBUG3) && dccl::dlog <<
"(ArithmeticFieldCodec) size_empty: " << size_empty << std::endl;
562 Model::value_type highest_frequency = std::max(out_of_range_freq,
563 *std::max_element(model.user_model().frequency().begin(), model.user_model().frequency().end()));
565 unsigned size_most_probable = (unsigned)(std::ceil(max_repeat()*(log2(model.total_freq(Model::ENCODER))-log2(highest_frequency))));
567 dccl::dlog.
is(dccl::logger::DEBUG3) && dccl::dlog <<
"(ArithmeticFieldCodec) size_most_probable: " << size_most_probable << std::endl;
569 return std::min(size_empty, size_most_probable);
575 "missing (dccl.field).arithmetic");
580 ModelManager::find(model_name);
591 Model::symbol_type bits_to_symbol(
Bitset* bits,
593 int& bit_stream_offset,
597 Model& model = current_model();
601 uint64 value_high = (bit_stream_offset > 0) ?
602 value + ((
static_cast<uint64>(1) << bit_stream_offset) - 1):
606 dccl::dlog.
is(dccl::logger::DEBUG3) && dccl::dlog <<
"(ArithmeticFieldCodec): value range: [" <<
Bitset(Model::CODE_VALUE_BITS, value) <<
"," <<
Bitset(Model::CODE_VALUE_BITS, value_high) <<
")" << std::endl;
609 Model::freq_type cumulative_freq = ((value-low+1)*model.total_freq(Model::DECODER)-1)/range;
610 Model::freq_type cumulative_freq_high = ((value_high-low+1)*model.total_freq(Model::DECODER)-1)/range;
612 dccl::dlog.
is(dccl::logger::DEBUG3) && dccl::dlog <<
"(ArithmeticFieldCodec): c_freq: " << cumulative_freq <<
", c_freq_high: " << cumulative_freq_high << std::endl;
615 std::pair<Model::symbol_type, Model::symbol_type> symbol_pair = model.cumulative_freq_to_symbol(std::make_pair(cumulative_freq, cumulative_freq_high), Model::DECODER);
617 dccl::dlog.
is(dccl::logger::DEBUG3) && dccl::dlog <<
"(ArithmeticFieldCodec): symbol: " << symbol_pair.first <<
", " << symbol_pair.second << std::endl;
620 if(symbol_pair.first == symbol_pair.second)
621 return symbol_pair.first;
626 dccl::dlog.
is(dccl::logger::DEBUG3) && dccl::dlog <<
"(ArithmeticFieldCodec): bits: " << *bits << std::endl;
629 value |=
static_cast<uint64>(bits->back()) << bit_stream_offset;
631 dccl::dlog.
is(dccl::logger::DEBUG3) && dccl::dlog <<
"(ArithmeticFieldCodec): ambiguous (symbol could be " << symbol_pair.first <<
" or " << symbol_pair.second <<
")" << std::endl;
647 Model& current_model()
650 return ModelManager::find(
name);
659 template<
typename FieldType>
const uint64 ArithmeticFieldCodecBase<FieldType>::TOP_VALUE;
660 template<
typename FieldType>
const uint64 ArithmeticFieldCodecBase<FieldType>::FIRST_QTR;
661 template<
typename FieldType>
const uint64 ArithmeticFieldCodecBase<FieldType>::HALF;
662 template<
typename FieldType>
const uint64 ArithmeticFieldCodecBase<FieldType>::THIRD_QTR;
664 template<
typename FieldType>
667 Model::value_type
pre_encode(
const FieldType& field_value)
668 {
return static_cast<Model::value_type
>(field_value); }
670 FieldType
post_decode(
const Model::value_type& wire_value)
671 {
return static_cast<FieldType
>(wire_value); }
679 Model::value_type
pre_encode(
const google::protobuf::EnumValueDescriptor*
const& field_value)
680 {
return field_value->number(); }
682 const google::protobuf::EnumValueDescriptor*
post_decode(
const Model::value_type& wire_value)
685 const google::protobuf::EnumValueDescriptor* return_value = e->FindValueByNumber((
int)wire_value);