summaryrefslogtreecommitdiff
path: root/ext/dl/lib/dl/func.rb
blob: 7a8b62e3253ef08d25b53557dcfefb19f84cd85e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
require 'dl'
require 'dl/callback'
require 'dl/stack'
require 'dl/value'
require 'thread'

module DL
  class Function
    include DL
    include ValueUtil

    def initialize(cfunc, argtypes, &proc)
      @cfunc = cfunc
      @stack = Stack.new(argtypes.collect{|ty| ty.abs})
      if( @cfunc.ctype < 0 )
        @cfunc.ctype = @cfunc.ctype.abs
        @unsigned = true
      else
        @unsigned = false
      end
      if( proc )
        bind(&proc)
      end
    end

    def to_i()
      @cfunc.to_i
    end

    def name
      @cfunc.name
    end

    def call(*args, &block)
      funcs = []
      args = wrap_args(args, @stack.types, funcs, &block)
      r = @cfunc.call(@stack.pack(args))
      funcs.each{|f| f.unbind_at_call()}
      return wrap_result(r)
    end

    def wrap_result(r)
      case @cfunc.ctype
      when TYPE_VOIDP
        r = CPtr.new(r)
      else
        if( @unsigned )
          r = unsigned_value(r, @cfunc.ctype)
        end
      end
      r
    end

    def bind(&block)
      if( !block )
        raise(RuntimeError, "block must be given.")
      end
      if( @cfunc.ptr == 0 )
        cb = Proc.new{|*args|
          ary = @stack.unpack(args)
          @stack.types.each_with_index{|ty, idx|
            case ty
            when TYPE_VOIDP
              ary[idx] = CPtr.new(ary[idx])
            end
          }
          r = block.call(*ary)
          wrap_arg(r, @cfunc.ctype, [])
        }
        case @cfunc.calltype
        when :cdecl
          @cfunc.ptr = set_cdecl_callback(@cfunc.ctype, @stack.size, &cb)
        when :stdcall
          @cfunc.ptr = set_stdcall_callback(@cfunc.ctype, @stack.size, &cb)
        else
          raise(RuntimeError, "unsupported calltype: #{@cfunc.calltype}")
        end
        if( @cfunc.ptr == 0 )
          raise(RuntimeException, "can't bind C function.")
        end
      end
    end

    def unbind()
      if( @cfunc.ptr != 0 )
        case @cfunc.calltype
        when :cdecl
          remove_cdecl_callback(@cfunc.ptr, @cfunc.ctype)
        when :stdcall
          remove_stdcall_callback(@cfunc.ptr, @cfunc.ctype)
        else
          raise(RuntimeError, "unsupported calltype: #{@cfunc.calltype}")
        end
        @cfunc.ptr = 0
      end
    end

    def bound?()
      @cfunc.ptr != 0
    end

    def bind_at_call(&block)
      bind(&block)
    end

    def unbind_at_call()
    end
  end

  class TempFunction < Function
    def bind_at_call(&block)
      bind(&block)
    end

    def unbind_at_call()
      unbind()
    end
  end

  class CarriedFunction < Function
    def initialize(cfunc, argtypes, n)
      super(cfunc, argtypes)
      @carrier = []
      @index = n
      @mutex = Mutex.new
    end

    def create_carrier(data)
      ary = []
      userdata = [ary, data]
      @mutex.lock()
      @carrier.push(userdata)
      return dlwrap(userdata)
    end

    def bind_at_call(&block)
      userdata = @carrier[-1]
      userdata[0].push(block)
      bind{|*args|
        ptr = args[@index]
        if( !ptr )
          raise(RuntimeError, "The index of userdata should be lower than #{args.size}.")
        end
        userdata = dlunwrap(Integer(ptr))
        args[@index] = userdata[1]
        userdata[0][0].call(*args)
      }
      @mutex.unlock()
    end
  end
end