From bdd6b995f91277a258db8614371be08e3fe9988d Mon Sep 17 00:00:00 2001 From: nobu Date: Fri, 24 Feb 2017 08:36:16 +0000 Subject: Integer.sqrt [Feature #13219] git-svn-id: svn+ssh://ci.ruby-lang.org/ruby/trunk@57705 b2dd03c8-39d4-4d8f-98ff-823fe69b080e --- bignum.c | 67 ++++++++++++++++++++++++++++++++++++++++++++--- numeric.c | 59 +++++++++++++++++++++++++++++++++++++++++ test/ruby/test_integer.rb | 24 +++++++++++++++++ 3 files changed, 147 insertions(+), 3 deletions(-) diff --git a/bignum.c b/bignum.c index 638fec4610..106b5c0d4b 100644 --- a/bignum.c +++ b/bignum.c @@ -419,14 +419,13 @@ static void bary_small_rshift(BDIGIT *zds, const BDIGIT *xds, size_t n, int shift, BDIGIT higher_bdigit) { BDIGIT_DBL num = 0; - BDIGIT x; assert(0 <= shift && shift < BITSPERDIG); num = BIGUP(higher_bdigit); while (n--) { - num = (num | xds[n]) >> shift; - x = xds[n]; + BDIGIT x = xds[n]; + num = (num | x) >> shift; zds[n] = BIGLO(num); num = BIGUP(x); } @@ -6762,6 +6761,68 @@ rb_big_even_p(VALUE num) return Qtrue; } +unsigned long rb_ulong_isqrt(unsigned long); +#if SIZEOF_BDIGIT*2 > SIZEOF_LONG +BDIGIT rb_bdigit_dbl_isqrt(BDIGIT_DBL); +#else +# define rb_bdigit_dbl_isqrt(x) (BDIGIT)rb_ulong_isqrt(x) +#endif + +VALUE +rb_big_isqrt(VALUE n) +{ + BDIGIT *nds = BDIGITS(n); + size_t len = BIGNUM_LEN(n); + + if (len <= 2) { + BDIGIT sq = rb_bdigit_dbl_isqrt(bary2bdigitdbl(nds, len)); +#if SIZEOF_BDIGIT > SIZEOF_LONG + return ULL2NUM(sq); +#else + return ULONG2NUM(sq); +#endif + } + else { + int zbits = nlz(nds[len-1]); + int shift_bits = (len&1 ? BITSPERDIG/2 : BITSPERDIG) - (zbits+1)/2 + 1; + size_t tn = (len+1) / 2, xn = tn; + VALUE t, x = bignew_1(0, xn, 1); /* division may release the GVL */ + BDIGIT *tds, *xds = BDIGITS(x); + + /* x = (n >> (b/2+1)) */ + if (shift_bits == BITSPERDIG) { + MEMCPY(xds, nds+tn, BDIGIT, xn); + } + else if (shift_bits > BITSPERDIG) { + bary_small_rshift(xds, nds+len-xn, xn, shift_bits-BITSPERDIG, 0); + } + else { + bary_small_rshift(xds, nds+len-xn-1, xn, shift_bits, nds[len-1]); + } + /* x |= (1 << (b-1)/2) */ + xds[xn-1] |= (BDIGIT)1u << + ((len&1 ? 0 : BITSPERDIG/2) + (BITSPERDIG-zbits-1)/2); + + /* t = n/x */ + tn += BIGDIVREM_EXTRA_WORDS; + t = bignew_1(0, tn, 1); + tds = BDIGITS(t); + tn = BIGNUM_LEN(t); + while (bary_divmod_branch(tds, tn, NULL, 0, nds, len, xds, xn), + bary_cmp(tds, tn, xds, xn) < 0) { + int carry; + BARY_TRUNC(tds, tn); + carry = bary_add(xds, xn, xds, xn, tds, tn); + bary_small_rshift(xds, xds, xn, 1, carry); + tn = BIGNUM_LEN(t); + } + rb_big_realloc(t, 0); + rb_gc_force_recycle(t); + RBASIC_SET_CLASS_RAW(x, rb_cInteger); + return x; + } +} + /* * Bignum objects hold integers outside the range of * Fixnum. Bignum objects are created diff --git a/numeric.c b/numeric.c index 16cdbb261a..2784b6dd86 100644 --- a/numeric.c +++ b/numeric.c @@ -5128,6 +5128,64 @@ int_truncate(int argc, VALUE* argv, VALUE num) return rb_int_truncate(num, ndigits); } +#define DEFINE_INT_SQRT(rettype, prefix, argtype) \ +rettype \ +prefix##_isqrt(argtype n) \ +{ \ + if (sizeof(n) * CHAR_BIT > DBL_MANT_DIG && \ + n >= ((argtype)1UL << DBL_MANT_DIG)) { \ + unsigned int b = bit_length(n); \ + argtype t; \ + rettype x = (rettype)(n >> (b/2+1)); \ + x |= ((rettype)1LU << (b-1)/2); \ + while ((t = n/x) < (argtype)x) x = (rettype)((x + t) >> 1); \ + return x; \ + } \ + return (rettype)sqrt((double)n); \ +} + +DEFINE_INT_SQRT(unsigned long, rb_ulong, unsigned long) +#if SIZEOF_BDIGIT*2 > SIZEOF_LONG +DEFINE_INT_SQRT(BDIGIT, rb_bdigit_dbl, BDIGIT_DBL) +#endif + +#define domain_error(msg) \ + rb_raise(rb_eMathDomainError, "Numerical argument is out of domain - " #msg) + +VALUE rb_big_isqrt(VALUE); + +static VALUE +rb_int_s_isqrt(VALUE self, VALUE num) +{ + unsigned long n, sq; + if (FIXNUM_P(num)) { + if (FIXNUM_NEGATIVE_P(num)) { + domain_error("isqrt"); + } + n = FIX2ULONG(num); + sq = rb_ulong_isqrt(n); + return LONG2FIX(sq); + } + if (RB_TYPE_P(num, T_BIGNUM)) { + size_t biglen; + if (RBIGNUM_NEGATIVE_P(num)) { + domain_error("isqrt"); + } + biglen = BIGNUM_LEN(num); + if (biglen == 0) return INT2FIX(0); +#if SIZEOF_BDIGIT <= SIZEOF_LONG + /* short-circuit */ + if (biglen == 1) { + n = BIGNUM_DIGITS(num)[0]; + sq = rb_ulong_isqrt(n); + return ULONG2NUM(sq); + } +#endif + return rb_big_isqrt(num); + } + return Qnil; +} + /* * Document-class: ZeroDivisionError * @@ -5281,6 +5339,7 @@ Init_Numeric(void) rb_cInteger = rb_define_class("Integer", rb_cNumeric); rb_undef_alloc_func(rb_cInteger); rb_undef_method(CLASS_OF(rb_cInteger), "new"); + rb_define_singleton_method(rb_cInteger, "sqrt", rb_int_s_isqrt, 1); rb_define_method(rb_cInteger, "to_s", int_to_s, -1); rb_define_alias(rb_cInteger, "inspect", "to_s"); diff --git a/test/ruby/test_integer.rb b/test/ruby/test_integer.rb index 0a6b2f4539..05b0b68ccc 100644 --- a/test/ruby/test_integer.rb +++ b/test/ruby/test_integer.rb @@ -464,4 +464,28 @@ class TestInteger < Test::Unit::TestCase end assert_equal([0, 1], 10.digits(o)) end + + def test_square_root + assert_raise(Math::DomainError) {Integer.sqrt(-1)} + assert_equal(0, Integer.sqrt(0)) + (1...4).each {|i| assert_equal(1, Integer.sqrt(i))} + (4...9).each {|i| assert_equal(2, Integer.sqrt(i))} + (9...16).each {|i| assert_equal(3, Integer.sqrt(i))} + (1..40).each do |i| + mesg = "10**#{i}" + s = Integer.sqrt(n = 10**i) + if i.even? + assert_equal(10**(i/2), Integer.sqrt(n), mesg) + else + assert_include((s**2)...(s+1)**2, n, mesg) + end + end + 50.step(400, 10) do |i| + exact = 10**(i/2) + x = 10**i + assert_equal(exact, Integer.sqrt(x), "10**#{i}") + assert_equal(exact, Integer.sqrt(x+1), "10**#{i}+1") + assert_equal(exact-1, Integer.sqrt(x-1), "10**#{i}-1") + end + end end -- cgit v1.2.3