summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornobu <nobu@b2dd03c8-39d4-4d8f-98ff-823fe69b080e>2017-02-24 08:36:16 +0000
committernobu <nobu@b2dd03c8-39d4-4d8f-98ff-823fe69b080e>2017-02-24 08:36:16 +0000
commitbdd6b995f91277a258db8614371be08e3fe9988d (patch)
tree9bd730f01373ca5ebf32160f2da0382121bde394
parent395ad27e7235bf32d89e747507f7f1518d5aa844 (diff)
Integer.sqrt [Feature #13219]
git-svn-id: svn+ssh://ci.ruby-lang.org/ruby/trunk@57705 b2dd03c8-39d4-4d8f-98ff-823fe69b080e
-rw-r--r--bignum.c67
-rw-r--r--numeric.c59
-rw-r--r--test/ruby/test_integer.rb24
3 files changed, 147 insertions, 3 deletions
diff --git a/bignum.c b/bignum.c
index 638fec4610..106b5c0d4b 100644
--- a/bignum.c
+++ b/bignum.c
@@ -419,14 +419,13 @@ static void
bary_small_rshift(BDIGIT *zds, const BDIGIT *xds, size_t n, int shift, BDIGIT higher_bdigit)
{
BDIGIT_DBL num = 0;
- BDIGIT x;
assert(0 <= shift && shift < BITSPERDIG);
num = BIGUP(higher_bdigit);
while (n--) {
- num = (num | xds[n]) >> shift;
- x = xds[n];
+ BDIGIT x = xds[n];
+ num = (num | x) >> shift;
zds[n] = BIGLO(num);
num = BIGUP(x);
}
@@ -6762,6 +6761,68 @@ rb_big_even_p(VALUE num)
return Qtrue;
}
+unsigned long rb_ulong_isqrt(unsigned long);
+#if SIZEOF_BDIGIT*2 > SIZEOF_LONG
+BDIGIT rb_bdigit_dbl_isqrt(BDIGIT_DBL);
+#else
+# define rb_bdigit_dbl_isqrt(x) (BDIGIT)rb_ulong_isqrt(x)
+#endif
+
+VALUE
+rb_big_isqrt(VALUE n)
+{
+ BDIGIT *nds = BDIGITS(n);
+ size_t len = BIGNUM_LEN(n);
+
+ if (len <= 2) {
+ BDIGIT sq = rb_bdigit_dbl_isqrt(bary2bdigitdbl(nds, len));
+#if SIZEOF_BDIGIT > SIZEOF_LONG
+ return ULL2NUM(sq);
+#else
+ return ULONG2NUM(sq);
+#endif
+ }
+ else {
+ int zbits = nlz(nds[len-1]);
+ int shift_bits = (len&1 ? BITSPERDIG/2 : BITSPERDIG) - (zbits+1)/2 + 1;
+ size_t tn = (len+1) / 2, xn = tn;
+ VALUE t, x = bignew_1(0, xn, 1); /* division may release the GVL */
+ BDIGIT *tds, *xds = BDIGITS(x);
+
+ /* x = (n >> (b/2+1)) */
+ if (shift_bits == BITSPERDIG) {
+ MEMCPY(xds, nds+tn, BDIGIT, xn);
+ }
+ else if (shift_bits > BITSPERDIG) {
+ bary_small_rshift(xds, nds+len-xn, xn, shift_bits-BITSPERDIG, 0);
+ }
+ else {
+ bary_small_rshift(xds, nds+len-xn-1, xn, shift_bits, nds[len-1]);
+ }
+ /* x |= (1 << (b-1)/2) */
+ xds[xn-1] |= (BDIGIT)1u <<
+ ((len&1 ? 0 : BITSPERDIG/2) + (BITSPERDIG-zbits-1)/2);
+
+ /* t = n/x */
+ tn += BIGDIVREM_EXTRA_WORDS;
+ t = bignew_1(0, tn, 1);
+ tds = BDIGITS(t);
+ tn = BIGNUM_LEN(t);
+ while (bary_divmod_branch(tds, tn, NULL, 0, nds, len, xds, xn),
+ bary_cmp(tds, tn, xds, xn) < 0) {
+ int carry;
+ BARY_TRUNC(tds, tn);
+ carry = bary_add(xds, xn, xds, xn, tds, tn);
+ bary_small_rshift(xds, xds, xn, 1, carry);
+ tn = BIGNUM_LEN(t);
+ }
+ rb_big_realloc(t, 0);
+ rb_gc_force_recycle(t);
+ RBASIC_SET_CLASS_RAW(x, rb_cInteger);
+ return x;
+ }
+}
+
/*
* Bignum objects hold integers outside the range of
* Fixnum. Bignum objects are created
diff --git a/numeric.c b/numeric.c
index 16cdbb261a..2784b6dd86 100644
--- a/numeric.c
+++ b/numeric.c
@@ -5128,6 +5128,64 @@ int_truncate(int argc, VALUE* argv, VALUE num)
return rb_int_truncate(num, ndigits);
}
+#define DEFINE_INT_SQRT(rettype, prefix, argtype) \
+rettype \
+prefix##_isqrt(argtype n) \
+{ \
+ if (sizeof(n) * CHAR_BIT > DBL_MANT_DIG && \
+ n >= ((argtype)1UL << DBL_MANT_DIG)) { \
+ unsigned int b = bit_length(n); \
+ argtype t; \
+ rettype x = (rettype)(n >> (b/2+1)); \
+ x |= ((rettype)1LU << (b-1)/2); \
+ while ((t = n/x) < (argtype)x) x = (rettype)((x + t) >> 1); \
+ return x; \
+ } \
+ return (rettype)sqrt((double)n); \
+}
+
+DEFINE_INT_SQRT(unsigned long, rb_ulong, unsigned long)
+#if SIZEOF_BDIGIT*2 > SIZEOF_LONG
+DEFINE_INT_SQRT(BDIGIT, rb_bdigit_dbl, BDIGIT_DBL)
+#endif
+
+#define domain_error(msg) \
+ rb_raise(rb_eMathDomainError, "Numerical argument is out of domain - " #msg)
+
+VALUE rb_big_isqrt(VALUE);
+
+static VALUE
+rb_int_s_isqrt(VALUE self, VALUE num)
+{
+ unsigned long n, sq;
+ if (FIXNUM_P(num)) {
+ if (FIXNUM_NEGATIVE_P(num)) {
+ domain_error("isqrt");
+ }
+ n = FIX2ULONG(num);
+ sq = rb_ulong_isqrt(n);
+ return LONG2FIX(sq);
+ }
+ if (RB_TYPE_P(num, T_BIGNUM)) {
+ size_t biglen;
+ if (RBIGNUM_NEGATIVE_P(num)) {
+ domain_error("isqrt");
+ }
+ biglen = BIGNUM_LEN(num);
+ if (biglen == 0) return INT2FIX(0);
+#if SIZEOF_BDIGIT <= SIZEOF_LONG
+ /* short-circuit */
+ if (biglen == 1) {
+ n = BIGNUM_DIGITS(num)[0];
+ sq = rb_ulong_isqrt(n);
+ return ULONG2NUM(sq);
+ }
+#endif
+ return rb_big_isqrt(num);
+ }
+ return Qnil;
+}
+
/*
* Document-class: ZeroDivisionError
*
@@ -5281,6 +5339,7 @@ Init_Numeric(void)
rb_cInteger = rb_define_class("Integer", rb_cNumeric);
rb_undef_alloc_func(rb_cInteger);
rb_undef_method(CLASS_OF(rb_cInteger), "new");
+ rb_define_singleton_method(rb_cInteger, "sqrt", rb_int_s_isqrt, 1);
rb_define_method(rb_cInteger, "to_s", int_to_s, -1);
rb_define_alias(rb_cInteger, "inspect", "to_s");
diff --git a/test/ruby/test_integer.rb b/test/ruby/test_integer.rb
index 0a6b2f4539..05b0b68ccc 100644
--- a/test/ruby/test_integer.rb
+++ b/test/ruby/test_integer.rb
@@ -464,4 +464,28 @@ class TestInteger < Test::Unit::TestCase
end
assert_equal([0, 1], 10.digits(o))
end
+
+ def test_square_root
+ assert_raise(Math::DomainError) {Integer.sqrt(-1)}
+ assert_equal(0, Integer.sqrt(0))
+ (1...4).each {|i| assert_equal(1, Integer.sqrt(i))}
+ (4...9).each {|i| assert_equal(2, Integer.sqrt(i))}
+ (9...16).each {|i| assert_equal(3, Integer.sqrt(i))}
+ (1..40).each do |i|
+ mesg = "10**#{i}"
+ s = Integer.sqrt(n = 10**i)
+ if i.even?
+ assert_equal(10**(i/2), Integer.sqrt(n), mesg)
+ else
+ assert_include((s**2)...(s+1)**2, n, mesg)
+ end
+ end
+ 50.step(400, 10) do |i|
+ exact = 10**(i/2)
+ x = 10**i
+ assert_equal(exact, Integer.sqrt(x), "10**#{i}")
+ assert_equal(exact, Integer.sqrt(x+1), "10**#{i}+1")
+ assert_equal(exact-1, Integer.sqrt(x-1), "10**#{i}-1")
+ end
+ end
end